Skip to content

Commit

Permalink
Add cat upcasting support for PyTorch compat
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Mar 22, 2024
1 parent 3492f14 commit 83ca40b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
15 changes: 14 additions & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,12 @@ def unsqueeze(a, /, dims: int | Sequence[int]) -> TensorProxy:
@clangop()
def cat(tensors: list[TensorProxy], dim: int):
"""Concatenates the given sequence of tensors in the given dimension."""

if tensors:
_, result_dtype = utils.elementwise_type_promotion(
*tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE
)
tensors = tuple(maybe_convert_to_dtype(a, result_dtype) for a in tensors)
return prims.cat(tensors, dim)


Expand All @@ -1129,7 +1135,14 @@ def stack(tensors: list[TensorProxy], dim: int):
utils.check(
s == shapes[0], lambda: f"tensors must be of the same shape, tensor at {i} is {s} instead of {shapes[0]}"
)
tensors_ = [unsqueeze(t, dim) for t in tensors]

tensors_ = tuple(unsqueeze(t, dim) for t in tensors)
if tensors:
_, result_dtype = utils.elementwise_type_promotion(
*tensors, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.PRESERVE
)
tensors = tuple(maybe_convert_to_dtype(a, result_dtype) for a in tensors)

return prims.cat(tensors_, dim)


Expand Down
5 changes: 5 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,6 +2764,11 @@ def cat_sample_generator(op, device, dtype, requires_grad, **kwargs):
for shapes, dim in cases:
yield SampleInput([make(s) for s in shapes], dim)

# test upcasting (in case the dtype is not float16). PyTorch has upcasting logic in cat.
if dtype != torch.float16:
for shapes, dim in cases:
yield SampleInput([make(s) if i != 1 else make(s, dtype=torch.float16) for i, s in enumerate(shapes)], dim)

# Tests concatenating with a tensor broadcast along the concatenation dimension
a = make((5,))
b = make((1,)).expand((5,))
Expand Down

0 comments on commit 83ca40b

Please sign in to comment.