Skip to content

Commit

Permalink
Merge branch 'main' into tom/drop-old-style-distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jan 14, 2025
2 parents a5ea9b3 + 780407d commit e36a686
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 45 deletions.
16 changes: 15 additions & 1 deletion thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,18 @@ def proxify(self, value: WrappedValue) -> Any:
# TensorProxy attributes should be considered derived quantities, so we flag TensorProxies here
value.provenance.ext_flag |= EXT_FLAG_IS_TENSOR_PROXY

# We assume that a Parameter's underlying storage won't be changed.
# We tag Parameter's proxy with `STATIC_MEMORY_LOCATION` tag so that
# other transforms (eg. CUDAGraph) can use this information.
# We tag this here (rather than below in unpack_parameter_or_buffer_or_submodule below) because
# thunderfx does not properly see the module, but does see that we are dealing with a parameter.
if isinstance(uvalue, torch.nn.Parameter):
# NOTE - Update `p_orig` as in Distributed scenario
# it is the proxy for the Parameter on device.
# In `jit(ddp(model))` or `jit(fsdp(model))` scenario,
# proxy `p` will be the proxy for synced parameter.
p_orig.tags.add(ProxyTag.STATIC_MEMORY_LOCATION)

if p is not uvalue:
value.register_proxy(p)
# TODO: other caching modes
Expand Down Expand Up @@ -1514,7 +1526,9 @@ def unpack_parameter_or_buffer_or_submodule(provenance, *, new_output=False):
name = ".".join(name)
if typ == "_parameters":
bsym = prims.unpack_parameter.bind(root_module, name, output=output)
output.tags.add(ProxyTag.STATIC_MEMORY_LOCATION)
assert (
ProxyTag.STATIC_MEMORY_LOCATION in output.tags
), "Parameter was not tagged with STATIC_MEMORY_LOCATION"
elif typ == "_buffers":
bsym = prims.unpack_buffer.bind(root_module, name, output=output)
output.tags.add(ProxyTag.STATIC_MEMORY_LOCATION)
Expand Down
58 changes: 56 additions & 2 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, replace
from functools import partial
from itertools import chain, product, takewhile
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Dict
from collections.abc import Callable
from collections.abc import Sequence
from collections import defaultdict
Expand All @@ -12,7 +12,7 @@
from thunder.core import prims, utils
from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface
from thunder.core.prims import PrimIDs
from thunder.core.proxies import TensorProxy, variableify, NumberProxy
from thunder.core.proxies import TensorProxy, variableify, NumberProxy, CollectionProxy, Proxy
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import has_tags
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
Expand Down Expand Up @@ -418,6 +418,15 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr
assert all(x.sym.id in (distPrimIDs.WAIT, wait_prim_impl.id) for x in waits)
wait_outputs = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in waits))

new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map = (
match_fw_and_bw_saved_for_bw_proxies(fw_trace, bw_trace)
)

wait_outputs = tuple(
new_required_for_backward_fw_to_bw_map[a.name] if a.name in new_required_for_backward_fw_to_bw_map else a
for a in wait_outputs
)

visited_wait_output = set()
# map the output of the original waitop to the output of the new waitop
wait_output_replacement_map = {}
Expand Down Expand Up @@ -508,10 +517,49 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr

new_fw_trace = from_trace(fw_trace)
new_fw_trace.bound_symbols = list(fw_trace.bound_symbols)

new_required_for_backward = tuple(
new_required_for_backward_bw_to_fw_map[a.name] if a.name in new_required_for_backward_bw_to_fw_map else a
for a in new_required_for_backward
)
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)
return new_fw_trace, new_bw_trace


def match_fw_and_bw_saved_for_bw_proxies(
fw_trace: TraceCtx, bw_trace: TraceCtx
) -> tuple[dict[str, Proxy], dict[str, Proxy]]:
"""Outputs required for backward may have different names between forward and backward.
Args:
fw_trace: TraceCtx: Forward trace.
bw_trace: TraceCtx: Backward trace.
Returns:
new_required_for_bakward_fw_to_bw_map: Dict[str, Proxy]: mapping of bw names to forward proxies
"""

old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1])
old_saved_for_backward_bw = []
for bsym in bw_trace.bound_symbols:
if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE:
flattened_args = tree_flatten(bw_trace.args[1])[0]
proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)}
if all(
not isinstance(out, CollectionProxy) and out.name not in proxy_names
for out in bsym.flat_outs
if out is not None
):
old_saved_for_backward_bw += bsym.flat_outs
assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw)
new_required_for_backward_bw_to_fw_map = {
x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None
}
new_required_for_backward_fw_to_bw_map = {
y.name: x for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None
}
return new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map


def rematerialize(trace: TraceCtx) -> TraceCtx:
"""Rematerialize the trace.
Expand Down Expand Up @@ -686,6 +734,12 @@ def add_to_swapmap(p):
bsym for bsym in joint_extrace.bound_symbols[: len(fw_trace.bound_symbols) - 1] if bsym.sym.id != PrimIDs.DEL
)
new_fw_trace.bound_symbols.append(replace(fw_trace.bound_symbols[-1], args=fw_trace.bound_symbols[-1].args))

_, new_required_for_backward_bw_to_fw_map = match_fw_and_bw_saved_for_bw_proxies(fw_trace, bw_trace)
new_required_for_backward = tuple(
new_required_for_backward_bw_to_fw_map[a.name] if a.name in new_required_for_backward_bw_to_fw_map else a
for a in new_required_for_backward
)
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)

# prims.python_return was updated and now DCE can remove the unused
Expand Down
6 changes: 6 additions & 0 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from contextvars import ContextVar
from contextlib import contextmanager
import pathlib
from typing import Optional, Any, Tuple, Type, Dict, List, Union
from collections.abc import Callable
from collections.abc import Sequence, Hashable
Expand Down Expand Up @@ -521,6 +522,11 @@ def python_callable(self, *, global_dicts: None | dict = None, **kwargs: Any) ->
def __repr__(self) -> str:
return self.python(print_depth=-1)

def save_trace(self, filename: str | os.PathLike) -> None:
filename = pathlib.Path(filename)
with open(filename, "w") as f:
f.write(str(self))


# Constructs a new trace by shallow copying parts of an existing trace
# NOTE Bound symbols and provenance are not copied
Expand Down
6 changes: 3 additions & 3 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,11 @@ def add_to_swap_map(self, old, new):
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
self.swap_map[variableify(new.primal)] = old
self.swap_map[variableify(old.primal)] = new
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
self.swap_map[variableify(new)] = old
self.swap_map[variableify(old)] = new

def do_swap(self, v):
if isinstance(v, VJPDual):
Expand Down Expand Up @@ -320,6 +320,7 @@ def process_args(self, *args, **kwargs):
def __call__(self):
with tracectx(self.new_trace):
self.unprocessed_bsyms = self.trace.bound_symbols[:]
self.swap_map = {}

while self.unprocessed_bsyms:
bsym = self.unprocessed_bsyms.pop(0)
Expand All @@ -338,7 +339,6 @@ def __call__(self):
assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms"

if self.replacement_result is not self.NULL:
self.swap_map = {}

# TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round
if not self.new_bsyms:
Expand Down
21 changes: 20 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,22 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr
# args will be added from unpack_trivial
have_in_backward = saved_tensors | saved_nontensors

from thunder.core.rematerialization import match_fw_and_bw_saved_for_bw_proxies

new_required_for_backward_fw_to_bw_map, new_required_for_backward_bw_to_fw_map = (
match_fw_and_bw_saved_for_bw_proxies(fwd_trace, bwd_trace)
)
all_recomputable_proxies = all_recomputable_proxies.union(
OrderedSet(
(
variableify(new_required_for_backward_fw_to_bw_map[unvariableify(a).name])
if unvariableify(a).name in new_required_for_backward_fw_to_bw_map
else a
)
for a in all_recomputable_proxies
)
)

def compute_proxy_from_producer(p):
vp = variableify(p)
if vp in have_in_backward:
Expand All @@ -3263,7 +3279,10 @@ def compute_proxy_from_producer(p):
else:
saved_nontensors.add(vp)
return
producer_bsym = proxy_names_to_producers[p.name]
if p.name not in proxy_names_to_producers:
producer_bsym = proxy_names_to_producers[new_required_for_backward_bw_to_fw_map[p.name].name]
else:
producer_bsym = proxy_names_to_producers[p.name]
for p in producer_bsym.flat_proxy_args:
compute_proxy_from_producer(p)
for o in producer_bsym.flat_proxy_outs:
Expand Down
18 changes: 2 additions & 16 deletions thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,6 @@
# Transforms a trace by determining which execution transforms to call given the list of executors in priority order
# This pass tries to preserve the original trace and proxies.
def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx:
start_time_ns = time.perf_counter_ns()

swapmap: dict[Variable, Proxy] = {}

def update_swapmap(o: Any, no: Any) -> None:
if isinstance(o, Proxy):
check(
isinstance(no, Proxy),
lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}",
)

vo = variableify(o)
vno = variableify(no)
if vo == vno:
return
swapmap[vno] = o

# This processes the bsyms to map symbols to operator executors:
# - if a bsym has a python impl, that will be called, so we can keep it.
Expand Down Expand Up @@ -104,6 +88,8 @@ def process_bsym(self, bsym):
### OUTPUTS to map
self.add_unprocessed_bsyms(bsym.subsymbols[:])

start_time_ns = time.perf_counter_ns()

extrace, _ = OpExProcessor(trace)()

end_time_ns = time.perf_counter_ns()
Expand Down
40 changes: 27 additions & 13 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from thunder.core.symbol import Symbol
import thunder.executors as executors
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS
from thunder.tests.make_tensor import make_tensor
from thunder.tests.make_tensor import make_tensor, make_tensor_like
import thunder.tests.bf16
import thunder.torch as ltorch

Expand Down Expand Up @@ -60,16 +60,6 @@ def round_remainder(x, y):
return x - torch.round(x / y) * y


def push_away_from_singularities(x, singularity_fn, eps):
"""This function takes a tensor and moves individual values away
from singularities in `eps` increments, until they are further than
`eps` away from them. The `singularity_fn` returns the (signed)
distance from `x` to the nearest singularity."""
x_dist = singularity_fn(x)
x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x)
return torch.where((x_dist <= 0) & (x_dist > -eps), x_ - eps, x_)


# Randomly select a fraction of the elements in a tensor and set them to specified value
def replace_random_percentage(a: torch.Tensor, value: Number, percentage: float) -> torch.Tensor:
flat = torch.flatten(a.detach().clone())
Expand Down Expand Up @@ -208,10 +198,24 @@ def _to(x):
args, kwargs = tree_map(_to, self.args), tree_map(_to, self.kwargs)
return SampleInput(*args, **kwargs)

def remove_singularities(self, singularity_fn, eps):
def remove_singularities(self, op, eps):

singularity_fn = op.singularity_fn_producer(self)
if singularity_fn is None:
return self

def _push_away_from_singularities(x, dist_fn, eps):
"""This function takes a tensor and moves individual values away
from singularities in `eps` increments, until they are further than
`eps` away from them. The `dist_fn` returns the (signed)
distance from `x` to the nearest singularity."""
x_dist = dist_fn(x)
x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x)
return torch.where((x_dist < 0) & (x_dist > -eps), x_ - eps, x_)

def _remove_singularities(x):
if isinstance(x, torch.Tensor) and datatypes.is_float_dtype(datatypes.to_dtype(x)):
return push_away_from_singularities(x, singularity_fn, eps)
return _push_away_from_singularities(x, singularity_fn, eps)

return x

Expand Down Expand Up @@ -2195,18 +2199,28 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs):
)
elementwise_binary_ops.append(lt_opinfo)


def min_max_singularity_fn_producer(sample):
a, b = sample.args
if a.shape == b.shape or b.shape == ():
return lambda x: x - b if x is a else make_tensor_like(x, low=1)
return lambda x: x - a if x is b else make_tensor_like(x, low=1)


maximum_opinfo = OpInfo(
clang.maximum,
sample_input_generator=partial(elementwise_binary_generator, no_rhs_numbers=True),
torch_reference=torch.maximum,
supports_grad=True,
singularity_fn_producer=min_max_singularity_fn_producer,
)
elementwise_binary_ops.append(maximum_opinfo)

minimum_opinfo = OpInfo(
clang.minimum,
sample_input_generator=partial(elementwise_binary_generator, no_rhs_numbers=True),
torch_reference=torch.minimum,
singularity_fn_producer=min_max_singularity_fn_producer,
)
elementwise_binary_ops.append(minimum_opinfo)

Expand Down
29 changes: 29 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import operator
import os
import tempfile
import traceback
from functools import partial, reduce
from itertools import product
Expand Down Expand Up @@ -3125,3 +3127,30 @@ def test_proxy_same_name():
t = TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32, requires_grad=True)
with pytest.raises(RuntimeError, match="already used"):
t2 = TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32, requires_grad=True)


def test_save_trace():
def fn(x):
return x + 1

jfn = thunder.jit(fn)
jfn(
torch.rand(
3,
)
)

fwd_trace = thunder.last_traces(jfn)[-1]

with tempfile.TemporaryDirectory() as tmp_dir:
trace_name = os.path.join(tmp_dir, "tmp_trace.py")
fwd_trace.save_trace(trace_name)

with open(trace_name) as f:
trace_contents = f.readlines()

# Verify we find a few expected things in the
# saved trace.
trace_contents = "".join(trace_contents)
assert ".add" in trace_contents
assert "@torch.no_grad" in trace_contents
Loading

0 comments on commit e36a686

Please sign in to comment.