Skip to content

Commit

Permalink
flatten Function.apply of converter
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 7, 2024
1 parent 526ca4b commit 3fa8e2d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 21 deletions.
5 changes: 4 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -3550,7 +3550,10 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy:
if isinstance(a, SubclassTensorProxy):
# SubclassTensorProxy(like=...) would not copy some attrs such as `_tensors` while replace does.
return a.replace()
return TensorProxy(like=a)


Expand Down
32 changes: 17 additions & 15 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,21 @@ def f(x: ScaleTensorSubclass, y: ScaleTensorSubclass) -> torch.Tensor:

dtype = torch.float32
shape = (2, 2)
# x = ScaleTensorSubclass(
# make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
# make_tensor((), device=device, dtype=dtype),
# )
# y = ScaleTensorSubclass(
# make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
# make_tensor((), device=device, dtype=dtype),
# )
#
# expected = f(x, y)
# actual = jitted(x, y)
# assert type(expected) is type(actual)
# torch.testing.assert_close(expected, actual)
# if requires_grad:
# actual.mean().backward()
x = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)
y = ScaleTensorSubclass(
make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad),
make_tensor((), device=device, dtype=dtype),
)

expected = f(x, y)
actual = jitted(x, y)
assert type(expected) is type(actual)
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()

def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
y = EncapsulateXandScale.apply(data, scale)
Expand All @@ -232,3 +232,5 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
actual = jitted(x, data, scale)
assert type(expected) is type(actual)
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()
36 changes: 31 additions & 5 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,38 @@ def proxy_fake_tensor(t: torch.Tensor | FakeTensor) -> ProxyInterface:


def trace_from_bsym_or_bsyms(bsym_or_bsyms: BoundSymbol | Sequence[BoundSymbol]) -> TraceCtx:
from thunder.core.compile_data import get_compile_data

cd = get_compile_data()
ad_hoc_executor = None
if cd is not None:
from thunder.extend import AdHocExecutor

executors_list = list(filter(lambda t: isinstance(t, AdHocExecutor), cd.executors_list))
if executors_list:
ad_hoc_executor = executors_list[0]

bsyms = utils.sequencify(bsym_or_bsyms)
trace_args = bsyms[0].flat_proxy_args
trace_name = bsyms[0].sym.name

if ad_hoc_executor is not None and ad_hoc_executor._implmap:
tmp_bsyms = []
for bsym in bsyms:
if ad_hoc_executor.can_execute(bsym) and bsym.subsymbols:
tmp_bsyms.extend(bsym.subsymbols)
else:
tmp_bsyms.append(bsym)
bsyms = tmp_bsyms

trace = TraceCtx()
trace.bound_symbols.extend(bsyms)
trace.args = bsyms[0].flat_proxy_args
trace.args = trace_args
with tracectx(trace):
prims.python_return(bsyms[-1].output)
with tracectx(trace):
# note(crcrpar): Give prefix `tmp` to avoid infinite recursion due to the same name
trace._siginfo = SigInfo.from_name_and_args(f"tmp_{bsyms[0].sym.name}", trace.args)
trace._siginfo = SigInfo.from_name_and_args(f"tmp_{trace_name}", trace.args)
return trace


Expand Down Expand Up @@ -371,7 +393,7 @@ def transform_out(out: torch.Tensor) -> OutputWrapperForFxTracing:
output = OutputWrapperForFxTracing(out, None)
return output

extrace = transform_for_execution(trace, [get_executor("torch")])
extrace = transform_for_execution(trace, [get_executor("torch"), get_executor("ad_hoc")])
f = extrace.python_callable(include_decorators=False)

def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing, ...]:
Expand Down Expand Up @@ -420,8 +442,12 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
if not self.subclass_proxy_to_flatten or True:
return [updated_bsym]

is_subclass_ctor = bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR
if not is_subclass_ctor and not any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args):
is_bsym_of_subclass_ctor = bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR
returns_subclass = any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_outs)
no_subclass_args = all(not isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args)
is_unpack = bsym.sym.id in {prims.PrimIDs.UNPACK_TRIVIAL, prims.PrimIDs.UNPACK_SEQUENCE}
is_subclass_ctor = is_bsym_of_subclass_ctor or (no_subclass_args and returns_subclass and not is_unpack)
if not is_subclass_ctor and no_subclass_args:
return [updated_bsym]

utils.check(
Expand Down

0 comments on commit 3fa8e2d

Please sign in to comment.