diff --git a/src/chronos/utils.py b/src/chronos/utils.py index 9248d5a..1c6b7e9 100644 --- a/src/chronos/utils.py +++ b/src/chronos/utils.py @@ -17,4 +17,4 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor: size=(max_len - len(c),), fill_value=torch.nan, device=c.device ) padded.append(torch.concat((padding, c), dim=-1)) - return torch.stack(padded).to(tensors[0]) + return torch.stack(padded) diff --git a/test/__init__.py b/test/__init__.py index 03f633a..04f8b7b 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,2 +1,2 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..c6a91b1 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from chronos.utils import left_pad_and_stack_1D + + +@pytest.mark.parametrize( + "tensors", + [ + [ + torch.tensor([2.0, 3.0], dtype=dtype), + torch.tensor([4.0, 5.0, 6.0], dtype=dtype), + torch.tensor([7.0, 8.0, 9.0, 10.0], dtype=dtype), + ] + for dtype in [torch.int, torch.float16, torch.float32] + ], +) +def test_pad_and_stack(tensors: list): + stacked_and_padded = left_pad_and_stack_1D(tensors) + + assert stacked_and_padded.dtype == torch.float32 + assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors)) + + ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype) + + assert torch.sum(torch.nan_to_num(stacked_and_padded, nan=0)) == torch.sum(ref) diff --git a/test/util.py b/test/util.py index 37a2c3b..78c2e93 100644 --- a/test/util.py +++ b/test/util.py @@ -10,4 +10,4 @@ def validate_tensor( assert a.shape == shape if dtype is not None: - assert a.dtype == dtype \ No newline at end of file + assert a.dtype == dtype