Skip to content

Commit

Permalink
fix cudagraphs without tensor inputs (#1324)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Oct 18, 2024
1 parent c32f499 commit 45bd025
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
15 changes: 15 additions & 0 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,18 @@ def forward(x):
break
else:
raise RuntimeError("Failed to find `add` symbol in trace")


@requiresCUDA
def test_cudagraph_empty_inputs():
def fn():
a = torch.ones(5, 5, device="cuda")
b = a * 2
return b

from thunder.transforms.cudagraph import CUDAGraphTransform

jfn = thunder.jit(fn, transforms=(CUDAGraphTransform(),), executors=())
assert_close(jfn(), fn())

assert any(("CUDAGraph" in bsym.sym.name) for bsym in thunder.last_traces(jfn)[-1].bound_symbols)
5 changes: 4 additions & 1 deletion thunder/transforms/cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def extract_descriptor(arg):
else:
return type(arg), None, None, arg

dtypes, sizes, strides, non_tensor_args = zip(*map(extract_descriptor, args))
if args:
dtypes, sizes, strides, non_tensor_args = zip(*map(extract_descriptor, args))
else:
dtypes = sizes = strides = non_tensor_args = None
return ArgsDescriptor(dtypes, sizes, strides, non_tensor_args)


Expand Down

0 comments on commit 45bd025

Please sign in to comment.