From e4c0efbae129758a0bec36d6ab2f3767d9c7cfc2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:51:33 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/cudagraphex.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/thunder/executors/cudagraphex.py b/thunder/executors/cudagraphex.py index 6432273fa1..a5fe3e3930 100644 --- a/thunder/executors/cudagraphex.py +++ b/thunder/executors/cudagraphex.py @@ -40,7 +40,7 @@ def extract_descriptor(arg): @lru_cache def build_cuda_graph( fn: Callable, args_descriptor: ArgsDescriptor, static_args_mask: tuple[bool, ...] - ) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]: +) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]: def get_static_buffer(x): if isinstance(x, torch.Tensor): @@ -54,7 +54,9 @@ def get_static_buffer(x): stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): - static_inputs = tuple(get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask)) + static_inputs = tuple( + get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask) + ) for _ in range(3): fn(*static_inputs) stream.synchronize() @@ -106,8 +108,7 @@ def make_callable( from inspect import Parameter, Signature region_fn_params = ( - Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) - for i, param in enumerate(inputs) + Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) for i, param in enumerate(inputs) ) region_fn_signature = Signature(region_fn_params) @@ -143,7 +144,11 @@ def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | in fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self) fusion_bsym = BoundSymbol( - fusion_sym, inputs, {}, outputs, region.bound_symbols, + fusion_sym, + inputs, + {}, + outputs, + region.bound_symbols, _call_ctx={fusion_name: fusion_callable}, ) @@ -219,7 +224,7 @@ def _can_fuse_node(n: Node): fused_trace.bound_symbols = fused_bsyms end_time_ns = time.time_ns() - elapsed_time_ns = (end_time_ns - start_time_ns) + elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_ms = elapsed_time_ns // 1000000 fused_trace.set_provenance(TraceProvenance(f"CUDAGraph fusion (took {elapsed_time_ms} milliseconds)"))