From 510ed00571b627adff82ac4bffa14d08dfb3dd4f Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 22 Jul 2024 13:28:16 +0200 Subject: [PATCH 1/4] cat: support inputs with mixed dtypes --- thunder/clang/__init__.py | 17 +++++++++++++++-- thunder/tests/opinfos.py | 6 ++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index ddcaea3fa8..3f61094952 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1286,7 +1286,20 @@ def unfold(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy: @clangop() def cat(tensors: list[TensorProxy], dim: int): """Concatenates the given sequence of tensors in the given dimension.""" - return prims.cat(tensors, dim) + # Upcast to tensors only if we have more than 1 tensor. + # NumPy and PyTorch support upcasting with mixed dtypes. + if len(tensors) > 1: + _, output_dtype = utils.elementwise_type_promotion( + *tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE + ) + promoted_tensors = [] + for t in tensors: + if t.dtype != output_dtype: + t = prims.convert_element_type(t, output_dtype) + promoted_tensors.append(t) + else: + promoted_tensors = tensors + return prims.cat(promoted_tensors, dim) @clangop() @@ -1299,7 +1312,7 @@ def stack(tensors: list[TensorProxy], dim: int): s == shapes[0], lambda: f"tensors must be of the same shape, tensor at {i} is {s} instead of {shapes[0]}" ) tensors_ = [unsqueeze(t, dim) for t in tensors] - return prims.cat(tensors_, dim) + return cat(tensors_, dim) @clangop() diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 6d181e2780..2c170186d5 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3107,6 +3107,12 @@ def cat_sample_generator(op, device, dtype, requires_grad, **kwargs): b = make((1,)).expand((5,)) yield SampleInput(a, b, dim=0) + # test upcasting. PyTorch has upcasting logic in cat. + if dtype != torch.float: + s1 = (1,) + s2 = (1,) + yield SampleInput(*[make(s1), make(s2, dtype=torch.float)], dim=0) + def cat_error_generator(op, device, dtype=torch.float32, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) From 5c0b788e6aaf8eb6c6709d81ad905ebd7feb9c25 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 22 Jul 2024 13:33:56 +0200 Subject: [PATCH 2/4] update comment --- thunder/clang/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 3f61094952..4c4bad9084 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1286,7 +1286,7 @@ def unfold(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy: @clangop() def cat(tensors: list[TensorProxy], dim: int): """Concatenates the given sequence of tensors in the given dimension.""" - # Upcast to tensors only if we have more than 1 tensor. + # Upcast tensors only if we have more than 1 tensor. # NumPy and PyTorch support upcasting with mixed dtypes. if len(tensors) > 1: _, output_dtype = utils.elementwise_type_promotion( From 1a4860a0da748826cf8e190155b659fd3d71f4db Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 23 Jul 2024 13:46:47 +0200 Subject: [PATCH 3/4] update --- thunder/tests/opinfos.py | 6 ------ thunder/tests/test_core.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 2c170186d5..6d181e2780 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3107,12 +3107,6 @@ def cat_sample_generator(op, device, dtype, requires_grad, **kwargs): b = make((1,)).expand((5,)) yield SampleInput(a, b, dim=0) - # test upcasting. PyTorch has upcasting logic in cat. - if dtype != torch.float: - s1 = (1,) - s2 = (1,) - yield SampleInput(*[make(s1), make(s2, dtype=torch.float)], dim=0) - def cat_error_generator(op, device, dtype=torch.float32, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index cf81cf93c2..ca439e3ade 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3012,3 +3012,31 @@ def fn(): jfn = thunder.jit(fn) assert fn() == jfn() assert jfn() == torch.int64 + + +def test_cat_mixed_dtypes(): + # We add a special test here instead of a sample in OpInfo. + # When we add a mixed input sample in OpInfo, it will also be picked up for the test which + # computes numerical Jacobian vector product and compares it with analytical. The test will produce failures + # when run in precision lower than double (and we can't disable a sample based on tests). + def fn(tensors): + return torch.cat(tensors, dim=0) + + tensors = (torch.randn(3, requires_grad=True), torch.randn(3, dtype=torch.float16, requires_grad=True)) + with torch.no_grad(): + tensors_jit = tuple(t.detach().clone() for t in tensors) + for t in tensors_jit: + t.requires_grad_(True) + + # Compare forward + jfn = thunder.jit(fn) + expected = fn(tensors) + actual = jfn(tensors_jit) + torch.testing.assert_close(actual, expected) + + # Compare backward + cotangent = torch.randn_like(expected) + expected.backward(cotangent) + actual.backward(cotangent) + + torch.testing.assert_close(tuple(t.grad for t in tensors), tuple(t.grad for t in tensors_jit)) From 2620479e3ef322dec86668d94bd5ae735b3ad7e8 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 23 Jul 2024 13:47:16 +0200 Subject: [PATCH 4/4] update comment --- thunder/tests/test_core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index ca439e3ade..24de5acdaa 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3019,6 +3019,7 @@ def test_cat_mixed_dtypes(): # When we add a mixed input sample in OpInfo, it will also be picked up for the test which # computes numerical Jacobian vector product and compares it with analytical. The test will produce failures # when run in precision lower than double (and we can't disable a sample based on tests). + # See comment - https://github.com/Lightning-AI/lightning-thunder/pull/819#issuecomment-2244761476 def fn(tensors): return torch.cat(tensors, dim=0)