diff --git a/thunder/__init__.py b/thunder/__init__.py index 79e0c44a59..78e69d33cb 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -659,19 +659,6 @@ def get_computation_and_inputs(*args, **kwargs): # by split_forward_backward if backward_trc is None: - ## EPILOGUE and TRANSFORMS should not mix... - # applies transforms - cs.last_computation_transformation_start = time.perf_counter_ns() - for transform in transforms: - new_computation_trc = transform.transform_trace_additionally( - computation_trc, executors_list=cd.executors_list - ) - if new_computation_trc is not computation_trc: - # todo: deprecation - computation_trc = new_computation_trc - computation_traces.append(computation_trc) - cs.last_computation_transformation_stop = time.perf_counter_ns() - from thunder.executors.passes import transform_for_execution as transform_for_execution_pass from thunder.executors.passes import _transform_for_operator_executor_execution from thunder.distributed.utils import maybe_sort_waits @@ -767,7 +754,6 @@ def fn_(*args, **kwargs) -> Any: cs.last_trace_host_execution_start = time.perf_counter_ns() if cache_entry.vanilla_tensor_args: - if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps): alias_tensor_indices = alias_tensor_indices_str alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} diff --git a/thunder/common.py b/thunder/common.py index e0fc05439b..74b0fa0e0a 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -83,8 +83,6 @@ class CompileStats: last_prologue_transformation_stop (int): last_prologue_execution_start (int): last_prologue_execution_stop (int): - last_computation_transformation_start (int): - last_computation_transformation_stop (int): last_computation_execution_start (int): last_computation_execution_stop (int): cache (dict): @@ -121,8 +119,6 @@ def __init__(self): self.last_prologue_transformation_stop: int = -1 self.last_prologue_execution_start: int = -1 self.last_prologue_execution_stop: int = -1 - self.last_computation_transformation_start: int = -1 - self.last_computation_transformation_stop: int = -1 self.last_computation_execution_start: int = -1 self.last_computation_execution_stop: int = -1 @@ -163,11 +159,6 @@ def last_prologue_execution_time(self, /) -> int: stop: int = self.last_prologue_execution_stop return self._time_template(start, stop, "prologue execution") - def last_computation_transformation_time(self, /) -> int: - start: int = self.last_computation_transformation_start - stop: int = self.last_computation_transformation_stop - return self._time_template(start, stop, "computation transformation") - def last_computation_execution_time(self, /) -> int: start: int = self.last_computation_execution_start stop: int = self.last_computation_execution_stop diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 1750c7cfe7..d322d406d0 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -363,15 +363,6 @@ def transform_state_dict_for_submodule( """ return state_dict - def transform_trace_additionally(self, computation_trace: Trace, **kwargs): - """ - transform_trace_additionally enables transforming the computation trace before optimization pass. - Note that this transform is only applicable if autograd is disabled. - - Please don't use this method in new implementations, we are working on removing it. Use transform_traces_pre_prologue instead. - """ - return computation_trace - def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs): """ transform_trace_post_optimization enables transforming computation trace after optimization pass. diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 4f6c07f687..18cf1adcbc 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1429,7 +1429,6 @@ def grad( cfn, ) -> Callable: def grad(func): - @wraps(func) def grad_func(*args, **kwargs): _, grads = value_and_grad(func)(*args, **kwargs) @@ -1440,16 +1439,25 @@ def grad_func(*args, **kwargs): return grad_func class _GradTransform(Transform): - def transform_trace_additionally(self, trc: Trace, *, executors_list: Sequence[Any]) -> Trace: + def transform_traces_pre_prologue( + self, + prologue_trc: Trace, + computation_trc: Trace, + epilogue_trc: Trace | None, + *, + executors_list: Sequence[Any], + ) -> Trace: # Using trc.python_callable() makes it impossible to retrace the # function because the python_callable uses python_ctx which replaces # symbol occurrences with its symbol._call_ctx function - @wraps(trc.python_callable()) + computation_trc = dce(computation_trc) + + @wraps(computation_trc.python_callable()) def python_callable(*args, **kwargs): - return eval_trace(trc, *args, **kwargs) + return eval_trace(computation_trc, *args, **kwargs) - gradtrc = construct_trace()(grad(python_callable), *trc.args, **trc.kwargs) - return gradtrc + gradtrc = construct_trace()(grad(python_callable), *computation_trc.args, **computation_trc.kwargs) + return prologue_trc, gradtrc, epilogue_trc cfn._using_grad_transform = True _grad_transform = _GradTransform()