Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable additional transforms for the PyTorch Autograd path #74

Merged
44 changes: 28 additions & 16 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from thunder.core.compile_data import compile_data_and_stats
from thunder.core.langctxs import LanguageContext
import thunder.core.langctxs as langctxs
from thunder.core.baseutils import run_once
from thunder.core.baseutils import run_once, check
from thunder.core.proxies import (
Proxy,
TensorProxy,
Expand Down Expand Up @@ -563,24 +563,37 @@ def get_computation_and_inputs(*args, **kwargs):
# thunder_backward may recursively call compile and wraps the result in a
# torch.autograd.Function to support embedding of Thunder-compiled
# functions in torch's Autograd

# Currently split_forward_backward also includes
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
# breaking the order of operations
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward
extrace = computation_trc
extraces = []
check(
additional_transforms,
lambda: "Specifying additional_transforms is not supported with PyTorch Autograd integration",
)
else:
cs.last_computation_transformation_start = time.time_ns()

## EPILOGUE and TRANSFORMS should not mix...
# applies transforms
for transform in additional_transforms:
computation_trc = transform(computation_trc, executors_list=cd.executors_list)
computation_traces.append(computation_trc)

with langctxs.langctx(cd.langctx):
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
)
extrace = extraces[-1]
cs.last_computation_transformation_stop = time.time_ns()

cs.last_computation_transformation_start = time.time_ns()

## EPILOGUE and TRANSFORMS should not mix...
# applies transforms
for transform in additional_transforms:
computation_trc = transform(computation_trc, executors_list=cd.executors_list)
computation_traces.append(computation_trc)

with langctxs.langctx(cd.langctx):
extraces = transform_for_execution(
computation_trc,
executors_list=cd.executors_list,
)
extrace = extraces[-1]
comp = extrace.python_callable()

if backward_trc is not None:
Expand All @@ -595,7 +608,6 @@ def get_computation_and_inputs(*args, **kwargs):
if cd.cache_option is not CACHE_OPTIONS.NO_CACHING:
cs.interpreter_cache.append(cache_entry)

cs.last_computation_transformation_stop = time.time_ns()
cs.last_traces += extraces
cs.last_prologue_traces = [prologue_trc] + protraces
cs.last_prologue = pro
Expand Down
Loading