diff --git a/thunder/__init__.py b/thunder/__init__.py index 3ac277ed4f..b169304509 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -635,6 +635,7 @@ def get_computation_and_inputs(*args, **kwargs): _tensor_subclass_transform_applied = True if not _tensor_subclass_transform_applied: computation_trc = flatten_tensor_subclasses(computation_trc) + computation_traces.append(computation_trc) if backward_trc is None: from thunder.executors.passes import transform_for_execution as transform_for_execution_pass