Skip to content

Commit

Permalink
Enable cat for nvfuser >= 0.1.7 (PR1844) (#35)
Browse files Browse the repository at this point in the history
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
  • Loading branch information
jacobhinkle and wujingyue authored Mar 24, 2024
1 parent ada6dc7 commit f36faaa
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 23 deletions.
10 changes: 2 additions & 8 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,20 +1087,14 @@ def broadcast_in_dim(


def _cat_check(tensors: list[TensorProxy], dim: int) -> bool:
# nvFuser cat fusion is currently disabled due to issue:
# "nvFuser doesn't support cating with an empty tensor"
return False
if nv_version < LooseVersion("0.1.7"):
return False

# Validates tensors and concatenated dimension lengths
for t in tensors:
if not is_supported_tensor(t):
return False

# See https://github.com/NVIDIA/Fuser/issues/21
# nvFuser cannot concatenate dimensions of length 1
if t.shape[dim] == 1:
return False

return True


Expand Down
35 changes: 26 additions & 9 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,12 +2762,12 @@ def cat_sample_generator(op, device, dtype, requires_grad, **kwargs):
]

for shapes, dim in cases:
yield SampleInput([make(s) for s in shapes], dim)
yield SampleInput(*[make(s) for s in shapes], dim=dim)

# Tests concatenating with a tensor broadcast along the concatenation dimension
a = make((5,))
b = make((1,)).expand((5,))
yield SampleInput((a, b))
yield SampleInput(a, b, dim=0)


def cat_error_generator(op, device, dtype=torch.float32, **kwargs):
Expand All @@ -2783,15 +2783,22 @@ def cat_error_generator(op, device, dtype=torch.float32, **kwargs):
]

for shapes, dim, exc_type, err_msg_match in cases:
yield SampleInput([make(s) for s in shapes], dim), exc_type, err_msg_match
yield SampleInput(*[make(s) for s in shapes], dim=dim), exc_type, err_msg_match


# nvfuserex_impl.to_descriptors refuses to take a **nested** list of tensors,
# reporting `ValueError: unrecognized type in arguments: <class 'list'>`.
# `cat_wrapper` is created to work around that.
def cat_wrapper(*args, dim):
return ltorch.cat(args, dim=dim)


cat_opinfo = OpInfo(
ltorch.cat,
cat_wrapper,
supports_grad=True,
sample_input_generator=cat_sample_generator,
error_input_generator=cat_error_generator,
torch_reference=torch.cat,
torch_reference=lambda *args, dim: torch.cat(args, dim=dim),
test_torch_compile_executor=True,
test_directives=(
# There's a bug in torch.compile + torch.cat for empty tensors in 2.1.0
Expand All @@ -2807,6 +2814,11 @@ def cat_error_generator(op, device, dtype=torch.float32, **kwargs):
active_if=(LooseVersion(torch.__version__) < "2.2.0"),
executors=("torchcompile",),
),
DecorateInfo(
pytest.mark.xfail(strict=True),
active_if=(nvfuser_version < "0.1.7"),
executors=("nvFuser",),
),
),
)
shape_ops.append(cat_opinfo)
Expand Down Expand Up @@ -3699,7 +3711,7 @@ def stack_sample_generator(op, device, dtype, requires_grad, **kwargs):
]

for shapes, dim in cases:
yield SampleInput([make(s) for s in shapes], dim)
yield SampleInput(*[make(s) for s in shapes], dim=dim)


def stack_error_generator(op, device, dtype=torch.float32, **kwargs):
Expand All @@ -3718,14 +3730,19 @@ def stack_error_generator(op, device, dtype=torch.float32, **kwargs):
]

for shapes, dim, exc_type, err_msg_match in cases:
yield SampleInput([make(s) for s in shapes], dim), exc_type, err_msg_match
yield SampleInput(*[make(s) for s in shapes], dim=dim), exc_type, err_msg_match


# `stack_wrapper` is created for the same reason as `cat_wrapper.
def stack_wrapper(*args, dim):
return ltorch.stack(args, dim=dim)


stack_opinfo = OpInfo(
ltorch.stack,
stack_wrapper,
sample_input_generator=stack_sample_generator,
error_input_generator=stack_error_generator,
torch_reference=torch.stack,
torch_reference=lambda *args, dim: torch.stack(args, dim=dim),
test_directives=(
# vjp and jvp not yet implemented
DecorateInfo(pytest.mark.xfail, "test_jvp_correctness"),
Expand Down
10 changes: 5 additions & 5 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,14 @@ def test_cse_rematerialization(executor, device, _):

fw_trace = thunder.last_traces(compiled_func)[-1]
fusion_bsyms = tuple(filter(lambda a: a.sym.is_fusion, fw_trace.bound_symbols))
assert len(fusion_bsyms) == 13
# fusion groups 1 and 7 correspond with the apply_rotary_emb function
assert len(fusion_bsyms) == 11
# fusion groups 1 and 6 correspond with the apply_rotary_emb function
# Nvfuser with recomputation should use precomputed cos and sin values.
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[7].args)
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[6].args)
assert fusion_bsyms[1].subsymbols[0].output.name == "freqs_cos"
assert fusion_bsyms[1].subsymbols[1].output.name == "freqs_sin"
assert fusion_bsyms[7].subsymbols[0].output.name == "freqs_cos"
assert fusion_bsyms[7].subsymbols[1].output.name == "freqs_sin"
assert fusion_bsyms[6].subsymbols[0].output.name == "freqs_cos"
assert fusion_bsyms[6].subsymbols[1].output.name == "freqs_sin"


# Tests that two separated nvFuser regions can be merged when they don't depend
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_errors(op, device, _, executor, comp):
# Snippets run a single test using a single sample
# TODO: should snippets be able to access the original opinfo? -- No?
# TODO: revisit atol/rtol, maybe be more selective about which ops need a more permissive check
def snippet_torch_consistency(op, torch_op, sample, comp):
def snippet_torch_consistency(op: OpInfo, torch_op, sample: SampleInput, comp: Callable):
thunder_result = op(*sample.args, **sample.kwargs)
torch_result = torch_op(*sample.args, **sample.kwargs)

Expand Down

0 comments on commit f36faaa

Please sign in to comment.