Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename rematerialization of saved for backward symbols #1367

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import thunder.core.utils as utils
from thunder.core import dtypes, prims
from thunder.core.devices import cpu, Device
from thunder.core.trace import VariableInterface
from thunder.core.trace_interpreter import (
interpret_trace as eval_trace,
interpret_trace_to_trace,
Expand Down Expand Up @@ -3127,7 +3128,6 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
Returns:
tuple[Trace, Trace]: A tuple containing the new forward and backward traces.
"""

start_time_ns = time.perf_counter_ns()

saved_for_bw = get_saved_for_backward_tensors(fwd_trace)
Expand All @@ -3148,6 +3148,9 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr

producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args)

trace_tok = set_tracectx(bwd_trace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please set and reset traces only with "try: finally:" blocks. If there's any error between the calls, the trace will not be reset.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you set the input bwd_trace as the active trace? There are no Thunder operations calls between set and reset, and the input trace shouldn't be modified.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would we not use with tracectx(bwd_trace)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this I've taken inspiration from the code in the torch_autograd executor, in particular these lines explain why the need to set the trace context:

# [note: why setting trace ctx?]
# [`TensorProxy.replace_name`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L1221-L1223) calls
# [`tensorproxy`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L1506-L1520)
# which then calls `TensorProxy.__init__`. `TensorProxy.__init__` of course calls
# [` Proxy.__init__`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L81-L86).
# `Proxy`'s dunder init calls [`make_proxy_name`](https://github.com/Lightning-AI/lightning-thunder/blob/561b699/thunder/core/proxies.py#L81-L86)
# which depends on a tracectx.
trace_tok = set_tracectx(bwd_trace)

@IvanYashchuk Would an acceptable workaround be to create a new empty trace and use it as ctx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, or allow creating Proxies with any name without active tracectx. Maybe all is needed is to return True if trc is None in this function

def register_proxy_name(name: None | str = None):
trc = get_tracectx()
if name is not None and not trc.has_name(name):
trc.add_name(name)
return True
return False

swap_map: dict[VariableInterface, TensorProxy] = {}

required_fw_args = fwd_trace_args & old_saved_for_bwd
recomputed_tensors_from_producers = set()
for prod in producers:
Expand All @@ -3156,8 +3159,12 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
if prod_arg in fwd_trace_args:
required_fw_args.add(prod_arg)
for prod_out in prod.flat_outs:
if isinstance(prod_out, TensorProxy):
swap_map[variableify(prod_out)] = prod_out.replace_name(f"remat_for_{prod_out.name}")
riccardofelluga marked this conversation as resolved.
Show resolved Hide resolved
recomputed_tensors_from_producers.add(variableify(prod_out))

reset_tracectx(trace_tok)

required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers
new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd)

Expand Down Expand Up @@ -3189,10 +3196,11 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward)
new_bwd_trace.bound_symbols.append(new_unpack)
elif idx == 6:
new_bwd_trace.bound_symbols.extend(producers)
new_bwd_trace.bound_symbols.append(bsym)
for p in producers:
new_bwd_trace.bound_symbols.append(p.from_bsym_swap_proxies(swap_map=swap_map))
new_bwd_trace.bound_symbols.append(bsym.from_bsym_swap_proxies(swap_map=swap_map))
else:
new_bwd_trace.bound_symbols.append(bsym)
new_bwd_trace.bound_symbols.append(bsym.from_bsym_swap_proxies(swap_map=swap_map))
riccardofelluga marked this conversation as resolved.
Show resolved Hide resolved

new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]]

Expand Down
23 changes: 18 additions & 5 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ def forward(self, x):
cd = thunder.compile_data(jmodel)
cs = thunder.compile_stats(jmodel)

from copy import copy

fwd_names = fwd_trace.names
bwd_names = bwd_trace.names

# Do not recompute any
cd.compile_options["recomputation_policy"] = lambda x: set()
with compile_data_and_stats(cd, cs):
Expand Down Expand Up @@ -408,6 +413,11 @@ def forward(self, x):
old_saved_for_bwd = {variableify(j) for j in saved_for_bw}

all_rematerializable = old_saved_for_bwd - fwd_trace_args
all_rematerializable_names = {f"remat_for_{unvariableify(x).name}" for x in all_rematerializable}

# Reset names to avoid conflicts
fwd_trace.names = copy(fwd_names)
bwd_trace.names = copy(bwd_names)

cd.compile_options["recomputation_policy"] = lambda x: x
with compile_data_and_stats(cd, cs):
Expand All @@ -416,27 +426,30 @@ def forward(self, x):
# List the outputs after the unpacks
bwd_bsym_out = set(
map(
lambda x: variableify(x.output),
lambda x: x.output.name,
filter(lambda x: isinstance(x.output, TensorProxy), new_bwd.bound_symbols[6:]),
)
)
# check that all the fwd are recomputed
for rematerializable in all_rematerializable:
for rematerializable in all_rematerializable_names:
assert rematerializable in bwd_bsym_out

# Reset names to avoid conflicts
fwd_trace.names = copy(fwd_names)
bwd_trace.names = copy(bwd_names)

# Recompute only one tensor
cd.compile_options["recomputation_policy"] = lambda x: set(filter(lambda i: unvariableify(i).name == "t7", x))
t7 = set(filter(lambda x: unvariableify(x).name == "t7", all_rematerializable))
with compile_data_and_stats(cd, cs):
_, new_bwd = recompute_saved_for_backward(fwd_trace, bwd_trace)

bwd_bsym_out = set(
map(
lambda x: variableify(x.output),
lambda x: x.output.name,
filter(lambda x: isinstance(x.output, TensorProxy), new_bwd.bound_symbols[6:]),
)
)
assert t7 not in bwd_bsym_out, "Unexpected tensor rematerialized in the backward."
assert "remat_for_t7" in bwd_bsym_out, "Unexpected tensor rematerialized in the backward."


def test_lora_transform_linear():
Expand Down
Loading