Skip to content

Commit

Permalink
factory functions - fix handling of default device (#820)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Jul 22, 2024
1 parent f04a88f commit 85df4f6
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 23 deletions.
3 changes: 3 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ def get_computation_and_inputs(*args, **kwargs):
# default dtype (for factory functions)
cache_info["default_dtype"] = pytorch.get_default_dtype()

# default device (for factory functions)
cache_info["default_device"] = pytorch.get_default_device()

# autocast related operations
is_autocast_enabled = False
if pytorch.is_autocast_enabled() or pytorch.is_autocast_cpu_enabled():
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,7 @@ def from_provenance(provenance, *, new_output=False):
clang.check_string_value(p, v)
elif isinstance(v, (int, bool, float)):
clang.check_number_type_and_value(p, v)
elif isinstance(v, torch.dtype):
elif isinstance(v, (torch.dtype, torch.device)):
clang.check_literal_like(p, v)
else:
raise NotImplementedError(f"cache info of type {type(v).__name__}")
Expand Down
61 changes: 61 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,67 @@ def fn(x):
torch.set_default_dtype(default_dtype)


@requiresCUDA
def test_factory_functions_default_device():

def fn(x):
o = torch.ones(x.shape)
return o.device

x = torch.randn(3, 3)
jfn = thunder.jit(fn)
actual_device = jfn(x)

assert fn(x) == jfn(x)
assert actual_device == torch.device("cpu")

# Check with a different default device.
org_device = torch.get_default_device()
torch.set_default_device("cuda")
try:
actual_device = jfn(x)
assert actual_device == fn(x)
finally:
torch.set_default_device(org_device)

assert thunder.cache_misses(jfn) == 2


@requiresCUDA
def test_change_default_device_in_jitted_fn():
default_device = torch.get_default_device()
try:

def fn(x):
torch.set_default_device("cuda")
o = torch.ones(x.shape)
return o.device

jfn = thunder.jit(fn)
with pytest.raises(RuntimeError, match="Default device is changed during the execution of jitted function"):
jfn(torch.randn(3, 3))
finally:
torch.set_default_device(default_device)


@requiresCUDA
@pytest.mark.xfail(
reason="When using device as context in PyTorch, it doesn't reflect in torch.get_default_device - see https://github.com/pytorch/pytorch/issues/131328",
strict=True,
)
def test_change_default_device_with_ctx():
def fn(x):
o = torch.ones(x.shape)
return o.device

x = torch.randn(3)

with torch.device("cuda"):
jfn = thunder.jit(fn)
actual_device = jfn(x)
assert actual_device == fn(x)


def test_arange_default_dtype():
# If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
# Otherwise, the dtype is inferred to be torch.int64.
Expand Down
46 changes: 24 additions & 22 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
_inplace_to_out_of_place: dict[Callable, tuple[Callable, int]] = {}


# Helpers for factory functions to get default dtypes.
# Helpers for factory functions to get default dtype and device.
def get_default_dtype():
# `thunder.jit` will create cache info and stash the default dtype
# observed at the beginning of jitting.
Expand All @@ -103,6 +103,23 @@ def maybe_get_default_dtype(dtype):
return dtype or get_default_dtype()


def get_default_device():
# `thunder.jit` will create cache info and stash the default device
# observed at the beginning of jitting.
cache_info = thunder._get_cache_info()

# Currently, changing device during the jitted function is unsupported.
utils.check(
cache_info["default_device"] == torch.get_default_device(),
lambda: "Default device is changed during the execution of jitted function. This is currently unsupported.",
)
return torch.get_default_device()


def maybe_get_default_device(device):
return device or get_default_device()


# A wrapper that executes the operations within the torch language context
# NOTE because this module defines the torch language context, a reference to itself
# is acquired by inspecting the __module__ attribute of the is_available function defined
Expand Down Expand Up @@ -520,9 +537,7 @@ def arange(
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
) -> TensorLike:
if device is None:
device = "cpu"

device = maybe_get_default_device(device)
device = to_device(device)
# From torch docs - https://pytorch.org/docs/stable/generated/torch.arange.html
# If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
Expand All @@ -545,12 +560,8 @@ def arange(
def full(
shape: Sequence[int], fill_value: NumberLike, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None
) -> TensorLike:
if device is None:
device = "cpu"

device = to_device(device)
device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))

return clang.full(shape, fill_value, device=device, dtype=dtype)


Expand Down Expand Up @@ -616,7 +627,7 @@ def uniform(
device: DeviceLike,
dtype: dtypeLike,
) -> TensorLike:
device = to_device(device)
device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))

return clang.uniform(shape, minval, maxval, device=device, dtype=dtype)
Expand Down Expand Up @@ -673,7 +684,7 @@ def uniform_philox(
seed: int | TensorProxy,
offset: int | TensorProxy,
) -> TensorLike:
device = to_device(device)
device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))

return clang.uniform_philox(shape, minval, maxval, device=device, dtype=dtype, seed=seed, offset=offset)
Expand All @@ -698,10 +709,8 @@ def randn(
# NOTE: Currently, we don't model randomness
utils.check(generator is None, lambda: "generator is not None which is currently unsupported", NotImplementedError)
utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError)
if device is None:
device = "cpu"
device = to_device(device)

device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))
shape = utils.extract_shape_from_varargs(shape)
return prims.randn(shape, device=device, dtype=dtype)
Expand Down Expand Up @@ -788,15 +797,8 @@ def empty(
NotImplementedError,
)

# For now we default to `float32`,
# however, we should add a default dtype or rely on `torch.get_default_dtype`.
dtype = to_dtype(maybe_get_default_dtype(dtype))

# For now we default to "cpu",
# however, we should add a default device or rely on `torch.get_default_device`.
if device is None:
device = "cpu"
device = to_device(device)
device = to_device(maybe_get_default_device(device))

return clang.empty(size, device=device, dtype=dtype)

Expand Down

0 comments on commit 85df4f6

Please sign in to comment.