Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

follow-up for cudagraphs changes #1009

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions thunder/transforms/cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,31 @@ def extract_descriptor(arg):


class CUDAGraphRunner:
"""A class to facilitate creating and running cudagraphs for a CUDAGraphTransform.

Key methods:
.make_cuda_graph_callable_from_symbols
the entry point from the CUDAGraphTransform that returns a callable
(mapping to .call_cuda_graph) for a given series of bound symbols and
inputs and outputs.
.call_cuda_graph
runs (and builds on cache miss) a cuda graph. This is (via the callable
from .make_cuda_graph_callable_from_symbols) the entry point during
execution

There are two cache/information layers, one mapping to the callables via the name
of the fusion in the trace (.python_callables acts as a cache, .trace_symbols is just for
inspection). There is a separate cuda_graph_cache as there could be reaons to
generate multiple graphs for inputs (e.g. changes in strides for inputs), this
is .cuda_graph_cache.

Note that these are good for inspection but are considered internals and might
change.
"""

def __init__(self):
self.cuda_graph_cache = {}
self.python_callables = {} # fn_name -> (callable. num_static_inputs)
self.cuda_graph_cache = {} # cahce_key (.make_cache_key) -> (graph, static_inputs, static_outputs)
self.python_callables = {} # fn_name -> (callable. static_input_mask (or None))
self.trace_symbols = {} # fn_name -> (bsyms, inputs, outputs)
self.name_counter = 1

Expand Down Expand Up @@ -92,7 +114,7 @@ def make_cache_key(self, fn_name, *args):
return (fn_name, self.make_static_inputs_mask(fn_name, *args), to_arg_descriptor(*args))

def call_cuda_graph(self, fn_name, *args):
fn, num_static_inputs = self.python_callables[fn_name]
fn, _ = self.python_callables[fn_name]

cache_key = self.make_cache_key(fn_name, *args)

Expand Down Expand Up @@ -142,34 +164,36 @@ def region_fn():

def make_cuda_graph_callable_from_symbols(
self,
fn_name: str,
bsyms: list[BoundSymbol],
inputs: list[Proxy],
outputs: list[Proxy],
num_static_inputs: int | None = None,
static_inputs_mask: Sequence[bool] | None = None,
) -> Callable:
# previously, one could pass the number of static inputs to get an automatic static_inputs_mask,
# but chances are that the transform could have more detailed information, so we take a mask
# static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs)

if num_static_inputs is not None:
static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs)
if static_inputs_mask is not None:
static_inputs_mask = tuple(static_inputs_mask) # ensure hashability
else:
static_inputs_mask = None

x_fn_name = f"{fn_name}_{self.name_counter}"
fn_name = f"CUDAGraph{self.name_counter}"
self.name_counter += 1

self.python_callables[x_fn_name] = (
self.python_callables[fn_name] = (
self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs),
static_inputs_mask,
)
self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs)
self.trace_symbols[fn_name] = (bsyms, inputs, outputs)

def callable(*args):
return self.call_cuda_graph(x_fn_name, *args)
return self.call_cuda_graph(fn_name, *args)

callable.__name__ = fn_name
callable.__qualname__ = fn_name
callable.__name__ = f"{fn_name}_fn"
callable.__qualname__ = f"{fn_name}_fn"

return callable
return callable, fn_name


class CUDAGraphTransform(Transform):
Expand All @@ -184,18 +208,18 @@ def __init__(self):
super().__init__()
self.cuda_graph_runner = CUDAGraphRunner()

def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol:
def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol:
inputs = [unvariableify(inp) for inp in region.inputs]
outputs = [unvariableify(out) for out in region.outputs]

from thunder.executors.passes import _del_last_used

region.bound_symbols = _del_last_used(region.bound_symbols, outputs)

fusion_name = f"CUDAGraph{fusion_counter}"

fusion_callable = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols(
f"{fusion_name}_fn", region.bound_symbols, inputs, outputs, num_static_inputs
fusion_callable, fusion_name = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols(
region.bound_symbols,
inputs,
outputs,
)

fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self)
Expand Down
Loading