Skip to content

Commit

Permalink
Add memory_format in torch.Tensor.to (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored Apr 13, 2024
1 parent ff199c2 commit 139cc22
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
12 changes: 8 additions & 4 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _to_transform(
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
copy: bool = False,
memory_format: None | torch.memory_format = None,
) -> TensorLike:
device: None | devices.Device
dtype: None | dtypes.dtype
Expand All @@ -116,11 +117,14 @@ def _to_transform(
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)

if torch_device is not None and torch_dtype is not None:
return to(a, torch_device, torch_dtype, copy=copy)
kwargs = {"copy": copy}
if torch_device is not None:
return to(a, torch_device, copy=copy)
return to(a, torch_dtype, copy=copy)
kwargs["device"] = torch_device
if torch_dtype is not None:
kwargs["dtype"] = torch_dtype
if memory_format is not None:
kwargs["memory_format"] = memory_format
return to(a, **kwargs)


def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy:
Expand Down
18 changes: 18 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,6 +1953,24 @@ def f(a):
assert "thunder.computation" in excinfo.traceback[-1].path


@instantiate(
dtypes=NOTHING,
executors=(TorchExecutor,),
)
def test_torch_tensor_to_memory_format(executor: TestExecutor, device: str, _):
inp = torch.randn(2, 4, 5, 3, device=device, dtype=torch.float32)

def torch_to(a, memory_format):
return a.to(memory_format=memory_format)

cfn = executor.make_callable(torch_to, disable_preprocessing=False)

for m_format in [torch.contiguous_format, torch.channels_last, torch.preserve_format]:
thunder_result = cfn(inp, torch.contiguous_format)
torch_result = torch_to(inp, torch.contiguous_format)
assert_close(torch_result, thunder_result, check_stride=True)


# TODO See issue "Add contiguous and clang.stride_order OpInfos that check stride
# consistency with PyTorch"
@instantiate(
Expand Down
16 changes: 15 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def _parse_to_device_and_dtype(
return device, dtype


# TODO Model non_blocking, copy, and memory_format (as kwargs)
# TODO Model non_blocking (as kwargs)
@torchsymbol(torch.Tensor.to, is_method=True)
def to(
a: TensorLike,
Expand All @@ -247,6 +247,7 @@ def to(
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
copy: bool = False,
memory_format: None | torch.memory_format = None,
) -> TensorLike:
device, dtype = _parse_to_device_and_dtype(
tensor_dtype_or_device, optional_positional_dtype, device=device, dtype=dtype
Expand All @@ -259,6 +260,12 @@ def to(
if dtype is not None:
dtype = to_dtype(dtype)
a = prims.convert_element_type(a, dtype)
if memory_format is not None:
# NOTE not sure if we need to handle torch.preserve_format explicitly
if memory_format == torch.channels_last:
a = prims.stride_order(a, (3, 0, 2, 1))
elif memory_format == torch.channels_last_3d:
a = prims.stride_order(a, (4, 0, 3, 2, 1))
return a

# NOTE copy == False
Expand All @@ -270,6 +277,13 @@ def to(
if dtype is not None:
return clang.maybe_convert_to_dtype(a, dtype)

if memory_format is not None:
# NOTE not sure if we need to handle torch.preserve_format explicitly
if memory_format == torch.channels_last:
a = prims.stride_order(a, (3, 0, 2, 1))
elif memory_format == torch.channels_last_3d:
a = prims.stride_order(a, (4, 0, 3, 2, 1))

return a


Expand Down

0 comments on commit 139cc22

Please sign in to comment.