From cf066e900d85d1fd2d5c141720bc2d7cfcd164f5 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 27 Nov 2024 16:39:36 +0900 Subject: [PATCH] simplify subclass output handling Signed-off-by: Masaki Kozuki --- thunder/transforms/tensor_subclasses.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index f1d8e76a83..0538f5bc06 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -379,20 +379,11 @@ def translate_fx_graph_into_bsym( f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" ), ) - with tracectx(self.computation_trace): - new_subclass = orig_output.replace() - for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): - setattr(new_subclass, name, value) - bsyms.append( - prims.unflatten_tensor_subclass.bind( - new_subclass._subclass_type, - dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), - dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), - output=new_subclass, - ) - ) + if [variableify(t) for t in orig_output._tensors] != [variableify(t) for t in new_tensor_proxies]: + orig_output._tensors = new_tensor_proxies + for name, tensor in zip(orig_output._tensor_attr_names, new_tensor_proxies): + setattr(orig_output, name, tensor) - self.swap_map[variableify(orig_output)] = new_subclass else: non_none_args = [n for n in node_of_output.args[0] if n is not None] utils.check(len(non_none_args) == 1, lambda: f"{node_of_output.args = }")