From 240412ab29e146d994c862cf927062dbe0219249 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 21 Aug 2024 12:40:42 +0200 Subject: [PATCH] follow-up for cudagraphs changes --- thunder/transforms/cudagraph.py | 62 +++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/thunder/transforms/cudagraph.py b/thunder/transforms/cudagraph.py index 07d19d9ae7..f22a5874ab 100644 --- a/thunder/transforms/cudagraph.py +++ b/thunder/transforms/cudagraph.py @@ -38,9 +38,31 @@ def extract_descriptor(arg): class CUDAGraphRunner: + """A class to facilitate creating and running cudagraphs for a CUDAGraphTransform. + + Key methods: + .make_cuda_graph_callable_from_symbols + the entry point from the CUDAGraphTransform that returns a callable + (mapping to .call_cuda_graph) for a given series of bound symbols and + inputs and outputs. + .call_cuda_graph + runs (and builds on cache miss) a cuda graph. This is (via the callable + from .make_cuda_graph_callable_from_symbols) the entry point during + execution + + There are two cache/information layers, one mapping to the callables via the name + of the fusion in the trace (.python_callables acts as a cache, .trace_symbols is just for + inspection). There is a separate cuda_graph_cache as there could be reaons to + generate multiple graphs for inputs (e.g. changes in strides for inputs), this + is .cuda_graph_cache. + + Note that these are good for inspection but are considered internals and might + change. + """ + def __init__(self): - self.cuda_graph_cache = {} - self.python_callables = {} # fn_name -> (callable. num_static_inputs) + self.cuda_graph_cache = {} # cahce_key (.make_cache_key) -> (graph, static_inputs, static_outputs) + self.python_callables = {} # fn_name -> (callable. static_input_mask (or None)) self.trace_symbols = {} # fn_name -> (bsyms, inputs, outputs) self.name_counter = 1 @@ -92,7 +114,7 @@ def make_cache_key(self, fn_name, *args): return (fn_name, self.make_static_inputs_mask(fn_name, *args), to_arg_descriptor(*args)) def call_cuda_graph(self, fn_name, *args): - fn, num_static_inputs = self.python_callables[fn_name] + fn, _ = self.python_callables[fn_name] cache_key = self.make_cache_key(fn_name, *args) @@ -142,34 +164,36 @@ def region_fn(): def make_cuda_graph_callable_from_symbols( self, - fn_name: str, bsyms: list[BoundSymbol], inputs: list[Proxy], outputs: list[Proxy], - num_static_inputs: int | None = None, + static_inputs_mask: Sequence[bool] | None = None, ) -> Callable: + # previously, one could pass the number of static inputs to get an automatic static_inputs_mask, + # but chances are that the transform could have more detailed information, so we take a mask + # static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs) - if num_static_inputs is not None: - static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs) + if static_inputs_mask is not None: + static_inputs_mask = tuple(static_inputs_mask) # ensure hashability else: static_inputs_mask = None - x_fn_name = f"{fn_name}_{self.name_counter}" + fn_name = f"CUDAGraph{self.name_counter}" self.name_counter += 1 - self.python_callables[x_fn_name] = ( + self.python_callables[fn_name] = ( self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs), static_inputs_mask, ) - self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs) + self.trace_symbols[fn_name] = (bsyms, inputs, outputs) def callable(*args): - return self.call_cuda_graph(x_fn_name, *args) + return self.call_cuda_graph(fn_name, *args) - callable.__name__ = fn_name - callable.__qualname__ = fn_name + callable.__name__ = f"{fn_name}_fn" + callable.__qualname__ = f"{fn_name}_fn" - return callable + return callable, fn_name class CUDAGraphTransform(Transform): @@ -184,7 +208,7 @@ def __init__(self): super().__init__() self.cuda_graph_runner = CUDAGraphRunner() - def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol: + def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol: inputs = [unvariableify(inp) for inp in region.inputs] outputs = [unvariableify(out) for out in region.outputs] @@ -192,10 +216,10 @@ def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | in region.bound_symbols = _del_last_used(region.bound_symbols, outputs) - fusion_name = f"CUDAGraph{fusion_counter}" - - fusion_callable = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols( - f"{fusion_name}_fn", region.bound_symbols, inputs, outputs, num_static_inputs + fusion_callable, fusion_name = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols( + region.bound_symbols, + inputs, + outputs, ) fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self)