From f1601112d4d13b44a2b6eedc6f85de0526275c72 Mon Sep 17 00:00:00 2001 From: Kaeun Kim <51257208+k223kim@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:32:40 +0900 Subject: [PATCH] convert thunder types to torch types (#792) --- thunder/__init__.py | 3 +++ thunder/common.py | 18 ++++++++++++++++++ thunder/tests/test_core.py | 13 ++++++++----- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 3ed8e47a1e..b344af60b7 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -47,6 +47,7 @@ _create_callable, trace, transform_for_execution, + transform_to_torch_types, ) import thunder.extend as extend from thunder.extend import Executor, add_default_executor @@ -76,6 +77,7 @@ import torch as pytorch import thunder.clang as clang +from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map # Imports executors (to populate default executors and make them accessible) import thunder.executors.pythonex @@ -624,6 +626,7 @@ def get_computation_and_inputs(*args, **kwargs): backward_trc = cudagraphex.fusion_pass(backward_trc, num_static_inputs=len(backward_trc.args[0][0])) backward_traces.append(backward_trc) + computation_trc = transform_to_torch_types(computation_trc) comp = computation_trc.python_callable() if backward_trc is not None: diff --git a/thunder/common.py b/thunder/common.py index e6ce1068ca..b6188d0643 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -44,6 +44,7 @@ import torch as torch import numpy as np +import thunder # # Datastructures for compiled functions @@ -858,3 +859,20 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: _fn._using_grad_transform = _using_grad_transform return _fn + + +def transform_to_torch_types(trace: TraceCtx): + # convert the thunder types to torch types if any + def map_to_torch(x: Any) -> Any: + if isinstance(x, thunder.dtypes.dtype): + return thunder.dtypes.to_torch_dtype(x) + elif isinstance(x, thunder.devices.Device): + return thunder.devices.to_torch_device(x) + return x + + last = trace.bound_symbols[-1] + assert last.sym.id == prims.PrimIDs.RETURN + new_args = tree_map(map_to_torch, last.args) + new_bsym = prims.python_return.bind(*new_args, output=()) + trace.bound_symbols[-1] = new_bsym + return trace diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 23fe696013..cf81cf93c2 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -390,7 +390,7 @@ def foo(a, dev=dev): x, y = lc_result assert x == 1 - assert y == dev + assert y == thunder.devices.to_torch_device(dev) @instantiate(dtypes=(thunder.float32,)) @@ -2969,12 +2969,13 @@ def fn(x): jfn = thunder.jit(fn) actual_dtype = jfn(x) - assert actual_dtype == thunder.dtypes.float32 + assert fn(x) == jfn(x) + assert actual_dtype == torch.float32 # Check with a different default dtype. with set_default_dtype_ctx(torch.float16): actual_dtype = jfn(x) - assert actual_dtype == thunder.dtypes.float16 + assert actual_dtype == torch.float16 assert thunder.cache_misses(jfn) == 2 @@ -3002,10 +3003,12 @@ def fn(): return torch.arange(start=1, end=2, step=0.5).dtype jfn = thunder.jit(fn) - assert jfn() == thunder.dtypes.float32 + assert fn() == jfn() + assert jfn() == torch.float32 def fn(): return torch.arange(start=1, end=3, step=1).dtype jfn = thunder.jit(fn) - assert jfn() == thunder.dtypes.int64 + assert fn() == jfn() + assert jfn() == torch.int64