Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 26, 2024
1 parent 6a1ce53 commit e4c0efb
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions thunder/executors/cudagraphex.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def extract_descriptor(arg):
@lru_cache
def build_cuda_graph(
fn: Callable, args_descriptor: ArgsDescriptor, static_args_mask: tuple[bool, ...]
) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]:
) -> tuple[torch.cuda.CUDAGraph, Sequence[torch.Tensor | Any], Sequence[torch.Tensor | Any]]:

def get_static_buffer(x):
if isinstance(x, torch.Tensor):
Expand All @@ -54,7 +54,9 @@ def get_static_buffer(x):
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
static_inputs = tuple(get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask))
static_inputs = tuple(
get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask)
)
for _ in range(3):
fn(*static_inputs)
stream.synchronize()
Expand Down Expand Up @@ -106,8 +108,7 @@ def make_callable(
from inspect import Parameter, Signature

region_fn_params = (
Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY)
for i, param in enumerate(inputs)
Parameter(getattr(param, "name", f"arg{i}"), Parameter.POSITIONAL_ONLY) for i, param in enumerate(inputs)
)

region_fn_signature = Signature(region_fn_params)
Expand Down Expand Up @@ -143,7 +144,11 @@ def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | in

fusion_sym = Symbol(fusion_name, meta=None, is_fusion=True, executor=self)
fusion_bsym = BoundSymbol(
fusion_sym, inputs, {}, outputs, region.bound_symbols,
fusion_sym,
inputs,
{},
outputs,
region.bound_symbols,
_call_ctx={fusion_name: fusion_callable},
)

Expand Down Expand Up @@ -219,7 +224,7 @@ def _can_fuse_node(n: Node):
fused_trace.bound_symbols = fused_bsyms

end_time_ns = time.time_ns()
elapsed_time_ns = (end_time_ns - start_time_ns)
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_ms = elapsed_time_ns // 1000000
fused_trace.set_provenance(TraceProvenance(f"CUDAGraph fusion (took {elapsed_time_ms} milliseconds)"))

Expand Down

0 comments on commit e4c0efb

Please sign in to comment.