Skip to content

Commit

Permalink
Support torch.device (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Jun 11, 2024
1 parent 067f15a commit 56e267a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
70 changes: 70 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,3 +2656,73 @@ def foo_torch(x):
assert str(trace_thunder).count("return thunder.torch.softmax(x, 0)") == 1
# torch.softmax should be traced as usual
assert str(trace_torch).count(f"return torch.softmax(x, 0)") == 1


def test_torch_device():
# Test `thunder.jit` support for `torch.device()`.
if not torch.cuda.is_available():
# thunder.core.devices.Device __init__ calls `torch.cuda.device_count()` when DeviceType is CUDA.
# https://github.com/Lightning-AI/lightning-thunder/blob/067f15aae47ad71229732ca6c35a5d190135e48c/thunder/core/devices.py#L96-L101
pytest.skip("CUDA not available")

# Check the output against the PyTorch eager output.
def _test(foo, inputs):
for input in inputs:
actual = thunder.jit(foo)(input)
expected = foo(input)
assert actual.device == expected.device

# Test with str input
device_strs = ("cpu", "cuda", "cuda:0", "meta")

def foo1(dev):
# If we return the device here, thunder.jit version will return `thunder.device`
# while eager will return `torch.device`
# https://github.com/Lightning-AI/lightning-thunder/issues/573
return torch.ones(3, 3, device=torch.device(dev))

_test(foo1, device_strs)

# Test with str and index input
device_strs_and_idxs = (("cpu", 0), ("cpu", 1), ("cuda", 0), ("meta", 0), ("meta", 1))

def foo2(dev_and_idx):
return torch.ones(3, 3, device=torch.device(*dev_and_idx))

_test(foo2, device_strs_and_idxs)

# Test with `torch.device` as input
torch_devices = (torch.device("cpu"), torch.device("cuda"), torch.device("meta"))

def foo3(device):
return torch.ones(3, 3, device=torch.device(device))

_test(foo3, torch_devices)

# Test with `thunder.device` as input
tensor_proxy_devices = (
torch.ones(1, device=torch.device("cpu")),
torch.ones(1, device=torch.device("cuda")),
torch.ones(1, device=torch.device("meta")),
)

# Here `torch.device()` will see a `thunder.device` as input.
def foo4(ref_t):
return torch.ones(3, 3, device=torch.device(ref_t.device))

_test(foo4, tensor_proxy_devices)

# Error inputs
error_inputs = (
((torch.device("cpu"), 0), RuntimeError),
(("cuda:0", 0), RuntimeError),
(("cpu:",), ValueError),
(("cuda:",), ValueError),
)

def foo_error(args):
return torch.device(*args)

for inp, err in error_inputs:
with pytest.raises(err):
thunder.jit(foo_error)(inp)
27 changes: 27 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4371,6 +4371,33 @@ def softmax(a: TensorLike, dim: int, dtype: None | dtypeLike = None, _stacklevel
return _softmax(a, dim=dim, dtype=dtype)


def torch_device(device_or_str: DeviceLike, /, index: int | None = None) -> devices.Device:
if isinstance(device_or_str, (devices.Device, torch.device)):
# PyTorch behavior:
# >>> torch.device(torch.device("cuda"), 0)
# TypeError: device(): argument 'type' (position 1) must be str, not torch.device
utils.check(index is None, lambda: f"device(): `index` is only allowed when `device` is a `str`.")
return to_device(device_or_str)

# NOTE: device_or_str is `str`
if index is not None:
# PyTorch behavior:
# >>> torch.device("cuda:0", 0)
# RuntimeError: type (string) must not include an index because index was passed explicitly: cuda:0
has_device_idx = len(device_or_str.split(":")) > 1
utils.check(
not has_device_idx,
lambda: f"device string must not include an index because index was passed explicitly: {device_or_str}",
)

return devices.Device(device_or_str, index)


# We don't use @torchsymbol as we don't want `torch.device()` to appear in trace as a symbol.
# Because of this, we need to manually register the implementation.
_torch_to_thunder_function_map[torch.device] = torch_device


#
# Distributed operations
#
Expand Down

0 comments on commit 56e267a

Please sign in to comment.