Skip to content

Commit

Permalink
simplify subclass_proxy_to_flatten
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 6, 2024
1 parent 6043450 commit bb5754a
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,18 @@ def __post_init__(self) -> None:
self.flat_trace_args, self.flat_trace_args_spec = tree_flatten(
(self.computation_trace.args, self.computation_trace.kwargs)
)
for arg in self.flat_trace_args:
self.maybe_update_subclass_type_dict(arg)
if isinstance(arg, SubclassTensorProxy):
self.subclass_proxy_to_flatten.add(variableify(arg))

# TODO(crcrpar): From my perspective, this check is rather for the sake of faster compilation.
# There could be a computation graph where none of the inputs are subclass while
# that graph call subclass creation inside of it.
self.requires_desugarring = any(isinstance(t, SubclassTensorProxy) for t in self.flat_trace_args)
if not self.requires_desugarring:
return

for arg in self.flat_trace_args:
self.maybe_update_subclass_type_dict(arg)

(
self.fx_computation_trace,
self.computation_trace_output,
Expand All @@ -226,9 +228,6 @@ def __post_init__(self) -> None:
) = self.convert_trace_to_fx_graph_and_get_fake_result(
self.computation_trace,
)
self.subclass_proxy_to_flatten: set[Variable] = {
variableify(a) for a in filter(lambda t: isinstance(t, SubclassTensorProxy), self.flat_trace_args)
}

def maybe_update_subclass_type_dict(self, proxy_arg: ProxyInterface) -> None:
if not isinstance(proxy_arg, SubclassTensorProxy):
Expand Down

0 comments on commit bb5754a

Please sign in to comment.