Skip to content

Commit

Permalink
recompute intermediates from decomposed symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 17, 2024
1 parent 6d946b3 commit 69389bd
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from thunder.core.trace import VariableInterface, from_trace, tracectx
from thunder.core.baseutils import ProxyInterface, TensorProxyInterface
from thunder.core.utils import safe_map_flat, sequencify
from thunder.core.proxies import variableify
from thunder.core.proxies import variableify, ProxyTag
from thunder.core.transform_common import VJPDual


Expand Down Expand Up @@ -183,6 +183,9 @@ def do_swap(v):

for new_bsym in new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
Expand Down

0 comments on commit 69389bd

Please sign in to comment.