Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cat: support inputs with mixed dtypes #819

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,20 @@ def unfold(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy:
@clangop()
def cat(tensors: list[TensorProxy], dim: int):
"""Concatenates the given sequence of tensors in the given dimension."""
return prims.cat(tensors, dim)
# Upcast tensors only if we have more than 1 tensor.
# NumPy and PyTorch support upcasting with mixed dtypes.
if len(tensors) > 1:
_, output_dtype = utils.elementwise_type_promotion(
*tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE
)
promoted_tensors = []
for t in tensors:
if t.dtype != output_dtype:
t = prims.convert_element_type(t, output_dtype)
promoted_tensors.append(t)
else:
promoted_tensors = tensors
return prims.cat(promoted_tensors, dim)


@clangop()
Expand All @@ -1299,7 +1312,7 @@ def stack(tensors: list[TensorProxy], dim: int):
s == shapes[0], lambda: f"tensors must be of the same shape, tensor at {i} is {s} instead of {shapes[0]}"
)
tensors_ = [unsqueeze(t, dim) for t in tensors]
return prims.cat(tensors_, dim)
return cat(tensors_, dim)


@clangop()
Expand Down
29 changes: 29 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3012,3 +3012,32 @@ def fn():
jfn = thunder.jit(fn)
assert fn() == jfn()
assert jfn() == torch.int64


def test_cat_mixed_dtypes():
# We add a special test here instead of a sample in OpInfo.
# When we add a mixed input sample in OpInfo, it will also be picked up for the test which
# computes numerical Jacobian vector product and compares it with analytical. The test will produce failures
# when run in precision lower than double (and we can't disable a sample based on tests).
# See comment - https://github.com/Lightning-AI/lightning-thunder/pull/819#issuecomment-2244761476
def fn(tensors):
return torch.cat(tensors, dim=0)

tensors = (torch.randn(3, requires_grad=True), torch.randn(3, dtype=torch.float16, requires_grad=True))
with torch.no_grad():
tensors_jit = tuple(t.detach().clone() for t in tensors)
for t in tensors_jit:
t.requires_grad_(True)

# Compare forward
jfn = thunder.jit(fn)
expected = fn(tensors)
actual = jfn(tensors_jit)
torch.testing.assert_close(actual, expected)

# Compare backward
cotangent = torch.randn_like(expected)
expected.backward(cotangent)
actual.backward(cotangent)

torch.testing.assert_close(tuple(t.grad for t in tensors), tuple(t.grad for t in tensors_jit))
Loading