Skip to content

Commit

Permalink
follow-up for cudagraphs changes (#1009)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 22, 2024
1 parent 6e1cb40 commit 7daa781
Showing 1 changed file with 43 additions and 19 deletions.
62 changes: 43 additions & 19 deletions thunder/transforms/cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,31 @@ def extract_descriptor(arg):


class CUDAGraphRunner:
"""A class to facilitate creating and running cudagraphs for a CUDAGraphTransform.
Key methods:
.make_cuda_graph_callable_from_symbols
the entry point from the CUDAGraphTransform that returns a callable
(mapping to .call_cuda_graph) for a given series of bound symbols and
inputs and outputs.
.call_cuda_graph
runs (and builds on cache miss) a cuda graph. This is (via the callable
from .make_cuda_graph_callable_from_symbols) the entry point during
execution
There are two cache/information layers, one mapping to the callables via the name
of the fusion in the trace (.python_callables acts as a cache, .trace_symbols is just for
inspection). There is a separate cuda_graph_cache as there could be reaons to
generate multiple graphs for inputs (e.g. changes in strides for inputs), this
is .cuda_graph_cache.
Note that these are good for inspection but are considered internals and might
change.
"""

def __init__(self):
self.cuda_graph_cache = {}
self.python_callables = {} # fn_name -> (callable. num_static_inputs)
self.cuda_graph_cache = {} # cahce_key (.make_cache_key) -> (graph, static_inputs, static_outputs)
self.python_callables = {} # fn_name -> (callable. static_input_mask (or None))
self.trace_symbols = {} # fn_name -> (bsyms, inputs, outputs)
self.name_counter = 1

Expand Down Expand Up @@ -92,7 +114,7 @@ def make_cache_key(self, fn_name, *args):
return (fn_name, self.make_static_inputs_mask(fn_name, *args), to_arg_descriptor(*args))

def call_cuda_graph(self, fn_name, *args):
fn, num_static_inputs = self.python_callables[fn_name]
fn, _ = self.python_callables[fn_name]

cache_key = self.make_cache_key(fn_name, *args)

Expand Down Expand Up @@ -142,34 +164,36 @@ def region_fn():

def make_cuda_graph_callable_from_symbols(
self,
fn_name: str,
bsyms: list[BoundSymbol],
inputs: list[Proxy],
outputs: list[Proxy],
num_static_inputs: int | None = None,
static_inputs_mask: Sequence[bool] | None = None,
) -> Callable:
# previously, one could pass the number of static inputs to get an automatic static_inputs_mask,
# but chances are that the transform could have more detailed information, so we take a mask
# static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs)

if num_static_inputs is not None:
static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs)
if static_inputs_mask is not None:
static_inputs_mask = tuple(static_inputs_mask) # ensure hashability
else:
static_inputs_mask = None

x_fn_name = f"{fn_name}_{self.name_counter}"
fn_name = f"CUDAGraph{self.name_counter}"
self.name_counter += 1

self.python_callables[x_fn_name] = (
self.python_callables[fn_name] = (
self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs),
static_inputs_mask,
)
self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs)
self.trace_symbols[fn_name] = (bsyms, inputs, outputs)

def callable(*args):
return self.call_cuda_graph(x_fn_name, *args)
return self.call_cuda_graph(fn_name, *args)

callable.__name__ = fn_name
callable.__qualname__ = fn_name
callable.__name__ = f"{fn_name}_fn"
callable.__qualname__ = f"{fn_name}_fn"

return callable
return callable, fn_name


class CUDAGraphTransform(Transform):
Expand All @@ -184,18 +208,18 @@ def __init__(self):
super().__init__()
self.cuda_graph_runner = CUDAGraphRunner()

def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol:
def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol:
inputs = [unvariableify(inp) for inp in region.inputs]
outputs = [unvariableify(out) for out in region.outputs]

from thunder.executors.passes import _del_last_used

region.bound_symbols = _del_last_used(region.bound_symbols, outputs)

fusion_name = f"CUDAGraph{fusion_counter}"

fusion_callable = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols(
f"{fusion_name}_fn", region.bound_symbols, inputs, outputs, num_static_inputs
fusion_callable, fusion_name = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols(
region.bound_symbols,
inputs,
outputs,
)

fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self)
Expand Down

0 comments on commit 7daa781

Please sign in to comment.