diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index fa38c0c885..37ebd5a177 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, replace from functools import partial from itertools import chain, product, takewhile -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Dict from collections.abc import Callable from collections.abc import Sequence from collections import defaultdict @@ -12,7 +12,7 @@ from thunder.core import prims, utils from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface from thunder.core.prims import PrimIDs -from thunder.core.proxies import TensorProxy, variableify, NumberProxy +from thunder.core.proxies import TensorProxy, variableify, NumberProxy, CollectionProxy, Proxy from thunder.core.pytree import tree_flatten, tree_unflatten from thunder.core.symbol import has_tags from thunder.core.trace import from_trace, TraceCtx, TraceProvenance @@ -418,6 +418,15 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr assert all(x.sym.id in (distPrimIDs.WAIT, wait_prim_impl.id) for x in waits) wait_outputs = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in waits)) + new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map = ( + match_fw_and_bw_saved_for_bw_proxies(fw_trace, bw_trace) + ) + + wait_outputs = tuple( + new_required_for_backward_fw_to_bw_map[a.name] if a.name in new_required_for_backward_fw_to_bw_map else a + for a in wait_outputs + ) + visited_wait_output = set() # map the output of the original waitop to the output of the new waitop wait_output_replacement_map = {} @@ -508,10 +517,49 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr new_fw_trace = from_trace(fw_trace) new_fw_trace.bound_symbols = list(fw_trace.bound_symbols) + + new_required_for_backward = tuple( + new_required_for_backward_bw_to_fw_map[a.name] if a.name in new_required_for_backward_bw_to_fw_map else a + for a in new_required_for_backward + ) _update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward) return new_fw_trace, new_bw_trace +def match_fw_and_bw_saved_for_bw_proxies( + fw_trace: TraceCtx, bw_trace: TraceCtx +) -> tuple[dict[str, Proxy], dict[str, Proxy]]: + """Outputs required for backward may have different names between forward and backward. + Args: + fw_trace: TraceCtx: Forward trace. + bw_trace: TraceCtx: Backward trace. + + Returns: + new_required_for_bakward_fw_to_bw_map: Dict[str, Proxy]: mapping of bw names to forward proxies + """ + + old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1]) + old_saved_for_backward_bw = [] + for bsym in bw_trace.bound_symbols: + if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE: + flattened_args = tree_flatten(bw_trace.args[1])[0] + proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)} + if all( + not isinstance(out, CollectionProxy) and out.name not in proxy_names + for out in bsym.flat_outs + if out is not None + ): + old_saved_for_backward_bw += bsym.flat_outs + assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw) + new_required_for_backward_bw_to_fw_map = { + x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None + } + new_required_for_backward_fw_to_bw_map = { + y.name: x for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None + } + return new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map + + def rematerialize(trace: TraceCtx) -> TraceCtx: """Rematerialize the trace. @@ -686,6 +734,12 @@ def add_to_swapmap(p): bsym for bsym in joint_extrace.bound_symbols[: len(fw_trace.bound_symbols) - 1] if bsym.sym.id != PrimIDs.DEL ) new_fw_trace.bound_symbols.append(replace(fw_trace.bound_symbols[-1], args=fw_trace.bound_symbols[-1].args)) + + _, new_required_for_backward_bw_to_fw_map = match_fw_and_bw_saved_for_bw_proxies(fw_trace, bw_trace) + new_required_for_backward = tuple( + new_required_for_backward_bw_to_fw_map[a.name] if a.name in new_required_for_backward_bw_to_fw_map else a + for a in new_required_for_backward + ) _update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward) # prims.python_return was updated and now DCE can remove the unused diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 6bee20a165..491ba49233 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -271,11 +271,11 @@ def add_to_swap_map(self, old, new): old = old.replace(shape=new._shape) if isinstance(new, VJPDual): - self.swap_map[variableify(new.primal)] = old + self.swap_map[variableify(old.primal)] = new new.primal = old else: assert isinstance(new, ProxyInterface), (old, new) - self.swap_map[variableify(new)] = old + self.swap_map[variableify(old)] = new def do_swap(self, v): if isinstance(v, VJPDual): @@ -320,6 +320,7 @@ def process_args(self, *args, **kwargs): def __call__(self): with tracectx(self.new_trace): self.unprocessed_bsyms = self.trace.bound_symbols[:] + self.swap_map = {} while self.unprocessed_bsyms: bsym = self.unprocessed_bsyms.pop(0) @@ -338,7 +339,6 @@ def __call__(self): assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms" if self.replacement_result is not self.NULL: - self.swap_map = {} # TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round if not self.new_bsyms: diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 15b7255bd6..9deed5fefd 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3252,6 +3252,22 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr # args will be added from unpack_trivial have_in_backward = saved_tensors | saved_nontensors + from thunder.core.rematerialization import match_fw_and_bw_saved_for_bw_proxies + + new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map = ( + match_fw_and_bw_saved_for_bw_proxies(fwd_trace, bwd_trace) + ) + all_recomputable_proxies = all_recomputable_proxies.union( + OrderedSet( + ( + variableify(new_required_for_backward_fw_to_bw_map[unvariableify(a).name]) + if unvariableify(a).name in new_required_for_backward_fw_to_bw_map + else a + ) + for a in all_recomputable_proxies + ) + ) + def compute_proxy_from_producer(p): vp = variableify(p) if vp in have_in_backward: @@ -3263,7 +3279,10 @@ def compute_proxy_from_producer(p): else: saved_nontensors.add(vp) return - producer_bsym = proxy_names_to_producers[p.name] + if p.name not in proxy_names_to_producers: + producer_bsym = proxy_names_to_producers[new_required_for_backward_bw_to_fw_map[p.name].name] + else: + producer_bsym = proxy_names_to_producers[p.name] for p in producer_bsym.flat_proxy_args: compute_proxy_from_producer(p) for o in producer_bsym.flat_proxy_outs: diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index f4bc3172b3..5cb3353d08 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -29,22 +29,6 @@ # Transforms a trace by determining which execution transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: - start_time_ns = time.perf_counter_ns() - - swapmap: dict[Variable, Proxy] = {} - - def update_swapmap(o: Any, no: Any) -> None: - if isinstance(o, Proxy): - check( - isinstance(no, Proxy), - lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", - ) - - vo = variableify(o) - vno = variableify(no) - if vo == vno: - return - swapmap[vno] = o # This processes the bsyms to map symbols to operator executors: # - if a bsym has a python impl, that will be called, so we can keep it. @@ -104,6 +88,8 @@ def process_bsym(self, bsym): ### OUTPUTS to map self.add_unprocessed_bsyms(bsym.subsymbols[:]) + start_time_ns = time.perf_counter_ns() + extrace, _ = OpExProcessor(trace)() end_time_ns = time.perf_counter_ns() diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 065dd79346..8ed41924f5 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -317,8 +317,8 @@ def func(w, x, y, z): nvf_0 = fusion_bsyms[0] nvf_1 = fusion_bsyms[1] - assert [t.name for t in tree_flatten(nvf_0.args)[0]] == ["t0", "z"] - assert [t.name for t in tree_flatten(nvf_1.args)[0]] == ["t0", "w", "t4"] + assert [t.name for t in tree_flatten(nvf_0.args)[0]] == ["t16", "z"] + assert [t.name for t in tree_flatten(nvf_1.args)[0]] == ["t16", "w", "t4"] assert len(nvf_0.subsymbols) == 4 assert len(nvf_1.subsymbols) == 6 assert [t.name for t in tree_flatten(nvf_0.output)[0]] == ["t4"] diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index b10e0048d2..0751b282cb 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -646,3 +646,166 @@ def test_disable_params_and_buffer_check(): ) assert len(check_bsyms) == 1 # We only have the check for input. + + +def test_buffer_dtype_casting(): + import torch.nn as nn + import itertools + + from typing import Any, Optional, Tuple, Union, List + + class CastBuffers(thunder.core.transform_common.Transform): + def __init__(self): + self.cast_states = {} + + def transform_module(self, model: thunder.ThunderModule): + self.thunder_module = model + for n, b in model._model.named_buffers(): + qb = b.to(torch.bfloat16) + self.cast_states[n] = { + "dtype": b.dtype, + "shape": tuple(b.shape), + "qb.dtype": qb.dtype, + "qb.shape": tuple(qb.shape), + } + model._overrides_buffers[n] = qb + + def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): + tm = self.thunder_module + from thunder.core.trace import tracectx + + checks = thunder.transforms.utils.get_checks(prologue_trace) + + prologue_proxy_map = { + get_param_bsym.output.name: dict( + shape=self.cast_states[model_weight_name]["qb.shape"], + dtype=thunder.dtypes.to_dtype(self.cast_states[model_weight_name]["qb.dtype"]), + ) + for model_weight_name, (check_bsym, get_param_bsym) in checks.items() + if model_weight_name in self.cast_states + } + + # here we switch the prologue_trace to a copy with new metadata + prologue_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata( + prologue_trace, prologue_proxy_map + ) + + checks = thunder.transforms.utils.get_checks(prologue_trace) + for n, qs in self.cast_states.items(): + check, get_param = checks[n] + # check has args: tensor, shape, device, dtype, requires_grad + proxy, _, device, _, requires_grad = check.args + check.args = ( + proxy, + qs["qb.shape"], + device, + qs["qb.dtype"], + False, + ) + + computation_proxy_map = { + csym.name: dict( + shape=psym.shape, + dtype=psym.dtype, + ) + for psym, csym in zip(prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args) + if psym.shape != csym.shape or psym.dtype != csym.dtype + } + + new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata( + computation_trace, computation_proxy_map + ) + + producers, consumers = thunder.core.utils.producers_and_consumers(new_computation_trace) + + bound_symbols = new_computation_trace.bound_symbols + new_computation_trace.bound_symbols = [] + + new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] + + computation_proxy_map = {} + new_bound_symbols = [] + for bsym in bound_symbols: + if bsym.sym == thunder.torch.to and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial: + inp = bsym.args[0] + args = (inp, inp.dtype, *bsym.args[2:]) + computation_proxy_map[bsym.output.name] = dict(shape=inp.shape, dtype=inp.dtype) + assert ( + len(bsym.subsymbols) == 1 and bsym.subsymbols[0].sym == thunder.core.prims.convert_element_type + ) + subsymbols = [bsym.subsymbols[0].from_bsym(args=(inp, inp.dtype))] + new_bound_symbols.append(bsym.from_bsym(args=args, subsymbols=subsymbols)) + else: + new_bound_symbols.append(bsym.from_bsym()) + + new_computation_trace.bound_symbols = new_bound_symbols + + new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata( + new_computation_trace, computation_proxy_map + ) + + new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("Dtype Convert")) + return prologue_trace, new_computation_trace, epilogue_trace + + class cast(nn.Module): + def __init__( + self, + k_shape: tuple[int, int, int, int], + v_shape: tuple[int, int, int, int], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) + + def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + return self.k, self.v + + # BUG: issue: 1637 + class ParentModule(nn.Module): + def __init__( + self, + k_shape: tuple[int, int, int, int], + v_shape: tuple[int, int, int, int], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + super().__init__() + self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype) + + def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return self.cast_module(k, v) + + with torch.device("cpu"): + k_shape = (2, 3, 4, 5) + v_shape = (2, 3, 4, 5) + device = torch.device("cpu") + dtype = torch.float32 + model = ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False) + + k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half) + v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half) + cast_jit = thunder.jit( + model, + transforms=[ + CastBuffers(), + ], + ) + output_k, output_v = cast_jit(k, v) + + def check_dtypes(bsym): + for a in itertools.chain(bsym.flat_args, bsym.flat_outs): + if isinstance(a, thunder.TensorProxy): + assert a.dtype == thunder.dtypes.bfloat16 + for sbsym in bsym.subsymbols: + check_dtypes(sbsym) + + for tr in thunder.last_traces(cast_jit): + if str(tr.get_provenance()) == "# Constructed by Dtype Convert": + for bsym in tr.bound_symbols: + check_dtypes(bsym)