diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index ddcaea3fa8..4c4bad9084 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1286,7 +1286,20 @@ def unfold(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy: @clangop() def cat(tensors: list[TensorProxy], dim: int): """Concatenates the given sequence of tensors in the given dimension.""" - return prims.cat(tensors, dim) + # Upcast tensors only if we have more than 1 tensor. + # NumPy and PyTorch support upcasting with mixed dtypes. + if len(tensors) > 1: + _, output_dtype = utils.elementwise_type_promotion( + *tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE + ) + promoted_tensors = [] + for t in tensors: + if t.dtype != output_dtype: + t = prims.convert_element_type(t, output_dtype) + promoted_tensors.append(t) + else: + promoted_tensors = tensors + return prims.cat(promoted_tensors, dim) @clangop() @@ -1299,7 +1312,7 @@ def stack(tensors: list[TensorProxy], dim: int): s == shapes[0], lambda: f"tensors must be of the same shape, tensor at {i} is {s} instead of {shapes[0]}" ) tensors_ = [unsqueeze(t, dim) for t in tensors] - return prims.cat(tensors_, dim) + return cat(tensors_, dim) @clangop() diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index cf81cf93c2..24de5acdaa 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3012,3 +3012,32 @@ def fn(): jfn = thunder.jit(fn) assert fn() == jfn() assert jfn() == torch.int64 + + +def test_cat_mixed_dtypes(): + # We add a special test here instead of a sample in OpInfo. + # When we add a mixed input sample in OpInfo, it will also be picked up for the test which + # computes numerical Jacobian vector product and compares it with analytical. The test will produce failures + # when run in precision lower than double (and we can't disable a sample based on tests). + # See comment - https://github.com/Lightning-AI/lightning-thunder/pull/819#issuecomment-2244761476 + def fn(tensors): + return torch.cat(tensors, dim=0) + + tensors = (torch.randn(3, requires_grad=True), torch.randn(3, dtype=torch.float16, requires_grad=True)) + with torch.no_grad(): + tensors_jit = tuple(t.detach().clone() for t in tensors) + for t in tensors_jit: + t.requires_grad_(True) + + # Compare forward + jfn = thunder.jit(fn) + expected = fn(tensors) + actual = jfn(tensors_jit) + torch.testing.assert_close(actual, expected) + + # Compare backward + cotangent = torch.randn_like(expected) + expected.backward(cotangent) + actual.backward(cotangent) + + torch.testing.assert_close(tuple(t.grad for t in tensors), tuple(t.grad for t in tensors_jit))