diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index d32ba0d4a9..9d3a4adfa8 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -187,40 +187,19 @@ class DesugarTensorSubclass: computation_trace: TraceCtx swap_map: dict[Variable, ProxyInterface] = field(init=False, default_factory=dict) fake_tensor_mode: FakeTensorMode = field(init=False, default_factory=FakeTensorMode) - fx_computation_trace: GraphModule = field(init=False, default=None) - computation_trace_output: tuple[OutputWrapperForFxTracing, ...] = field(init=False, default=None) - fx_computation_trace_result: tuple[torch.Tensor, ...] = field(init=False, default=None) - spec_of_fx_computation_trace_result: PyTreeSpec = field(init=False, default=None) flat_trace_args: Sequence[ProxyInterface] = field(init=False, default=None) flat_trace_args_spec: Any = field(init=False, default=None) - requires_desugarring: bool = field(init=False, default=False) subclass_proxy_to_flatten: set[Variable] = field(init=False, default_factory=set) + bsym_to_new_outputs: dict[BoundSymbol, list[TensorProxy]] = field(init=False, default_factory=dict) def __post_init__(self) -> None: self.flat_trace_args, self.flat_trace_args_spec = tree_flatten( (self.computation_trace.args, self.computation_trace.kwargs) ) for arg in self.flat_trace_args: - # self.maybe_update_subclass_type_dict(arg) if isinstance(arg, SubclassTensorProxy): self.subclass_proxy_to_flatten.add(variableify(arg)) - # TODO(crcrpar): From my perspective, this check is rather for the sake of faster compilation. - # There could be a computation graph where none of the inputs are subclass while - # that graph call subclass creation inside of it. - self.requires_desugarring = any(isinstance(t, SubclassTensorProxy) for t in self.flat_trace_args) - if not self.requires_desugarring: - return - - ( - self.fx_computation_trace, - self.computation_trace_output, - self.fx_computation_trace_result, - self.spec_of_fx_computation_trace_result, - ) = self.convert_trace_to_fx_graph_and_get_fake_result( - self.computation_trace, - ) - def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: return p._tensor_attr_names @@ -230,7 +209,7 @@ def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: def translate_fx_graph_into_bsym( self, bsym: BoundSymbol, - fx: GraphModule, + fx_graph: GraphModule, ) -> BoundSymbol | tuple[BoundSymbol, ...]: import thunder.torch as ltorch @@ -272,7 +251,7 @@ def translate_fx_graph_into_bsym( list_of_placeholder_node: list[Node] = [] list_of_function_call_node: list[Node] = [] node_of_output: Node - for node in fx.graph.nodes: + for node in fx_graph.graph.nodes: if node.op == PLACEHOLDER: list_of_placeholder_node.append(node) if node.op == CALL_FUNCTION: @@ -310,6 +289,7 @@ def translate_fx_graph_into_bsym( if is_subclass_ctor_bsym := bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR: utils.check_type(orig_output, SubclassTensorProxy) if isinstance(orig_output, SubclassTensorProxy): + # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. args: list[Node] = node_of_output.args[0] new_tensor_proxies = [] for a in args: @@ -328,6 +308,15 @@ def translate_fx_graph_into_bsym( new_subclass = orig_output.replace() for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): setattr(new_subclass, name, value) + bsyms.append( + prims.unflatten_tensor_subclass.bind( + new_subclass._subclass_type, + dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), + dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), + output=new_subclass, + ) + ) + self.swap_map[variableify(orig_output)] = new_subclass return bsyms @@ -423,35 +412,8 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing, def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map) if updated_bsym.sym.id == prims.PrimIDs.RETURN: - unflatten_fake_tensor_result = tree_unflatten( - self.fx_computation_trace_result, - self.spec_of_fx_computation_trace_result, - ) - outputs: dict[str, Any] = updated_bsym.args[0] # {"output": ..., "flat_args": ...} - utils.check_type(outputs, dict) - utils.check( - isinstance(outputs, dict) and len(outputs) == 2 and ("output", "flat_args") == tuple(outputs.keys()), - lambda: fr"{outputs=} does not conform to the format of \{'output': ..., 'flat_args': [...]\}", - ) - seq_outs = utils.sequencify(outputs["output"]) - seq_fake_ret = utils.sequencify(unflatten_fake_tensor_result["output"]) - utils.check( - len(seq_outs) == len(seq_fake_ret), - lambda: f"{outputs['output']=}, {unflatten_fake_tensor_result['output']=}", - ) - - bsyms: list[BoundSymbol] = [] - for proxy_output, fx_output in zip(seq_outs, seq_fake_ret): - if not isinstance(proxy_output, SubclassTensorProxy): - continue - tensor_attrs, metadata = proxy_output.__tensor_flatten__() - tensors = [getattr(proxy_output, name) for name in tensor_attrs] - bsyms.append( - prims.unflatten_tensor_subclass.bind( - type(fx_output), dict(zip(tensor_attrs, tensors)), metadata, output=proxy_output - ) - ) - return [*bsyms, updated_bsym] + if not self.subclass_proxy_to_flatten or True: + return [updated_bsym] is_subclass_ctor = bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR if not is_subclass_ctor and not any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args): @@ -505,6 +467,8 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: self.swap_map.update(dict(zip(sequence_out, utils.sequencify(out_proxy)))) bsym_with_modified_output = updated_bsym.from_bsym_swap_proxies(self.swap_map) + + self.bsym_to_new_outputs[bsym_with_modified_output] = bsym_with_modified_output return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) @@ -535,14 +499,15 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx: behavior is spelled out. """ desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=computation_trace) - if not desugar_tensor_subclass.requires_desugarring: - return computation_trace updated_bsyms: list[BoundSymbol] = [] bsym: BoundSymbol for bsym in computation_trace.bound_symbols: maybe_desugared_bsyms = desugar_tensor_subclass(bsym) updated_bsyms.extend(maybe_desugared_bsyms) + if not desugar_tensor_subclass.subclass_proxy_to_flatten: + return computation_trace + computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace) computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms) computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared"))