Skip to content

Commit

Permalink
support inferring dtype with torch.get_default_dtype for factory func…
Browse files Browse the repository at this point in the history
…tions (#775)
  • Loading branch information
kshitij12345 authored Jul 16, 2024
1 parent ba89fe7 commit 6703b35
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 14 deletions.
3 changes: 3 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ def get_computation_and_inputs(*args, **kwargs):
# this could be replaced by the respective querying in the prologues
cache_info = _get_cache_info()

# default dtype (for factory functions)
cache_info["default_dtype"] = pytorch.get_default_dtype()

# autocast related operations
is_autocast_enabled = False
if pytorch.is_autocast_enabled() or pytorch.is_autocast_cpu_enabled():
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,8 @@ def from_provenance(provenance, *, new_output=False):
clang.check_string_value(p, v)
elif isinstance(v, (int, bool, float)):
clang.check_number_type_and_value(p, v)
elif isinstance(v, torch.dtype):
clang.check_literal_like(p, v)
else:
raise NotImplementedError(f"cache info of type {type(v).__name__}")

Expand Down
11 changes: 11 additions & 0 deletions thunder/tests/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Optional
from collections.abc import Callable, Sequence, Iterable
import packaging.version
import contextlib

import pytest
import torch
Expand Down Expand Up @@ -581,3 +582,13 @@ def __init__(self, comparator):

def __call__(self, test_template):
return test_template


@contextlib.contextmanager
def set_default_dtype_ctx(dtype):
saved_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(saved_dtype)
64 changes: 62 additions & 2 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
import thunder.torch as ltorch

import thunder.core.codeutils as codeutils
from thunder.tests.framework import instantiate, NOTHING, TorchExecutor, nvFuserExecutor, requiresCUDA, TestExecutor
from thunder.tests.framework import (
instantiate,
NOTHING,
TorchExecutor,
nvFuserExecutor,
requiresCUDA,
TestExecutor,
set_default_dtype_ctx,
)
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, tracectx
Expand Down Expand Up @@ -1302,7 +1310,7 @@ def test_boundsymbol_hash_eq_examples(executor, device, dtype: dtypes.dtype):

# Returns the bound symbols for a function and args.
def compile_bsyms(fn, args):
fn = executor.make_callable_with_info(fn)
fn = executor.make_callable(fn)
_ = fn(*args)
traces = thunder.last_traces(fn)
return traces[0].bound_symbols
Expand Down Expand Up @@ -2943,3 +2951,55 @@ def fn(x):
(pystr,) = tr.bound_symbols[1].subsymbols[0].python(0)

assert "convert_element_type(x, dtypes.float16)" in pystr


def test_factory_functions_default_dtype():

def fn(x):
o = torch.ones(x.shape)
return o.dtype

x = torch.randn(3, 3)
jfn = thunder.jit(fn)
actual_dtype = jfn(x)

assert actual_dtype == thunder.dtypes.float32

# Check with a different default dtype.
with set_default_dtype_ctx(torch.float16):
actual_dtype = jfn(x)
assert actual_dtype == thunder.dtypes.float16

assert thunder.cache_misses(jfn) == 2


def test_change_default_dtype_in_jitted_fn():
default_dtype = torch.get_default_dtype()
try:

def fn(x):
torch.set_default_dtype(torch.float16)
o = torch.ones(x.shape)
return o.dtype

jfn = thunder.jit(fn)
with pytest.raises(RuntimeError, match="Default dtype is changed during the execution of jitted function"):
jfn(torch.randn(3, 3))
finally:
torch.set_default_dtype(default_dtype)


def test_arange_default_dtype():
# If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
# Otherwise, the dtype is inferred to be torch.int64.
def fn():
return torch.arange(start=1, end=2, step=0.5).dtype

jfn = thunder.jit(fn)
assert jfn() == thunder.dtypes.float32

def fn():
return torch.arange(start=1, end=3, step=1).dtype

jfn = thunder.jit(fn)
assert jfn() == thunder.dtypes.int64
47 changes: 35 additions & 12 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@
_inplace_to_out_of_place: dict[Callable, tuple[Callable, int]] = {}


# Helpers for factory functions to get default dtypes.
def get_default_dtype():
# `thunder.jit` will create cache info and stash the default dtype
# observed at the beginning of jitting.
cache_info = thunder._get_cache_info()

# Currently, changing dtype during the jitted function is unsupported.
utils.check(
cache_info["default_dtype"] == torch.get_default_dtype(),
lambda: "Default dtype is changed during the execution of jitted function. This is currently unsupported.",
)
return torch.get_default_dtype()


def maybe_get_default_dtype(dtype):
return dtype or get_default_dtype()


# A wrapper that executes the operations within the torch language context
# NOTE because this module defines the torch language context, a reference to itself
# is acquired by inspecting the __module__ attribute of the is_available function defined
Expand Down Expand Up @@ -536,6 +554,15 @@ def arange(
device = "cpu"

device = to_device(device)
# From torch docs - https://pytorch.org/docs/stable/generated/torch.arange.html
# If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
# Otherwise, the dtype is inferred to be torch.int64.
if dtype is None: # infer the dtype
if any(map(lambda x: isinstance(x, float), (start, end, step))):
dtype = maybe_get_default_dtype(dtype)
else:
dtype = torch.int64

dtype = to_dtype(dtype)

if end is None:
Expand All @@ -552,7 +579,7 @@ def full(
device = "cpu"

device = to_device(device)
dtype = to_dtype(dtype)
dtype = to_dtype(maybe_get_default_dtype(dtype))

return clang.full(shape, fill_value, device=device, dtype=dtype)

Expand Down Expand Up @@ -599,6 +626,9 @@ def tensor(
utils.check(not pin_memory, lambda: "pin_memory=True is not supported within thunder.jit", NotImplementedError)

if isinstance(seq_or_number, (Number, NumberProxy)):
# Infer dtype from value (as `full` will use default dtype if dtype=None).
if dtype is None:
dtype = dtypes.numbertype_to_dtype(dtypes.to_dtype(seq_or_number))
return full((), seq_or_number, dtype=dtype, device=device)

return clang.tensor_from_sequence(seq_or_number, dtype=dtype, device=device)
Expand All @@ -617,7 +647,7 @@ def uniform(
dtype: dtypeLike,
) -> TensorLike:
device = to_device(device)
dtype = to_dtype(dtype)
dtype = to_dtype(maybe_get_default_dtype(dtype))

return clang.uniform(shape, minval, maxval, device=device, dtype=dtype)

Expand Down Expand Up @@ -674,7 +704,7 @@ def uniform_philox(
offset: int | TensorProxy,
) -> TensorLike:
device = to_device(device)
dtype = to_dtype(dtype)
dtype = to_dtype(maybe_get_default_dtype(dtype))

return clang.uniform_philox(shape, minval, maxval, device=device, dtype=dtype, seed=seed, offset=offset)

Expand Down Expand Up @@ -702,12 +732,7 @@ def randn(
device = "cpu"
device = to_device(device)

# For now we default to `float32`,
# however, we should add a default dtype or
# rely on `torch.get_default_dtype`.
if dtype is None:
dtype = torch.float
dtype = to_dtype(dtype)
dtype = to_dtype(maybe_get_default_dtype(dtype))
shape = utils.extract_shape_from_varargs(shape)
return prims.randn(shape, device=device, dtype=dtype)

Expand Down Expand Up @@ -795,9 +820,7 @@ def empty(

# For now we default to `float32`,
# however, we should add a default dtype or rely on `torch.get_default_dtype`.
if dtype is None:
dtype = torch.float
dtype = to_dtype(dtype)
dtype = to_dtype(maybe_get_default_dtype(dtype))

# For now we default to "cpu",
# however, we should add a default device or rely on `torch.get_default_device`.
Expand Down

0 comments on commit 6703b35

Please sign in to comment.