diff --git a/thunder/__init__.py b/thunder/__init__.py index 79866004e6..ca70a86407 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -374,6 +374,9 @@ def get_computation_and_inputs(*args, **kwargs): # default dtype (for factory functions) cache_info["default_dtype"] = pytorch.get_default_dtype() + # default device (for factory functions) + cache_info["default_device"] = pytorch.get_default_device() + # autocast related operations is_autocast_enabled = False if pytorch.is_autocast_enabled() or pytorch.is_autocast_cpu_enabled(): diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 8fc4ece414..723957771f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1629,7 +1629,7 @@ def from_provenance(provenance, *, new_output=False): clang.check_string_value(p, v) elif isinstance(v, (int, bool, float)): clang.check_number_type_and_value(p, v) - elif isinstance(v, torch.dtype): + elif isinstance(v, (torch.dtype, torch.device)): clang.check_literal_like(p, v) else: raise NotImplementedError(f"cache info of type {type(v).__name__}") diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index cf81cf93c2..2838031876 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2996,6 +2996,67 @@ def fn(x): torch.set_default_dtype(default_dtype) +@requiresCUDA +def test_factory_functions_default_device(): + + def fn(x): + o = torch.ones(x.shape) + return o.device + + x = torch.randn(3, 3) + jfn = thunder.jit(fn) + actual_device = jfn(x) + + assert fn(x) == jfn(x) + assert actual_device == torch.device("cpu") + + # Check with a different default device. + org_device = torch.get_default_device() + torch.set_default_device("cuda") + try: + actual_device = jfn(x) + assert actual_device == fn(x) + finally: + torch.set_default_device(org_device) + + assert thunder.cache_misses(jfn) == 2 + + +@requiresCUDA +def test_change_default_device_in_jitted_fn(): + default_device = torch.get_default_device() + try: + + def fn(x): + torch.set_default_device("cuda") + o = torch.ones(x.shape) + return o.device + + jfn = thunder.jit(fn) + with pytest.raises(RuntimeError, match="Default device is changed during the execution of jitted function"): + jfn(torch.randn(3, 3)) + finally: + torch.set_default_device(default_device) + + +@requiresCUDA +@pytest.mark.xfail( + reason="When using device as context in PyTorch, it doesn't reflect in torch.get_default_device - see https://github.com/pytorch/pytorch/issues/131328", + strict=True, +) +def test_change_default_device_with_ctx(): + def fn(x): + o = torch.ones(x.shape) + return o.device + + x = torch.randn(3) + + with torch.device("cuda"): + jfn = thunder.jit(fn) + actual_device = jfn(x) + assert actual_device == fn(x) + + def test_arange_default_dtype(): # If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype(). # Otherwise, the dtype is inferred to be torch.int64. diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2b09d4c528..cddfffdb16 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -85,7 +85,7 @@ _inplace_to_out_of_place: dict[Callable, tuple[Callable, int]] = {} -# Helpers for factory functions to get default dtypes. +# Helpers for factory functions to get default dtype and device. def get_default_dtype(): # `thunder.jit` will create cache info and stash the default dtype # observed at the beginning of jitting. @@ -103,6 +103,23 @@ def maybe_get_default_dtype(dtype): return dtype or get_default_dtype() +def get_default_device(): + # `thunder.jit` will create cache info and stash the default device + # observed at the beginning of jitting. + cache_info = thunder._get_cache_info() + + # Currently, changing device during the jitted function is unsupported. + utils.check( + cache_info["default_device"] == torch.get_default_device(), + lambda: "Default device is changed during the execution of jitted function. This is currently unsupported.", + ) + return torch.get_default_device() + + +def maybe_get_default_device(device): + return device or get_default_device() + + # A wrapper that executes the operations within the torch language context # NOTE because this module defines the torch language context, a reference to itself # is acquired by inspecting the __module__ attribute of the is_available function defined @@ -520,9 +537,7 @@ def arange( device: None | DeviceLike = None, dtype: None | dtypeLike = None, ) -> TensorLike: - if device is None: - device = "cpu" - + device = maybe_get_default_device(device) device = to_device(device) # From torch docs - https://pytorch.org/docs/stable/generated/torch.arange.html # If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype(). @@ -545,12 +560,8 @@ def arange( def full( shape: Sequence[int], fill_value: NumberLike, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None ) -> TensorLike: - if device is None: - device = "cpu" - - device = to_device(device) + device = to_device(maybe_get_default_device(device)) dtype = to_dtype(maybe_get_default_dtype(dtype)) - return clang.full(shape, fill_value, device=device, dtype=dtype) @@ -616,7 +627,7 @@ def uniform( device: DeviceLike, dtype: dtypeLike, ) -> TensorLike: - device = to_device(device) + device = to_device(maybe_get_default_device(device)) dtype = to_dtype(maybe_get_default_dtype(dtype)) return clang.uniform(shape, minval, maxval, device=device, dtype=dtype) @@ -673,7 +684,7 @@ def uniform_philox( seed: int | TensorProxy, offset: int | TensorProxy, ) -> TensorLike: - device = to_device(device) + device = to_device(maybe_get_default_device(device)) dtype = to_dtype(maybe_get_default_dtype(dtype)) return clang.uniform_philox(shape, minval, maxval, device=device, dtype=dtype, seed=seed, offset=offset) @@ -698,10 +709,8 @@ def randn( # NOTE: Currently, we don't model randomness utils.check(generator is None, lambda: "generator is not None which is currently unsupported", NotImplementedError) utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError) - if device is None: - device = "cpu" - device = to_device(device) + device = to_device(maybe_get_default_device(device)) dtype = to_dtype(maybe_get_default_dtype(dtype)) shape = utils.extract_shape_from_varargs(shape) return prims.randn(shape, device=device, dtype=dtype) @@ -788,15 +797,8 @@ def empty( NotImplementedError, ) - # For now we default to `float32`, - # however, we should add a default dtype or rely on `torch.get_default_dtype`. dtype = to_dtype(maybe_get_default_dtype(dtype)) - - # For now we default to "cpu", - # however, we should add a default device or rely on `torch.get_default_device`. - if device is None: - device = "cpu" - device = to_device(device) + device = to_device(maybe_get_default_device(device)) return clang.empty(size, device=device, dtype=dtype)