diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 06d508ccae..51f38bb381 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2656,3 +2656,73 @@ def foo_torch(x): assert str(trace_thunder).count("return thunder.torch.softmax(x, 0)") == 1 # torch.softmax should be traced as usual assert str(trace_torch).count(f"return torch.softmax(x, 0)") == 1 + + +def test_torch_device(): + # Test `thunder.jit` support for `torch.device()`. + if not torch.cuda.is_available(): + # thunder.core.devices.Device __init__ calls `torch.cuda.device_count()` when DeviceType is CUDA. + # https://github.com/Lightning-AI/lightning-thunder/blob/067f15aae47ad71229732ca6c35a5d190135e48c/thunder/core/devices.py#L96-L101 + pytest.skip("CUDA not available") + + # Check the output against the PyTorch eager output. + def _test(foo, inputs): + for input in inputs: + actual = thunder.jit(foo)(input) + expected = foo(input) + assert actual.device == expected.device + + # Test with str input + device_strs = ("cpu", "cuda", "cuda:0", "meta") + + def foo1(dev): + # If we return the device here, thunder.jit version will return `thunder.device` + # while eager will return `torch.device` + # https://github.com/Lightning-AI/lightning-thunder/issues/573 + return torch.ones(3, 3, device=torch.device(dev)) + + _test(foo1, device_strs) + + # Test with str and index input + device_strs_and_idxs = (("cpu", 0), ("cpu", 1), ("cuda", 0), ("meta", 0), ("meta", 1)) + + def foo2(dev_and_idx): + return torch.ones(3, 3, device=torch.device(*dev_and_idx)) + + _test(foo2, device_strs_and_idxs) + + # Test with `torch.device` as input + torch_devices = (torch.device("cpu"), torch.device("cuda"), torch.device("meta")) + + def foo3(device): + return torch.ones(3, 3, device=torch.device(device)) + + _test(foo3, torch_devices) + + # Test with `thunder.device` as input + tensor_proxy_devices = ( + torch.ones(1, device=torch.device("cpu")), + torch.ones(1, device=torch.device("cuda")), + torch.ones(1, device=torch.device("meta")), + ) + + # Here `torch.device()` will see a `thunder.device` as input. + def foo4(ref_t): + return torch.ones(3, 3, device=torch.device(ref_t.device)) + + _test(foo4, tensor_proxy_devices) + + # Error inputs + error_inputs = ( + ((torch.device("cpu"), 0), RuntimeError), + (("cuda:0", 0), RuntimeError), + (("cpu:",), ValueError), + (("cuda:",), ValueError), + ) + + def foo_error(args): + return torch.device(*args) + + for inp, err in error_inputs: + with pytest.raises(err): + thunder.jit(foo_error)(inp) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 21e8df6247..ffc5bece5a 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -4371,6 +4371,33 @@ def softmax(a: TensorLike, dim: int, dtype: None | dtypeLike = None, _stacklevel return _softmax(a, dim=dim, dtype=dtype) +def torch_device(device_or_str: DeviceLike, /, index: int | None = None) -> devices.Device: + if isinstance(device_or_str, (devices.Device, torch.device)): + # PyTorch behavior: + # >>> torch.device(torch.device("cuda"), 0) + # TypeError: device(): argument 'type' (position 1) must be str, not torch.device + utils.check(index is None, lambda: f"device(): `index` is only allowed when `device` is a `str`.") + return to_device(device_or_str) + + # NOTE: device_or_str is `str` + if index is not None: + # PyTorch behavior: + # >>> torch.device("cuda:0", 0) + # RuntimeError: type (string) must not include an index because index was passed explicitly: cuda:0 + has_device_idx = len(device_or_str.split(":")) > 1 + utils.check( + not has_device_idx, + lambda: f"device string must not include an index because index was passed explicitly: {device_or_str}", + ) + + return devices.Device(device_or_str, index) + + +# We don't use @torchsymbol as we don't want `torch.device()` to appear in trace as a symbol. +# Because of this, we need to manually register the implementation. +_torch_to_thunder_function_map[torch.device] = torch_device + + # # Distributed operations #