From 83ca40b1a34ec659d0ae0900d5e9ea1f4fd43b57 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 21 Mar 2024 17:18:01 -0700 Subject: [PATCH] Add cat upcasting support for PyTorch compat --- thunder/clang/__init__.py | 15 ++++++++++++++- thunder/tests/opinfos.py | 5 +++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 3fd192802f..af89b1f7b6 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1117,6 +1117,12 @@ def unsqueeze(a, /, dims: int | Sequence[int]) -> TensorProxy: @clangop() def cat(tensors: list[TensorProxy], dim: int): """Concatenates the given sequence of tensors in the given dimension.""" + + if tensors: + _, result_dtype = utils.elementwise_type_promotion( + *tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE + ) + tensors = tuple(maybe_convert_to_dtype(a, result_dtype) for a in tensors) return prims.cat(tensors, dim) @@ -1129,7 +1135,14 @@ def stack(tensors: list[TensorProxy], dim: int): utils.check( s == shapes[0], lambda: f"tensors must be of the same shape, tensor at {i} is {s} instead of {shapes[0]}" ) - tensors_ = [unsqueeze(t, dim) for t in tensors] + + tensors_ = tuple(unsqueeze(t, dim) for t in tensors) + if tensors: + _, result_dtype = utils.elementwise_type_promotion( + *tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE + ) + tensors = tuple(maybe_convert_to_dtype(a, result_dtype) for a in tensors) + return prims.cat(tensors_, dim) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index d934e18dfd..bc8429500f 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2764,6 +2764,11 @@ def cat_sample_generator(op, device, dtype, requires_grad, **kwargs): for shapes, dim in cases: yield SampleInput([make(s) for s in shapes], dim) + # test upcasting (in case the dtype is not float16). PyTorch has upcasting logic in cat. + if dtype != torch.float16: + for shapes, dim in cases: + yield SampleInput([make(s) if i != 1 else make(s, dtype=torch.float16) for i, s in enumerate(shapes)], dim) + # Tests concatenating with a tensor broadcast along the concatenation dimension a = make((5,)) b = make((1,)).expand((5,))