Skip to content

Commit f160111

Browse files
authored
convert thunder types to torch types (#792)
1 parent 08d8347 commit f160111

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

thunder/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_create_callable,
4848
trace,
4949
transform_for_execution,
50+
transform_to_torch_types,
5051
)
5152
import thunder.extend as extend
5253
from thunder.extend import Executor, add_default_executor
@@ -76,6 +77,7 @@
7677
import torch as pytorch
7778

7879
import thunder.clang as clang
80+
from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map
7981

8082
# Imports executors (to populate default executors and make them accessible)
8183
import thunder.executors.pythonex
@@ -624,6 +626,7 @@ def get_computation_and_inputs(*args, **kwargs):
624626
backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0]))
625627
backward_traces.append(backward_trc)
626628

629+
computation_trc = transform_to_torch_types(computation_trc)
627630
comp = computation_trc.python_callable()
628631

629632
if backward_trc is not None:

thunder/common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import torch as torch
4646
import numpy as np
47+
import thunder
4748

4849
#
4950
# Datastructures for compiled functions
@@ -858,3 +859,20 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]:
858859
_fn._using_grad_transform = _using_grad_transform
859860

860861
return _fn
862+
863+
864+
def transform_to_torch_types(trace: TraceCtx):
865+
# convert the thunder types to torch types if any
866+
def map_to_torch(x: Any) -> Any:
867+
if isinstance(x, thunder.dtypes.dtype):
868+
return thunder.dtypes.to_torch_dtype(x)
869+
elif isinstance(x, thunder.devices.Device):
870+
return thunder.devices.to_torch_device(x)
871+
return x
872+
873+
last = trace.bound_symbols[-1]
874+
assert last.sym.id == prims.PrimIDs.RETURN
875+
new_args = tree_map(map_to_torch, last.args)
876+
new_bsym = prims.python_return.bind(*new_args, output=())
877+
trace.bound_symbols[-1] = new_bsym
878+
return trace

thunder/tests/test_core.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def foo(a, dev=dev):
390390
x, y = lc_result
391391

392392
assert x == 1
393-
assert y == dev
393+
assert y == thunder.devices.to_torch_device(dev)
394394

395395

396396
@instantiate(dtypes=(thunder.float32,))
@@ -2969,12 +2969,13 @@ def fn(x):
29692969
jfn = thunder.jit(fn)
29702970
actual_dtype = jfn(x)
29712971

2972-
assert actual_dtype == thunder.dtypes.float32
2972+
assert fn(x) == jfn(x)
2973+
assert actual_dtype == torch.float32
29732974

29742975
# Check with a different default dtype.
29752976
with set_default_dtype_ctx(torch.float16):
29762977
actual_dtype = jfn(x)
2977-
assert actual_dtype == thunder.dtypes.float16
2978+
assert actual_dtype == torch.float16
29782979

29792980
assert thunder.cache_misses(jfn) == 2
29802981

@@ -3002,10 +3003,12 @@ def fn():
30023003
return torch.arange(start=1, end=2, step=0.5).dtype
30033004

30043005
jfn = thunder.jit(fn)
3005-
assert jfn() == thunder.dtypes.float32
3006+
assert fn() == jfn()
3007+
assert jfn() == torch.float32
30063008

30073009
def fn():
30083010
return torch.arange(start=1, end=3, step=1).dtype
30093011

30103012
jfn = thunder.jit(fn)
3011-
assert jfn() == thunder.dtypes.int64
3013+
assert fn() == jfn()
3014+
assert jfn() == torch.int64

0 commit comments

Comments
 (0)