diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index bd991da611..73b323d20c 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -106,6 +106,7 @@ def _to_transform( device: None | DeviceLike = None, dtype: None | dtypeLike = None, copy: bool = False, + memory_format: None | torch.memory_format = None, ) -> TensorLike: device: None | devices.Device dtype: None | dtypes.dtype @@ -116,11 +117,14 @@ def _to_transform( torch_device: None | torch.device = to_torch_device(device) torch_dtype: None | torch.dtype = to_torch_dtype(dtype) - if torch_device is not None and torch_dtype is not None: - return to(a, torch_device, torch_dtype, copy=copy) + kwargs = {"copy": copy} if torch_device is not None: - return to(a, torch_device, copy=copy) - return to(a, torch_dtype, copy=copy) + kwargs["device"] = torch_device + if torch_dtype is not None: + kwargs["dtype"] = torch_dtype + if memory_format is not None: + kwargs["memory_format"] = memory_format + return to(a, **kwargs) def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy: diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index c8af5a58e1..3ef9d44ca8 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -1953,6 +1953,24 @@ def f(a): assert "thunder.computation" in excinfo.traceback[-1].path +@instantiate( + dtypes=NOTHING, + executors=(TorchExecutor,), +) +def test_torch_tensor_to_memory_format(executor: TestExecutor, device: str, _): + inp = torch.randn(2, 4, 5, 3, device=device, dtype=torch.float32) + + def torch_to(a, memory_format): + return a.to(memory_format=memory_format) + + cfn = executor.make_callable(torch_to, disable_preprocessing=False) + + for m_format in [torch.contiguous_format, torch.channels_last, torch.preserve_format]: + thunder_result = cfn(inp, torch.contiguous_format) + torch_result = torch_to(inp, torch.contiguous_format) + assert_close(torch_result, thunder_result, check_stride=True) + + # TODO See issue "Add contiguous and clang.stride_order OpInfos that check stride # consistency with PyTorch" @instantiate( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 0cee3f169b..b054a9a806 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -236,7 +236,7 @@ def _parse_to_device_and_dtype( return device, dtype -# TODO Model non_blocking, copy, and memory_format (as kwargs) +# TODO Model non_blocking (as kwargs) @torchsymbol(torch.Tensor.to, is_method=True) def to( a: TensorLike, @@ -247,6 +247,7 @@ def to( device: None | DeviceLike = None, dtype: None | dtypeLike = None, copy: bool = False, + memory_format: None | torch.memory_format = None, ) -> TensorLike: device, dtype = _parse_to_device_and_dtype( tensor_dtype_or_device, optional_positional_dtype, device=device, dtype=dtype @@ -259,6 +260,12 @@ def to( if dtype is not None: dtype = to_dtype(dtype) a = prims.convert_element_type(a, dtype) + if memory_format is not None: + # NOTE not sure if we need to handle torch.preserve_format explicitly + if memory_format == torch.channels_last: + a = prims.stride_order(a, (3, 0, 2, 1)) + elif memory_format == torch.channels_last_3d: + a = prims.stride_order(a, (4, 0, 3, 2, 1)) return a # NOTE copy == False @@ -270,6 +277,13 @@ def to( if dtype is not None: return clang.maybe_convert_to_dtype(a, dtype) + if memory_format is not None: + # NOTE not sure if we need to handle torch.preserve_format explicitly + if memory_format == torch.channels_last: + a = prims.stride_order(a, (3, 0, 2, 1)) + elif memory_format == torch.channels_last_3d: + a = prims.stride_order(a, (4, 0, 3, 2, 1)) + return a