Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor recomputation to work with tags #1615

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
from thunder.core.trace import VariableInterface, from_trace, tracectx
from thunder.core.baseutils import ProxyInterface, TensorProxyInterface
from thunder.core.utils import safe_map_flat, sequencify
from thunder.core.proxies import variableify
from thunder.core.proxies import variableify, ProxyTag
from thunder.core.transform_common import VJPDual


@@ -183,6 +183,12 @@ def do_swap(v):

for new_bsym in new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
for o in new_bsym.flat_proxy_outs:
if variableify(o) not in swap_map:
# when we decompose to compute the forward/backward, we mark intermediates as to be recomputed in the backward.
# Typically our decompositions are for things that will then be fused together.
# We could refine this heuristic to exclude "expensive" operations.
o.tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
137 changes: 98 additions & 39 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol, has_tags
from thunder.core.trace import TraceCtx as Trace
from thunder.core.trace import VariableInterface as Variable
from thunder.core.trace import (
@@ -59,6 +59,7 @@
unzip2,
const_as,
sequencify,
OrderedSet,
ProxyDict,
find_producer_symbols,
)
@@ -92,6 +93,7 @@


TraceTag.register_tag("AUGMENTED_FORWARD")
ProxyTag.register_tag("RECOMPUTE_IN_BACKWARD")


# TODO This should be a partial of thunder.trace, but that would cause a circular import
@@ -2991,6 +2993,7 @@ def unpacking_fn(saved_for_backward, cotangents):
)
)
backward_trace.bound_symbols = list((*unpacking_trace.bound_symbols[:-1], *backward_trace_bsyms_without_unpacking))
backward_trace.scopes[0] = backward_trace.bound_symbols


def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> ForwardBackwardTraces:
@@ -3150,9 +3153,15 @@ def backward_fn(saved_for_backward, cotangents):
forward_trace.set_provenance(TraceProvenance("Augmented forward pass"))
backward_trace.set_provenance(TraceProvenance("Backward pass"))

remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option(
"recomputation_policy",
"A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.",
)
enable_saved_for_backward_recomputation: None | bool = get_compile_option(
"enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation."
)
if enable_saved_for_backward_recomputation is None or remat_policy:
enable_saved_for_backward_recomputation = True
if enable_saved_for_backward_recomputation:
forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace)

@@ -3171,71 +3180,121 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr

start_time_ns = time.perf_counter_ns()

saved_for_bw = get_saved_for_backward_tensors(fwd_trace)
fwd_trace_args = {variableify(j) for j in fwd_trace.args}
old_saved_for_bwd = {variableify(j) for j in saved_for_bw}
cd = get_compile_data()
have_nvfuser = any([ex.name == "nvfuser" for ex in cd.executors_list]) if cd is not None else False
if have_nvfuser:
from thunder.core.rematerialization import replace_uniform

all_rematerializable = old_saved_for_bwd - fwd_trace_args
fwd_trace = replace_uniform(fwd_trace)

saved_for_bw = get_saved_for_backward_tensors(fwd_trace)
fwd_trace_args = OrderedSet(variableify(j) for j in fwd_trace.args)
old_saved_for_bwd = OrderedSet(variableify(j) for j in saved_for_bw)

remat_policy: None | Callable[[set[Variable]], set[Variable]] = get_compile_option(
"recomputation_policy",
"A callable that accepts a set of variables and returns a set of the variables that are allowed to be recomputed from the forward in the backward trace. The compile option `enable_saved_for_backward_recomputation` needs to be true for this policy to take effect.",
)

enable_saved_for_backward_recomputation: None | bool = get_compile_option(
"enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation."
)

if remat_policy:
rematerializable = remat_policy(all_rematerializable)
else:
rematerializable = all_rematerializable
for v in remat_policy(old_saved_for_bwd - fwd_trace_args):
unvariableify(v).tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)
elif enable_saved_for_backward_recomputation:
for v in old_saved_for_bwd - fwd_trace_args:
unvariableify(v).tags.add(ProxyTag.RECOMPUTE_IN_BACKWARD)

producers = find_producer_symbols(fwd_trace, tuple(unvariableify(i) for i in rematerializable), fwd_trace.args)
all_proxies = fwd_trace_args.copy()
all_recomputable_proxies = OrderedSet()

required_fw_args = fwd_trace_args & old_saved_for_bwd
recomputed_tensors_from_producers = set()
for prod in producers:
for prod_arg in prod.flat_args:
prod_arg = variableify(prod_arg)
if prod_arg in fwd_trace_args:
required_fw_args.add(prod_arg)
for prod_out in prod.flat_outs:
recomputed_tensors_from_producers.add(variableify(prod_out))
proxy_names_to_producers = {}
for bsym in fwd_trace.bound_symbols:
for o in bsym.flat_proxy_outs:
vo = variableify(o)

required_saved_for_bwd = all_rematerializable - rematerializable - recomputed_tensors_from_producers
new_saved_for_backward = tuple(unvariableify(i) for i in required_fw_args | required_saved_for_bwd)
if vo in all_proxies:
continue

new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()
new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1]))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None)
proxy_names_to_producers[o.name] = bsym
all_proxies.add(vo)
if ProxyTag.RECOMPUTE_IN_BACKWARD in o.tags and not has_tags(bsym, {prims.OpTags.RANDOM_OP}):
all_recomputable_proxies.add(vo)

new_bwd_trace = from_trace(bwd_trace)
# In cases where C0 name is carried from previous trace it must be removed
# as the proxy needs to register with that specific name to follow the backward
# trace standard signature.
new_bwd_trace.names.discard("C0")
rematerializable = old_saved_for_bwd & all_recomputable_proxies

with tracectx(new_bwd_trace):
unpack_args = (CollectionProxy(new_saved_for_backward, name="C0"), len(new_saved_for_backward))
if not rematerializable:
return fwd_trace, bwd_trace

# Here we make sure that the signature of the backward trace is the same as the one we expect.
# This part of the trace is the unpacking of the tuple passed from the forward trace,
# more specifically, C0 unpacks into the saved for backward tensors and C1 into the cotangents
# used to compute the vector-Jacobian product.
# more specifically, C0 unpacks into the saved for backward tensors and C1 into the nontensors
# the cotangents used to compute the vector-Jacobian product are a separate argument.
assert bwd_trace.bound_symbols[2].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[2].args[0].name == "saved_for_backward"
assert bwd_trace.bound_symbols[4].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[4].args[0].name == "C0"
assert bwd_trace.bound_symbols[5].sym.id == prims.PrimIDs.UNPACK_SEQUENCE
assert bwd_trace.bound_symbols[5].args[0].name == "C1"
p_saved_for_backward = bwd_trace.bound_symbols[2].args[0]
p_c0 = bwd_trace.bound_symbols[4].args[0]
p_c1 = bwd_trace.bound_symbols[5].args[0]

saved_tensors = fwd_trace_args & old_saved_for_bwd
saved_nontensors = OrderedSet(variableify(p) for p in p_c1.coll)

new_bwd_trace = from_trace(bwd_trace)

# args will be added from unpack_trivial
have_in_backward = saved_tensors | saved_nontensors

def compute_proxy_from_producer(p):
vp = variableify(p)
if vp in have_in_backward:
return
if vp not in all_recomputable_proxies:
have_in_backward.add(vp)
if isinstance(p, TensorProxy):
saved_tensors.add(vp)
else:
saved_nontensors.add(vp)
return
producer_bsym = proxy_names_to_producers[p.name]
for p in producer_bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in producer_bsym.flat_proxy_outs:
have_in_backward.add(variableify(o))
new_bwd_trace.bound_symbols.append(producer_bsym)

for idx, bsym in enumerate(bwd_trace.bound_symbols):
if idx == 4:
new_unpack = prims.unpack_sequence.bind(*unpack_args, output=new_saved_for_backward)
new_bwd_trace.bound_symbols.append(new_unpack)
elif idx == 6:
new_bwd_trace.bound_symbols.extend(producers)
if idx in {4, 5}:
# handled later
new_bwd_trace.bound_symbols.append(bsym)
else:
for p in bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in bsym.flat_proxy_outs:
have_in_backward.add(variableify(o))
new_bwd_trace.bound_symbols.append(bsym)

new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]]
new_fwd_trace = from_trace(fwd_trace)
new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy()

new_c0 = tuple(unvariableify(i) for i in saved_tensors)
new_c1 = tuple(unvariableify(i) for i in saved_nontensors)

new_return_args = (fwd_trace.output[0], (new_c0, new_c1))
new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None)

p_saved_for_backward.coll = (new_c0, new_c1)
p_c0.coll = new_c0
p_c1.coll = new_c1

new_bwd_trace.bound_symbols[4] = prims.unpack_sequence.bind(p_c0, len(new_c0), output=new_c0)
new_bwd_trace.bound_symbols[5] = prims.unpack_sequence.bind(p_c1, len(new_c1), output=new_c1)
new_bwd_trace.args = [(new_c0, new_c1), *bwd_trace.args[1:]]

elapsed_time_ns = time.perf_counter_ns() - start_time_ns
new_bwd_trace.set_provenance(
7 changes: 6 additions & 1 deletion thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import torch

import thunder.core.utils as utils
from thunder.core.compile_data import get_compile_option
from thunder.core.prims import PrimIDs
from thunder.core.proxies import TensorProxy, variableify
from thunder.core.pytree import tree_flatten
@@ -349,7 +350,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
)
bw_traces.append(bw_extrace)

fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
use_rematerialization: None | bool = get_compile_option(
"use_forward_backward_rematerialization", "use rematerialization of saved for backward values in fusions"
)
if use_rematerialization:
fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
fw_traces.append(fw_extrace)
bw_traces.append(bw_extrace)

4 changes: 1 addition & 3 deletions thunder/tests/distributed/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -116,9 +116,7 @@ def test_rematerialize_all_gather(self):
x = torch.ones((2, 12), device=device)
cm(x).mean().backward()

fwd_trc = [
t for t in thunder.last_traces(cm) if getattr(t.get_provenance(), "pss", "") == "Augmented forward pass"
][0]
fwd_trc = [t for t in thunder.last_traces(cm) if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in t.tags][0]
bwd_trc = thunder.last_backward_traces(cm)[0]
from thunder.core.rematerialization import rematerialize_all_gather

8 changes: 4 additions & 4 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
@@ -113,7 +113,7 @@ def test_nanogpt_block():

# Actual memory usage may vary depending on hardware and cuBLAS settings.
# We are checking the estimated memory against a fixed value for consistency.
assert max_mem_fw[0] == 381754368
assert sum(max_mem_fw[1].values()) == 375462912
assert max_mem_bw[0] == 437292032
assert sum(max_mem_bw[1].values()) == 40934400
assert max_mem_fw[0] == 262183936
assert sum(max_mem_fw[1].values()) == 135306240
assert max_mem_bw[0] == 484833280
assert sum(max_mem_bw[1].values()) == 169915392
29 changes: 29 additions & 0 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
run_snippet,
assert_closer,
IN_CI,
NVFUSER_AVAILABLE,
requiresCUDA,
version_between,
)
@@ -1886,3 +1887,31 @@ def func(x):

torch.testing.assert_close(actual, expected)
torch.testing.assert_close(actual_gr, expected_gr)


@pytest.mark.parametrize("device", ("cuda", "cpu"))
def test_backward_recomputation_decomposed_ops(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

def fn(a):
return torch.nn.functional.gelu(a)

jfn = thunder.jit(fn, enable_saved_for_backward_recomputation=False)
jfn2 = thunder.jit(fn, enable_saved_for_backward_recomputation=True)
Comment on lines +1900 to +1901
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default value now?

a = torch.randn(2, 2, device=device, requires_grad=True)
res = jfn(a)
res2 = jfn2(a)
assert len(res.grad_fn.next_functions[0][0].saved_tensors) == 3 # should be decomposed
assert len(res2.grad_fn.next_functions[0][0].saved_tensors) == 1

if NVFUSER_AVAILABLE and device == "cuda":
# check everything is fused
assert {bsym.sym.name for bsym in thunder.last_backward_traces(jfn2)[-1].bound_symbols} == {
"nvFusion0",
"clear_mutable_collection",
"python_return",
"python_del",
"unpack_sequence",
"unpack_trivial",
}
3 changes: 2 additions & 1 deletion thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -381,7 +381,7 @@ def forward(self, x):
fwd_trace = None

for trace in thunder.last_traces(jmodel):
if str(trace.get_provenance()) == "# Constructed by Augmented forward pass":
if thunder.core.trace.TraceTag.AUGMENTED_FORWARD in trace.tags:
fwd_trace = trace
break

@@ -433,6 +433,7 @@ def forward(self, x):
filter(lambda x: isinstance(x.output, TensorProxy), new_bwd.bound_symbols[6:]),
)
)

# check that all the fwd are recomputed
for rematerializable in all_rematerializable:
assert rematerializable in bwd_bsym_out
Loading
Oops, something went wrong.