Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cat for nvfuser >= 0.1.7 #35

Merged
merged 9 commits into from
Mar 24, 2024
10 changes: 2 additions & 8 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,20 +1087,14 @@ def broadcast_in_dim(


def _cat_check(tensors: list[TensorProxy], dim: int) -> bool:
# nvFuser cat fusion is currently disabled due to issue:
# "nvFuser doesn't support cating with an empty tensor"
return False
if nv_version < LooseVersion("0.1.7"):
return False

# Validates tensors and concatenated dimension lengths
for t in tensors:
if not is_supported_tensor(t):
return False

# See https://github.com/NVIDIA/Fuser/issues/21
# nvFuser cannot concatenate dimensions of length 1
if t.shape[dim] == 1:
return False

return True


Expand Down
35 changes: 26 additions & 9 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,12 +2762,12 @@ def cat_sample_generator(op, device, dtype, requires_grad, **kwargs):
]

for shapes, dim in cases:
yield SampleInput([make(s) for s in shapes], dim)
yield SampleInput(*[make(s) for s in shapes], dim=dim)

# Tests concatenating with a tensor broadcast along the concatenation dimension
a = make((5,))
b = make((1,)).expand((5,))
yield SampleInput((a, b))
yield SampleInput(a, b, dim=0)


def cat_error_generator(op, device, dtype=torch.float32, **kwargs):
Expand All @@ -2783,15 +2783,22 @@ def cat_error_generator(op, device, dtype=torch.float32, **kwargs):
]

for shapes, dim, exc_type, err_msg_match in cases:
yield SampleInput([make(s) for s in shapes], dim), exc_type, err_msg_match
yield SampleInput(*[make(s) for s in shapes], dim=dim), exc_type, err_msg_match


# nvfuserex_impl.to_descriptors refuses to take a **nested** list of tensors,
# reporting `ValueError: unrecognized type in arguments: <class 'list'>`.
# `cat_wrapper` is created to work around that.
def cat_wrapper(*args, dim):
return ltorch.cat(args, dim=dim)


cat_opinfo = OpInfo(
ltorch.cat,
cat_wrapper,
supports_grad=True,
sample_input_generator=cat_sample_generator,
error_input_generator=cat_error_generator,
torch_reference=torch.cat,
torch_reference=lambda *args, dim: torch.cat(args, dim=dim),
test_torch_compile_executor=True,
test_directives=(
# There's a bug in torch.compile + torch.cat for empty tensors in 2.1.0
Expand All @@ -2807,6 +2814,11 @@ def cat_error_generator(op, device, dtype=torch.float32, **kwargs):
active_if=(LooseVersion(torch.__version__) < "2.2.0"),
executors=("torchcompile",),
),
DecorateInfo(
pytest.mark.xfail(strict=True),
active_if=(nvfuser_version < "0.1.7"),
executors=("nvFuser",),
),
),
)
shape_ops.append(cat_opinfo)
Expand Down Expand Up @@ -3699,7 +3711,7 @@ def stack_sample_generator(op, device, dtype, requires_grad, **kwargs):
]

for shapes, dim in cases:
yield SampleInput([make(s) for s in shapes], dim)
yield SampleInput(*[make(s) for s in shapes], dim=dim)


def stack_error_generator(op, device, dtype=torch.float32, **kwargs):
Expand All @@ -3718,14 +3730,19 @@ def stack_error_generator(op, device, dtype=torch.float32, **kwargs):
]

for shapes, dim, exc_type, err_msg_match in cases:
yield SampleInput([make(s) for s in shapes], dim), exc_type, err_msg_match
yield SampleInput(*[make(s) for s in shapes], dim=dim), exc_type, err_msg_match


# `stack_wrapper` is created for the same reason as `cat_wrapper.
def stack_wrapper(*args, dim):
return ltorch.stack(args, dim=dim)


stack_opinfo = OpInfo(
ltorch.stack,
stack_wrapper,
sample_input_generator=stack_sample_generator,
error_input_generator=stack_error_generator,
torch_reference=torch.stack,
torch_reference=lambda *args, dim: torch.stack(args, dim=dim),
test_directives=(
# vjp and jvp not yet implemented
DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"),
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def test_cse_rematerialization(executor, device, _):

fw_trace = thunder.last_traces(compiled_func)[-1]
fusion_bsyms = tuple(filter(lambda a: a.sym.is_fusion, fw_trace.bound_symbols))
assert len(fusion_bsyms) == 13
# fusion groups 1 and 7 correspond with the apply_rotary_emb function
assert len(fusion_bsyms) == 11
# fusion groups 1 and 6 correspond with the apply_rotary_emb function
# Nvfuser with recomputation should use precomputed cos and sin values.
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[7].args)
assert fusion_bsyms[1].subsymbols[0].output.name == "freqs_cos"
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_errors(op, device, _, executor, comp):
# Snippets run a single test using a single sample
# TODO: should snippets be able to access the original opinfo? -- No?
# TODO: revisit atol/rtol, maybe be more selective about which ops need a more permissive check
def snippet_torch_consistency(op, torch_op, sample, comp):
def snippet_torch_consistency(op: OpInfo, torch_op, sample: SampleInput, comp: Callable):
thunder_result = op(*sample.args, **sample.kwargs)
torch_result = torch_op(*sample.args, **sample.kwargs)

Expand Down
Loading