Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove transform_trace_additionally #914

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,19 +659,6 @@ def get_computation_and_inputs(*args, **kwargs):
# by split_forward_backward

if backward_trc is None:
k223kim marked this conversation as resolved.
Show resolved Hide resolved
## EPILOGUE and TRANSFORMS should not mix...
# applies transforms
cs.last_computation_transformation_start = time.perf_counter_ns()
for transform in transforms:
new_computation_trc = transform.transform_trace_additionally(
computation_trc, executors_list=cd.executors_list
)
if new_computation_trc is not computation_trc:
# todo: deprecation
computation_trc = new_computation_trc
computation_traces.append(computation_trc)
cs.last_computation_transformation_stop = time.perf_counter_ns()

from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits
Expand Down Expand Up @@ -767,7 +754,6 @@ def fn_(*args, **kwargs) -> Any:
cs.last_trace_host_execution_start = time.perf_counter_ns()

if cache_entry.vanilla_tensor_args:

if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps):
alias_tensor_indices = alias_tensor_indices_str
alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")}
Expand Down
9 changes: 0 additions & 9 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ class CompileStats:
last_prologue_transformation_stop (int):
last_prologue_execution_start (int):
last_prologue_execution_stop (int):
last_computation_transformation_start (int):
last_computation_transformation_stop (int):
last_computation_execution_start (int):
last_computation_execution_stop (int):
cache (dict):
Expand Down Expand Up @@ -121,8 +119,6 @@ def __init__(self):
self.last_prologue_transformation_stop: int = -1
self.last_prologue_execution_start: int = -1
self.last_prologue_execution_stop: int = -1
self.last_computation_transformation_start: int = -1
self.last_computation_transformation_stop: int = -1
self.last_computation_execution_start: int = -1
self.last_computation_execution_stop: int = -1

Expand Down Expand Up @@ -163,11 +159,6 @@ def last_prologue_execution_time(self, /) -> int:
stop: int = self.last_prologue_execution_stop
return self._time_template(start, stop, "prologue execution")

def last_computation_transformation_time(self, /) -> int:
start: int = self.last_computation_transformation_start
stop: int = self.last_computation_transformation_stop
return self._time_template(start, stop, "computation transformation")

def last_computation_execution_time(self, /) -> int:
start: int = self.last_computation_execution_start
stop: int = self.last_computation_execution_stop
Expand Down
9 changes: 0 additions & 9 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,6 @@ def transform_state_dict_for_submodule(
"""
return state_dict

def transform_trace_additionally(self, computation_trace: Trace, **kwargs):
"""
transform_trace_additionally enables transforming the computation trace before optimization pass.
Note that this transform is only applicable if autograd is disabled.

Please don't use this method in new implementations, we are working on removing it. Use transform_traces_pre_prologue instead.
"""
return computation_trace

def transform_trace_post_optimization(self, computation_trace: Trace, **kwargs):
"""
transform_trace_post_optimization enables transforming computation trace after optimization pass.
Expand Down
19 changes: 13 additions & 6 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,6 @@ def grad(
cfn,
) -> Callable:
def grad(func):

@wraps(func)
def grad_func(*args, **kwargs):
_, grads = value_and_grad(func)(*args, **kwargs)
Expand All @@ -1440,16 +1439,24 @@ def grad_func(*args, **kwargs):
return grad_func

class _GradTransform(Transform):
def transform_trace_additionally(self, trc: Trace, *, executors_list: Sequence[Any]) -> Trace:
def transform_traces_pre_prologue(
self,
prologue_trc: Trace,
computation_trc: Trace,
epilogue_trc: Trace | None,
*,
executors_list: Sequence[Any],
) -> Trace:
k223kim marked this conversation as resolved.
Show resolved Hide resolved
# Using trc.python_callable() makes it impossible to retrace the
# function because the python_callable uses python_ctx which replaces
# symbol occurrences with its symbol._call_ctx function
@wraps(trc.python_callable())
@wraps(computation_trc.python_callable())
def python_callable(*args, **kwargs):
return eval_trace(trc, *args, **kwargs)
computation_trc = dce(computation_trc)
k223kim marked this conversation as resolved.
Show resolved Hide resolved
return eval_trace(computation_trc, *args, **kwargs)

gradtrc = construct_trace()(grad(python_callable), *trc.args, **trc.kwargs)
return gradtrc
gradtrc = construct_trace()(grad(python_callable), *computation_trc.args, **computation_trc.kwargs)
return prologue_trc, gradtrc, epilogue_trc

cfn._using_grad_transform = True
_grad_transform = _GradTransform()
Expand Down
Loading