Skip to content

Commit

Permalink
handle PrimIDs.RETURN earlier
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 6, 2024
1 parent bb5754a commit 4cd7a3d
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,37 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing,

def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map)
if updated_bsym.sym.id == prims.PrimIDs.RETURN:
unflatten_fake_tensor_result = tree_unflatten(
self.fx_computation_trace_result,
self.spec_of_fx_computation_trace_result,
)
outputs: dict[str, Any] = updated_bsym.args[0] # {"output": ..., "flat_args": ...}
utils.check_type(outputs, dict)
utils.check(
isinstance(outputs, dict) and len(outputs) == 2 and ("output", "flat_args") == tuple(outputs.keys()),
lambda: fr"{outputs=} does not conform to the format of \{'output': ..., 'flat_args': [...]\}",
)
seq_outs = utils.sequencify(outputs["output"])
seq_fake_ret = utils.sequencify(unflatten_fake_tensor_result["output"])
utils.check(
len(seq_outs) == len(seq_fake_ret),
lambda: f"{outputs['output']=}, {unflatten_fake_tensor_result['output']=}",
)

bsyms: list[BoundSymbol] = []
for proxy_output, fx_output in zip(seq_outs, seq_fake_ret):
if not isinstance(proxy_output, SubclassTensorProxy):
continue
tensor_attrs, metadata = proxy_output.__tensor_flatten__()
tensors = [getattr(proxy_output, name) for name in tensor_attrs]
bsyms.append(
prims.unflatten_tensor_subclass.bind(
type(fx_output), dict(zip(tensor_attrs, tensors)), metadata, output=proxy_output
)
)
return [*bsyms, updated_bsym]

if not any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args):
if bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR:
subclass_proxy = updated_bsym.flat_proxy_outs[0]
Expand Down Expand Up @@ -483,37 +514,6 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
exception_type=NotImplementedError,
)

if updated_bsym.sym.id == prims.PrimIDs.RETURN:
unflatten_fake_tensor_result = tree_unflatten(
self.fx_computation_trace_result,
self.spec_of_fx_computation_trace_result,
)
outputs: dict[str, Any] = updated_bsym.args[0] # {"output": ..., "flat_args": ...}
utils.check_type(outputs, dict)
utils.check(
isinstance(outputs, dict) and len(outputs) == 2 and ("output", "flat_args") == tuple(outputs.keys()),
lambda: fr"{outputs=} does not conform to the format of \{'output': ..., 'flat_args': [...]\}",
)
seq_outs = utils.sequencify(outputs["output"])
seq_fake_ret = utils.sequencify(unflatten_fake_tensor_result["output"])
utils.check(
len(seq_outs) == len(seq_fake_ret),
lambda: f"{outputs['output']=}, {unflatten_fake_tensor_result['output']=}",
)

bsyms: list[BoundSymbol] = []
for proxy_output, fx_output in zip(seq_outs, seq_fake_ret):
if not isinstance(proxy_output, SubclassTensorProxy):
continue
tensor_attrs, metadata = proxy_output.__tensor_flatten__()
tensors = [getattr(proxy_output, name) for name in tensor_attrs]
bsyms.append(
prims.unflatten_tensor_subclass.bind(
type(fx_output), dict(zip(tensor_attrs, tensors)), metadata, output=proxy_output
)
)
return [*bsyms, updated_bsym]

trace = trace_from_bsym_or_bsyms(updated_bsym)
fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace)
utils.check(
Expand Down

0 comments on commit 4cd7a3d

Please sign in to comment.