From 5e4b00d62020db20f49371e86793aa6acd709de2 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 6 Nov 2024 15:31:44 +0900 Subject: [PATCH] handle `PrimIDs.RETURN` earlier Signed-off-by: Masaki Kozuki --- thunder/transforms/tensor_subclasses.py | 62 ++++++++++++------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index 677e6f67ec..979bdf31e5 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -450,6 +450,37 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing, def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map) + if updated_bsym.sym.id == prims.PrimIDs.RETURN: + unflatten_fake_tensor_result = tree_unflatten( + self.fx_computation_trace_result, + self.spec_of_fx_computation_trace_result, + ) + outputs: dict[str, Any] = updated_bsym.args[0] # {"output": ..., "flat_args": ...} + utils.check_type(outputs, dict) + utils.check( + isinstance(outputs, dict) and len(outputs) == 2 and ("output", "flat_args") == tuple(outputs.keys()), + lambda: fr"{outputs=} does not conform to the format of \{'output': ..., 'flat_args': [...]\}", + ) + seq_outs = utils.sequencify(outputs["output"]) + seq_fake_ret = utils.sequencify(unflatten_fake_tensor_result["output"]) + utils.check( + len(seq_outs) == len(seq_fake_ret), + lambda: f"{outputs['output']=}, {unflatten_fake_tensor_result['output']=}", + ) + + bsyms: list[BoundSymbol] = [] + for proxy_output, fx_output in zip(seq_outs, seq_fake_ret): + if not isinstance(proxy_output, SubclassTensorProxy): + continue + tensor_attrs, metadata = proxy_output.__tensor_flatten__() + tensors = [getattr(proxy_output, name) for name in tensor_attrs] + bsyms.append( + prims.unflatten_tensor_subclass.bind( + type(fx_output), dict(zip(tensor_attrs, tensors)), metadata, output=proxy_output + ) + ) + return [*bsyms, updated_bsym] + if not any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args): if bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR: subclass_proxy = updated_bsym.flat_proxy_outs[0] @@ -483,37 +514,6 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: exception_type=NotImplementedError, ) - if updated_bsym.sym.id == prims.PrimIDs.RETURN: - unflatten_fake_tensor_result = tree_unflatten( - self.fx_computation_trace_result, - self.spec_of_fx_computation_trace_result, - ) - outputs: dict[str, Any] = updated_bsym.args[0] # {"output": ..., "flat_args": ...} - utils.check_type(outputs, dict) - utils.check( - isinstance(outputs, dict) and len(outputs) == 2 and ("output", "flat_args") == tuple(outputs.keys()), - lambda: fr"{outputs=} does not conform to the format of \{'output': ..., 'flat_args': [...]\}", - ) - seq_outs = utils.sequencify(outputs["output"]) - seq_fake_ret = utils.sequencify(unflatten_fake_tensor_result["output"]) - utils.check( - len(seq_outs) == len(seq_fake_ret), - lambda: f"{outputs['output']=}, {unflatten_fake_tensor_result['output']=}", - ) - - bsyms: list[BoundSymbol] = [] - for proxy_output, fx_output in zip(seq_outs, seq_fake_ret): - if not isinstance(proxy_output, SubclassTensorProxy): - continue - tensor_attrs, metadata = proxy_output.__tensor_flatten__() - tensors = [getattr(proxy_output, name) for name in tensor_attrs] - bsyms.append( - prims.unflatten_tensor_subclass.bind( - type(fx_output), dict(zip(tensor_attrs, tensors)), metadata, output=proxy_output - ) - ) - return [*bsyms, updated_bsym] - trace = trace_from_bsym_or_bsyms(updated_bsym) fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace) utils.check(