diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index 1b59a695ba..677e6f67ec 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -208,6 +208,11 @@ def __post_init__(self) -> None: self.flat_trace_args, self.flat_trace_args_spec = tree_flatten( (self.computation_trace.args, self.computation_trace.kwargs) ) + for arg in self.flat_trace_args: + self.maybe_update_subclass_type_dict(arg) + if isinstance(arg, SubclassTensorProxy): + self.subclass_proxy_to_flatten.add(variableify(arg)) + # TODO(crcrpar): From my perspective, this check is rather for the sake of faster compilation. # There could be a computation graph where none of the inputs are subclass while # that graph call subclass creation inside of it. @@ -215,9 +220,6 @@ def __post_init__(self) -> None: if not self.requires_desugarring: return - for arg in self.flat_trace_args: - self.maybe_update_subclass_type_dict(arg) - ( self.fx_computation_trace, self.computation_trace_output, @@ -226,9 +228,6 @@ def __post_init__(self) -> None: ) = self.convert_trace_to_fx_graph_and_get_fake_result( self.computation_trace, ) - self.subclass_proxy_to_flatten: set[Variable] = { - variableify(a) for a in filter(lambda t: isinstance(t, SubclassTensorProxy), self.flat_trace_args) - } def maybe_update_subclass_type_dict(self, proxy_arg: ProxyInterface) -> None: if not isinstance(proxy_arg, SubclassTensorProxy):