diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 1caa41e4f5..5d49d26ef2 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -90,14 +90,11 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype): config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=4, n_head=6, n_embd=768) gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype).requires_grad_(False).eval() - from thunder.transforms.cudagraph import CUDAGraphTransform, build_cuda_graph + from thunder.transforms.cudagraph import CUDAGraphTransform cgtransform = CUDAGraphTransform() tom = executor.make_callable(gpt, transforms=[cgtransform], disable_torch_autograd=True) - # Cache stats before test runs - build_graph_stats_old = build_cuda_graph.cache_info() - for _ in range(2): idx = make((4, 64), dtype=torch.int64, low=0, high=255) torch_result = gpt(idx) @@ -105,20 +102,12 @@ def test_nanogpt_complete_cudagraphs(executor, device, dtype): thunder_result = tom(idx) assert_close(torch_result, thunder_result) - # Cache stats after test runs - build_graph_stats_new = build_cuda_graph.cache_info() # We ran only a single (forward) graph several times. - # Test that at most 1 cache miss happened after the runs. - assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 1 + # Test that at only 1 cache entry was created during the runs. + assert len(cgtransform.cuda_graph_runner.cuda_graph_cache) == 1 - # Check we really use CUDA graphs { + # Check we really use CUDA graphs assert _there_is_cudagraph_sym(thunder.last_traces(tom)[-1]) - # } - - # Let's clear cache if run only in tests - # TODO: merge with the cache of the thunder.jit callable - if build_graph_stats_old.misses == 0: - build_cuda_graph.cache_clear() @instantiate(dtypes=(thunder.float32,), devicetypes=(thunder.devices.DeviceType.CUDA,)) @@ -131,15 +120,11 @@ def test_nanogpt_complete_cudagraphs_autograd(executor, device, dtype): config = nanogpt_model.GPTConfig(dropout=0, block_size=512, n_layer=6, n_head=6, n_embd=768) gpt = nanogpt_model.GPT(config).to(device=device, dtype=tdtype) - from thunder.transforms.cudagraph import CUDAGraphTransform, build_cuda_graph + from thunder.transforms.cudagraph import CUDAGraphTransform cgtransform = CUDAGraphTransform() cmodel = executor.make_callable(gpt, transforms=[cgtransform]) - # Checking graph cache stats - # Cache stats before test runs - build_graph_stats_old = build_cuda_graph.cache_info() - # Multiple runs to test whether static buffers are properly updated for i in range(3): x = make_tensor((4, 64), dtype=torch.int64, low=0, high=255, device=device) @@ -154,22 +139,14 @@ def test_nanogpt_complete_cudagraphs_autograd(executor, device, dtype): assert_close(torch_result, thunder_result) assert_close(torch_grads, thunder_grads) - # Cache stats after test runs - build_graph_stats_new = build_cuda_graph.cache_info() # We ran only at most two (forward and backward) graphs several times. # Test that at most 2 cache misses happened after the runs # (at most one per each graph) - assert (build_graph_stats_new.misses - build_graph_stats_old.misses) <= 2 + assert len(cgtransform.cuda_graph_runner.cuda_graph_cache) == 2 - # Check we really run CUDAGraphExecutor { + # Check we have CUDAGraph symbols in forward and backward assert _there_is_cudagraph_sym(thunder.last_traces(cmodel)[-1]) assert _there_is_cudagraph_sym(thunder.last_backward_traces(cmodel)[-1]) - # } - - # Let's clear cache if run only in tests - # TODO: merge with the cache of the thunder.jit callable - if build_graph_stats_old.misses == 0: - build_cuda_graph.cache_clear() @instantiate(dtypes=(thunder.float32,), executors=all_test_executors_and_dynamo) diff --git a/thunder/transforms/cudagraph.py b/thunder/transforms/cudagraph.py index f2382aef32..83afa310b6 100644 --- a/thunder/transforms/cudagraph.py +++ b/thunder/transforms/cudagraph.py @@ -24,7 +24,6 @@ class ArgsDescriptor: sizes: tuple strides: tuple non_tensor_args: tuple - args: list = field(hash=False, repr=False, compare=False) def to_arg_descriptor(*args): @@ -35,63 +34,75 @@ def extract_descriptor(arg): return type(arg), None, None, arg dtypes, sizes, strides, non_tensor_args = zip(*map(extract_descriptor, args)) - return ArgsDescriptor(dtypes, sizes, strides, non_tensor_args, args) + return ArgsDescriptor(dtypes, sizes, strides, non_tensor_args) -@lru_cache -def build_cuda_graph( - fn: Callable, args_descriptor: ArgsDescriptor, static_args_mask: tuple[bool, ...] -) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]: +class CUDAGraphRunner: + def __init__(self): + self.cuda_graph_cache = {} + self.python_callables = {} # fn_name -> (callable. num_static_inputs) + self.trace_symbols = {} # fn_name -> (bsyms, inputs, outputs) + self.name_counter = 1 - def get_static_buffer(x): + def get_static_buffer(self, x): if isinstance(x, torch.Tensor): return torch.empty_like(x).copy_(x) return x - args = args_descriptor.args - - # Warmup - torch.cuda.synchronize() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(stream): - static_inputs = tuple( - get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask) - ) - for _ in range(3): - fn(*static_inputs) - - stream.synchronize() - torch.cuda.current_stream().wait_stream(stream) - torch.cuda.synchronize() - - # Record - # NOTE: We are using default private pool here, but it is possibly better to - # use a custom pool for better memory management. See CUDA Graphs Tree in - # PyTorch's Inductor: torch/_inductor/cudagraph_trees.py - # Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/view - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - static_outputs = fn(*static_inputs) - - return graph, static_inputs, static_outputs + def build_cuda_graph( + self, fn: Callable, args: list[any], static_args_mask: tuple[bool, ...] + ) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]: + + # Warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream): + static_inputs = tuple( + self.get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask) + ) + for _ in range(3): + fn(*static_inputs) + + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # Record + # NOTE: We are using default private pool here, but it is possibly better to + # use a custom pool for better memory management. See CUDA Graphs Tree in + # PyTorch's Inductor: torch/_inductor/cudagraph_trees.py + # Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/view + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + static_outputs = fn(*static_inputs) + + return graph, static_inputs, static_outputs + + def make_static_inputs_mask(self, fn_name, *args): + _, static_inputs_mask = self.python_callables[fn_name] + if static_inputs_mask is None: + static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args) + return static_inputs_mask + def make_cache_key(self, fn_name, *args): + # if args_descriptor included torch.nn.Parameter-ness, we would + # could use the static_inputs_mask or None as the key. + return (fn_name, self.make_static_inputs_mask(fn_name, *args), to_arg_descriptor(*args)) -class CUDAGraphCallable: - def __init__(self, fn: Callable, num_static_inputs: None | int = None): - self.fn = fn - self.num_static_inputs = num_static_inputs + def call_cuda_graph(self, fn_name, *args): + fn, num_static_inputs = self.python_callables[fn_name] - def __call__(self, *args): - if self.num_static_inputs is not None: - static_inputs_mask = (True,) * self.num_static_inputs + (False,) * (len(args) - self.num_static_inputs) - else: - static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args) + cache_key = self.make_cache_key(fn_name, *args) - args_descriptor = to_arg_descriptor(*args) + cache_entry = self.cuda_graph_cache.get(cache_key) + if cache_entry is None: + static_inputs_mask = self.make_static_inputs_mask(fn_name, *args) + cache_entry = self.build_cuda_graph(fn, args, static_inputs_mask) + self.cuda_graph_cache[cache_key] = cache_entry - graph, static_inputs, static_outputs = build_cuda_graph(self.fn, args_descriptor, static_inputs_mask) + graph, static_inputs, static_outputs = cache_entry for static_input, arg in utils.safe_zip(static_inputs, args): if id(static_input) != id(arg) and isinstance(static_input, torch.Tensor) and isinstance(arg, torch.Tensor): @@ -100,44 +111,79 @@ def __call__(self, *args): graph.replay() return static_outputs + def make_python_callable_from_symbols( + self, + fn_name: str, + bsyms: list[BoundSymbol], + inputs: list[Proxy], + outputs: list[Proxy], + ) -> Callable: -def make_callable( - fn_name: str, - bsyms: list[BoundSymbol], - inputs: list[Proxy], - outputs: list[Proxy], -) -> Callable: - from inspect import Parameter, Signature + from inspect import Parameter, Signature - region_fn_params = ( - Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) for i, param in enumerate(inputs) - ) + region_fn_params = ( + Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) for i, param in enumerate(inputs) + ) - region_fn_signature = Signature(region_fn_params) + region_fn_signature = Signature(region_fn_params) - def region_fn(): - pass + def region_fn(): + pass - region_fn.__signature__ = region_fn_signature - region_fn.__name__ = fn_name + region_fn.__signature__ = region_fn_signature + region_fn.__name__ = fn_name - region_trace = TraceCtx(region_fn) - region_trace.bound_symbols = bsyms - region_trace.args = inputs - region_trace.kwargs = {} - region_trace.bound_symbols.append(prims.python_return.bind(outputs, output=())) + region_trace = TraceCtx(region_fn) + region_trace.bound_symbols = bsyms + region_trace.args = inputs + region_trace.kwargs = {} + region_trace.bound_symbols.append(prims.python_return.bind(outputs, output=())) + return region_trace.python_callable() - return region_trace.python_callable() + def make_cuda_graph_callable_from_symbols( + self, + fn_name: str, + bsyms: list[BoundSymbol], + inputs: list[Proxy], + outputs: list[Proxy], + num_static_inputs: int | None = None, + ) -> Callable: + + if num_static_inputs is not None: + static_inputs_mask = (True,) * num_static_inputs + (False,) * (len(inputs) - num_static_inputs) + else: + static_inputs_mask = None + + x_fn_name = f"{fn_name}_{self.name_counter}" + self.name_counter += 1 + + self.python_callables[x_fn_name] = ( + self.make_python_callable_from_symbols(fn_name, bsyms, inputs, outputs), + static_inputs_mask, + ) + self.trace_symbols[x_fn_name] = (bsyms, inputs, outputs) + + def callable(*args): + return self.call_cuda_graph(x_fn_name, *args) + + callable.__name__ = fn_name + callable.__qualname__ = fn_name + + return callable class CUDAGraphTransform(Transform): """ Transform to fuse operations into CUDA graphs post optimization. - This class provides the basic infrastructure, but it is expected that you might subclass this transform + This class provides the basic infrastructure, but it is expected that you might subclass this transformm in order to override ``can_fuse```or other methods. """ + def __init__(self): + super().__init__() + self.cuda_graph_runner = CUDAGraphRunner() + def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol: inputs = [unvariableify(inp) for inp in region.inputs] outputs = [unvariableify(out) for out in region.outputs] @@ -147,8 +193,10 @@ def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | in region.bound_symbols = _del_last_used(region.bound_symbols, outputs) fusion_name = f"CUDAGraph{fusion_counter}" - fusible_callable: Callable = make_callable(f"{fusion_name}_fn", region.bound_symbols, inputs, outputs) - fusion_callable = CUDAGraphCallable(fusible_callable, num_static_inputs) + + fusion_callable = self.cuda_graph_runner.make_cuda_graph_callable_from_symbols( + f"{fusion_name}_fn", region.bound_symbols, inputs, outputs, num_static_inputs + ) fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self) fusion_bsym = BoundSymbol(