Skip to content

Commit

Permalink
convert thunder types to torch types (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
k223kim authored Jul 18, 2024
1 parent 08d8347 commit f160111
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
3 changes: 3 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_create_callable,
trace,
transform_for_execution,
transform_to_torch_types,
)
import thunder.extend as extend
from thunder.extend import Executor, add_default_executor
Expand Down Expand Up @@ -76,6 +77,7 @@
import torch as pytorch

import thunder.clang as clang
from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map

# Imports executors (to populate default executors and make them accessible)
import thunder.executors.pythonex
Expand Down Expand Up @@ -624,6 +626,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
backward_traces.append(backward_trc)

computation_trc = transform_to_torch_types(computation_trc)
comp = computation_trc.python_callable()

if backward_trc is not None:
Expand Down
18 changes: 18 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import torch as torch
import numpy as np
import thunder

#
# Datastructures for compiled functions
Expand Down Expand Up @@ -858,3 +859,20 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]:
_fn._using_grad_transform = _using_grad_transform

return _fn


def transform_to_torch_types(trace: TraceCtx):
# convert the thunder types to torch types if any
def map_to_torch(x: Any) -> Any:
if isinstance(x, thunder.dtypes.dtype):
return thunder.dtypes.to_torch_dtype(x)
elif isinstance(x, thunder.devices.Device):
return thunder.devices.to_torch_device(x)
return x

last = trace.bound_symbols[-1]
assert last.sym.id == prims.PrimIDs.RETURN
new_args = tree_map(map_to_torch, last.args)
new_bsym = prims.python_return.bind(*new_args, output=())
trace.bound_symbols[-1] = new_bsym
return trace
13 changes: 8 additions & 5 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def foo(a, dev=dev):
x, y = lc_result

assert x == 1
assert y == dev
assert y == thunder.devices.to_torch_device(dev)


@instantiate(dtypes=(thunder.float32,))
Expand Down Expand Up @@ -2969,12 +2969,13 @@ def fn(x):
jfn = thunder.jit(fn)
actual_dtype = jfn(x)

assert actual_dtype == thunder.dtypes.float32
assert fn(x) == jfn(x)
assert actual_dtype == torch.float32

# Check with a different default dtype.
with set_default_dtype_ctx(torch.float16):
actual_dtype = jfn(x)
assert actual_dtype == thunder.dtypes.float16
assert actual_dtype == torch.float16

assert thunder.cache_misses(jfn) == 2

Expand Down Expand Up @@ -3002,10 +3003,12 @@ def fn():
return torch.arange(start=1, end=2, step=0.5).dtype

jfn = thunder.jit(fn)
assert jfn() == thunder.dtypes.float32
assert fn() == jfn()
assert jfn() == torch.float32

def fn():
return torch.arange(start=1, end=3, step=1).dtype

jfn = thunder.jit(fn)
assert jfn() == thunder.dtypes.int64
assert fn() == jfn()
assert jfn() == torch.int64

0 comments on commit f160111

Please sign in to comment.