Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename rematerialization of saved for backward symbols #1367

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import thunder.core.utils as utils
from thunder.core import dtypes, prims
from thunder.core.devices import cpu, Device
from thunder.core.trace import VariableInterface
from thunder.core.trace_interpreter import (
interpret_trace as eval_trace,
interpret_trace_to_trace,
Expand Down Expand Up @@ -3126,7 +3127,6 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
Returns:
tuple[Trace, Trace]: A tuple containing the new forward and backward traces.
"""

start_time_ns = time.perf_counter_ns()

saved_for_bw = get_saved_for_backward_tensors(fwd_trace)
Expand All @@ -3147,33 +3147,37 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr

producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args)

required_fw_args = fwd_trace_args & old_saved_for_bwd
recomputed_tensors_from_producers = set()
for prod in producers:
for prod_arg in prod.flat_args:
prod_arg = variableify(prod_arg)
if prod_arg in fwd_trace_args:
required_fw_args.add(prod_arg)
for prod_out in prod.flat_outs:
recomputed_tensors_from_producers.add(variableify(prod_out))

required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers
new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd)

new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()
new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1]))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None)

new_bwd_trace = from_trace(bwd_trace)
# In cases where C0 name is carried from previous trace it must be removed
# as the proxy needs to register with that specific name to follow the backward
# trace standard signature.
new_bwd_trace.names.discard("C0")

swap_map: dict[VariableInterface, TensorProxy] = {}

with tracectx(new_bwd_trace):
required_fw_args = fwd_trace_args & old_saved_for_bwd
recomputed_tensors_from_producers = set()

for prod in producers:
for prod_arg in prod.flat_args:
prod_arg = variableify(prod_arg)
if prod_arg in fwd_trace_args:
required_fw_args.add(prod_arg)
for prod_out in prod.flat_outs:
if isinstance(prod_out, TensorProxy):
swap_map[variableify(prod_out)] = prod_out.replace_name(f"remat_of_{prod_out.name}")
recomputed_tensors_from_producers.add(variableify(prod_out))

required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers
new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd)
unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward))

new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()
new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1]))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None)

# Here we make sure that the signature of the backward trace is the same as the one we expect.
# This part of the trace is the unpacking of the tuple passed from the forward trace,
# more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents
Expand All @@ -3188,10 +3192,11 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward)
new_bwd_trace.bound_symbols.append(new_unpack)
elif idx == 6:
new_bwd_trace.bound_symbols.extend(producers)
new_bwd_trace.bound_symbols.append(bsym)
for p in producers:
new_bwd_trace.bound_symbols.append(p.from_bsym_swap_proxies(swap_map=swap_map))
new_bwd_trace.bound_symbols.append(bsym.from_bsym_swap_proxies(swap_map=swap_map))
else:
new_bwd_trace.bound_symbols.append(bsym)
new_bwd_trace.bound_symbols.append(bsym.from_bsym_swap_proxies(swap_map=swap_map))
riccardofelluga marked this conversation as resolved.
Show resolved Hide resolved

new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]]

Expand Down
23 changes: 18 additions & 5 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ def forward(self, x):
cd = thunder.compile_data(jmodel)
cs = thunder.compile_stats(jmodel)

from copy import copy

fwd_names = fwd_trace.names
bwd_names = bwd_trace.names

# Do not recompute any
cd.compile_options["recomputation_policy"] = lambda x: set()
with compile_data_and_stats(cd, cs):
Expand Down Expand Up @@ -408,6 +413,11 @@ def forward(self, x):
old_saved_for_bwd = {variableify(j) for j in saved_for_bw}

all_rematerializable = old_saved_for_bwd - fwd_trace_args
all_rematerializable_names = {f"remat_for_{unvariableify(x).name}" for x in all_rematerializable}

# Reset names to avoid conflicts
fwd_trace.names = copy(fwd_names)
bwd_trace.names = copy(bwd_names)

cd.compile_options["recomputation_policy"] = lambda x: x
with compile_data_and_stats(cd, cs):
Expand All @@ -416,27 +426,30 @@ def forward(self, x):
# List the outputs after the unpacks
bwd_bsym_out = set(
map(
lambda x: variableify(x.output),
lambda x: x.output.name,
filter(lambda x: isinstance(x.output, TensorProxy), new_bwd.bound_symbols[6:]),
)
)
# check that all the fwd are recomputed
for rematerializable in all_rematerializable:
for rematerializable in all_rematerializable_names:
assert rematerializable in bwd_bsym_out

# Reset names to avoid conflicts
fwd_trace.names = copy(fwd_names)
bwd_trace.names = copy(bwd_names)

# Recompute only one tensor
cd.compile_options["recomputation_policy"] = lambda x: set(filter(lambda i: unvariableify(i).name == "t7", x))
t7 = set(filter(lambda x: unvariableify(x).name == "t7", all_rematerializable))
with compile_data_and_stats(cd, cs):
_, new_bwd = recompute_saved_for_backward(fwd_trace, bwd_trace)

bwd_bsym_out = set(
map(
lambda x: variableify(x.output),
lambda x: x.output.name,
filter(lambda x: isinstance(x.output, TensorProxy), new_bwd.bound_symbols[6:]),
)
)
assert t7 not in bwd_bsym_out, "Unexpected tensor rematerialized in the backward."
assert "remat_for_t7" in bwd_bsym_out, "Unexpected tensor rematerialized in the backward."


def test_lora_transform_linear():
Expand Down
Loading