Skip to content

Commit

Permalink
simplify subclass output handling
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 9, 2024
1 parent 4675f8f commit cf066e9
Showing 1 changed file with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,20 +379,11 @@ def translate_fx_graph_into_bsym(
f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}"
),
)
with tracectx(self.computation_trace):
new_subclass = orig_output.replace()
for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies):
setattr(new_subclass, name, value)
bsyms.append(
prims.unflatten_tensor_subclass.bind(
new_subclass._subclass_type,
dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)),
dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)),
output=new_subclass,
)
)
if [variableify(t) for t in orig_output._tensors] != [variableify(t) for t in new_tensor_proxies]:
orig_output._tensors = new_tensor_proxies
for name, tensor in zip(orig_output._tensor_attr_names, new_tensor_proxies):
setattr(orig_output, name, tensor)

self.swap_map[variableify(orig_output)] = new_subclass
else:
non_none_args = [n for n in node_of_output.args[0] if n is not None]
utils.check(len(non_none_args) == 1, lambda: f"{node_of_output.args = }")
Expand Down

0 comments on commit cf066e9

Please sign in to comment.