Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] properly propagate swapped proxies in TraceSubstitutionProcessor #1632

Merged
merged 16 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, replace
from functools import partial
from itertools import chain, product, takewhile
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Dict
from collections.abc import Callable
from collections.abc import Sequence
from collections import defaultdict
Expand All @@ -12,7 +12,7 @@
from thunder.core import prims, utils
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface
from thunder.core.prims import PrimIDs
from thunder.core.proxies import TensorProxy, variableify, NumberProxy
from thunder.core.proxies import TensorProxy, variableify, NumberProxy, CollectionProxy, Proxy
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import has_tags
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
Expand Down Expand Up @@ -418,6 +418,15 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr
assert all(x.sym.id in (distPrimIDs.WAIT, wait_prim_impl.id) for x in waits)
wait_outputs = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in waits))

new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map = (
match_fw_and_bw_saved_for_bw_proxies(fw_trace, bw_trace)
)

wait_outputs = tuple(
new_required_for_backward_fw_to_bw_map[a.name] if a.name in new_required_for_backward_fw_to_bw_map else a
for a in wait_outputs
)

visited_wait_output = set()
# map the output of the original waitop to the output of the new waitop
wait_output_replacement_map = {}
Expand Down Expand Up @@ -508,10 +517,49 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr

new_fw_trace = from_trace(fw_trace)
new_fw_trace.bound_symbols = list(fw_trace.bound_symbols)

new_required_for_backward = tuple(
new_required_for_backward_bw_to_fw_map[a.name] if a.name in new_required_for_backward_bw_to_fw_map else a
for a in new_required_for_backward
)
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)
return new_fw_trace, new_bw_trace


def match_fw_and_bw_saved_for_bw_proxies(
fw_trace: TraceCtx, bw_trace: TraceCtx
) -> tuple[dict[str, Proxy], dict[str, Proxy]]:
"""Outputs required for backward may have different names between forward and backward.
Args:
fw_trace: TraceCtx: Forward trace.
bw_trace: TraceCtx: Backward trace.

Returns:
new_required_for_bakward_fw_to_bw_map: Dict[str, Proxy]: mapping of bw names to forward proxies
"""

old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1])
old_saved_for_backward_bw = []
for bsym in bw_trace.bound_symbols:
if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE:
flattened_args = tree_flatten(bw_trace.args[1])[0]
proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)}
if all(
not isinstance(out, CollectionProxy) and out.name not in proxy_names
for out in bsym.flat_outs
if out is not None
):
old_saved_for_backward_bw += bsym.flat_outs
assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw)
new_required_for_backward_bw_to_fw_map = {
x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None
}
new_required_for_backward_fw_to_bw_map = {
y.name: x for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None
}
return new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map


def rematerialize(trace: TraceCtx) -> TraceCtx:
"""Rematerialize the trace.

Expand Down Expand Up @@ -686,6 +734,12 @@ def add_to_swapmap(p):
bsym for bsym in joint_extrace.bound_symbols[: len(fw_trace.bound_symbols) - 1] if bsym.sym.id != PrimIDs.DEL
)
new_fw_trace.bound_symbols.append(replace(fw_trace.bound_symbols[-1], args=fw_trace.bound_symbols[-1].args))

_, new_required_for_backward_bw_to_fw_map = match_fw_and_bw_saved_for_bw_proxies(fw_trace, bw_trace)
new_required_for_backward = tuple(
new_required_for_backward_bw_to_fw_map[a.name] if a.name in new_required_for_backward_bw_to_fw_map else a
for a in new_required_for_backward
)
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)

# prims.python_return was updated and now DCE can remove the unused
Expand Down
6 changes: 3 additions & 3 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,11 @@ def add_to_swap_map(self, old, new):
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
self.swap_map[variableify(new.primal)] = old
self.swap_map[variableify(old.primal)] = new
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
self.swap_map[variableify(new)] = old
self.swap_map[variableify(old)] = new

def do_swap(self, v):
if isinstance(v, VJPDual):
Expand Down Expand Up @@ -320,6 +320,7 @@ def process_args(self, *args, **kwargs):
def __call__(self):
with tracectx(self.new_trace):
self.unprocessed_bsyms = self.trace.bound_symbols[:]
self.swap_map = {}

while self.unprocessed_bsyms:
bsym = self.unprocessed_bsyms.pop(0)
Expand All @@ -338,7 +339,6 @@ def __call__(self):
assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms"

if self.replacement_result is not self.NULL:
self.swap_map = {}

# TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round
if not self.new_bsyms:
Expand Down
21 changes: 20 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,22 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
# args will be added from unpack_trivial
have_in_backward = saved_tensors | saved_nontensors

from thunder.core.rematerialization import match_fw_and_bw_saved_for_bw_proxies

new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map = (
match_fw_and_bw_saved_for_bw_proxies(fwd_trace, bwd_trace)
)
all_recomputable_proxies = all_recomputable_proxies.union(
OrderedSet(
(
variableify(new_required_for_backward_fw_to_bw_map[unvariableify(a).name])
if unvariableify(a).name in new_required_for_backward_fw_to_bw_map
else a
)
for a in all_recomputable_proxies
)
)

def compute_proxy_from_producer(p):
vp = variableify(p)
if vp in have_in_backward:
Expand All @@ -3263,7 +3279,10 @@ def compute_proxy_from_producer(p):
else:
saved_nontensors.add(vp)
return
producer_bsym = proxy_names_to_producers[p.name]
if p.name not in proxy_names_to_producers:
producer_bsym = proxy_names_to_producers[new_required_for_backward_bw_to_fw_map[p.name].name]
else:
producer_bsym = proxy_names_to_producers[p.name]
for p in producer_bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in producer_bsym.flat_proxy_outs:
Expand Down
18 changes: 2 additions & 16 deletions thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,6 @@
# Transforms a trace by determining which execution transforms to call given the list of executors in priority order
# This pass tries to preserve the original trace and proxies.
def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx:
start_time_ns = time.perf_counter_ns()

swapmap: dict[Variable, Proxy] = {}

def update_swapmap(o: Any, no: Any) -> None:
if isinstance(o, Proxy):
check(
isinstance(no, Proxy),
lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}",
)

vo = variableify(o)
vno = variableify(no)
if vo == vno:
return
swapmap[vno] = o

# This processes the bsyms to map symbols to operator executors:
# - if a bsym has a python impl, that will be called, so we can keep it.
Expand Down Expand Up @@ -104,6 +88,8 @@ def process_bsym(self, bsym):
### OUTPUTS to map
self.add_unprocessed_bsyms(bsym.subsymbols[:])

start_time_ns = time.perf_counter_ns()

extrace, _ = OpExProcessor(trace)()

end_time_ns = time.perf_counter_ns()
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def func(w, x, y, z):
nvf_0 = fusion_bsyms[0]
nvf_1 = fusion_bsyms[1]

assert [t.name for t in tree_flatten(nvf_0.args)[0]] == ["t0", "z"]
assert [t.name for t in tree_flatten(nvf_1.args)[0]] == ["t0", "w", "t4"]
assert [t.name for t in tree_flatten(nvf_0.args)[0]] == ["t16", "z"]
assert [t.name for t in tree_flatten(nvf_1.args)[0]] == ["t16", "w", "t4"]
assert len(nvf_0.subsymbols) == 4
assert len(nvf_1.subsymbols) == 6
assert [t.name for t in tree_flatten(nvf_0.output)[0]] == ["t4"]
Expand Down
163 changes: 163 additions & 0 deletions thunder/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,166 @@ def test_disable_params_and_buffer_check():
)

assert len(check_bsyms) == 1 # We only have the check for input.


def test_buffer_dtype_casting():
import torch.nn as nn
import itertools

from typing import Any, Optional, Tuple, Union, List

class CastBuffers(thunder.core.transform_common.Transform):
def __init__(self):
self.cast_states = {}

def transform_module(self, model: thunder.ThunderModule):
self.thunder_module = model
for n, b in model._model.named_buffers():
qb = b.to(torch.bfloat16)
self.cast_states[n] = {
"dtype": b.dtype,
"shape": tuple(b.shape),
"qb.dtype": qb.dtype,
"qb.shape": tuple(qb.shape),
}
model._overrides_buffers[n] = qb

def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
tm = self.thunder_module
from thunder.core.trace import tracectx

checks = thunder.transforms.utils.get_checks(prologue_trace)

prologue_proxy_map = {
get_param_bsym.output.name: dict(
shape=self.cast_states[model_weight_name]["qb.shape"],
dtype=thunder.dtypes.to_dtype(self.cast_states[model_weight_name]["qb.dtype"]),
)
for model_weight_name, (check_bsym, get_param_bsym) in checks.items()
if model_weight_name in self.cast_states
}

# here we switch the prologue_trace to a copy with new metadata
prologue_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
prologue_trace, prologue_proxy_map
)

checks = thunder.transforms.utils.get_checks(prologue_trace)
for n, qs in self.cast_states.items():
check, get_param = checks[n]
# check has args: tensor, shape, device, dtype, requires_grad
proxy, _, device, _, requires_grad = check.args
check.args = (
proxy,
qs["qb.shape"],
device,
qs["qb.dtype"],
False,
)

computation_proxy_map = {
csym.name: dict(
shape=psym.shape,
dtype=psym.dtype,
)
for psym, csym in zip(prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args)
if psym.shape != csym.shape or psym.dtype != csym.dtype
}

new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
computation_trace, computation_proxy_map
)

producers, consumers = thunder.core.utils.producers_and_consumers(new_computation_trace)

bound_symbols = new_computation_trace.bound_symbols
new_computation_trace.bound_symbols = []

new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args]

computation_proxy_map = {}
new_bound_symbols = []
for bsym in bound_symbols:
if bsym.sym == thunder.torch.to and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial:
inp = bsym.args[0]
args = (inp, inp.dtype, *bsym.args[2:])
computation_proxy_map[bsym.output.name] = dict(shape=inp.shape, dtype=inp.dtype)
assert (
len(bsym.subsymbols) == 1 and bsym.subsymbols[0].sym == thunder.core.prims.convert_element_type
)
subsymbols = [bsym.subsymbols[0].from_bsym(args=(inp, inp.dtype))]
new_bound_symbols.append(bsym.from_bsym(args=args, subsymbols=subsymbols))
else:
new_bound_symbols.append(bsym.from_bsym())

new_computation_trace.bound_symbols = new_bound_symbols

new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
new_computation_trace, computation_proxy_map
)

new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("Dtype Convert"))
return prologue_trace, new_computation_trace, epilogue_trace

class cast(nn.Module):
def __init__(
self,
k_shape: tuple[int, int, int, int],
v_shape: tuple[int, int, int, int],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)

def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# move the buffer to the activation dtype for when AMP is used
self.k = self.k.to(k.dtype)
self.v = self.v.to(v.dtype)
# update the cache
return self.k, self.v

# BUG: issue: 1637
class ParentModule(nn.Module):
def __init__(
self,
k_shape: tuple[int, int, int, int],
v_shape: tuple[int, int, int, int],
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype)

def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.cast_module(k, v)

with torch.device("cpu"):
k_shape = (2, 3, 4, 5)
v_shape = (2, 3, 4, 5)
device = torch.device("cpu")
dtype = torch.float32
model = ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False)

k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
cast_jit = thunder.jit(
model,
transforms=[
CastBuffers(),
],
)
output_k, output_v = cast_jit(k, v)

def check_dtypes(bsym):
for a in itertools.chain(bsym.flat_args, bsym.flat_outs):
if isinstance(a, thunder.TensorProxy):
assert a.dtype == thunder.dtypes.bfloat16
for sbsym in bsym.subsymbols:
check_dtypes(sbsym)

for tr in thunder.last_traces(cast_jit):
if str(tr.get_provenance()) == "# Constructed by Dtype Convert":
for bsym in tr.bound_symbols:
check_dtypes(bsym)
Loading