Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use transform for execution to get torch_compile executable #1500

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def __reduce__(self): # for pickling
raise ValueError("Cannot serialize a symbol without a module and executor.")

if self.executor is None:
assert getattr(sys.modules[self.module.__name__], self.name, None) is self
assert (
getattr(sys.modules[self.module.__name__], self.name, None) is self
), f"{self.module.__name__}.{self.name} is not {self}"
else:
assert thunder.get_executor(self.executor.name).opmap.get(self.name) is self

Expand Down
181 changes: 179 additions & 2 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any

from thunder.core import prims
Expand Down Expand Up @@ -45,9 +46,9 @@ def read(x: VariableInterface | Any) -> Any:
def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
if not isinstance(v, VariableInterface):
return
# Duplicates are allowed and overwritten
if v.name in env:
if allow_duplicates:
# Duplicates are allowed and not overwritten
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
env[v.name] = val
Expand Down Expand Up @@ -104,9 +105,9 @@ def read(x: VariableInterface | Any) -> Any:
def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
if not isinstance(v, VariableInterface):
return
# Duplicates are allowed and overwritten
if v.name in env:
if allow_duplicates:
# Duplicates are allowed and not overwritten
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
env[v.name] = val
Expand Down Expand Up @@ -203,3 +204,179 @@ def do_swap(v):
return new_trace, tree_map(read, trace.output), env

return new_trace, tree_map(read, trace.output)


class TraceSubstitutionProcessor:
t-vi marked this conversation as resolved.
Show resolved Hide resolved
"""This processes a trace in an interpretation-style way by looping over the bound symbols.
This processing aims to preserve as much information on the proxies as possible.

Args:
trace: trace to process
*args: arguments to process the trace with
**kwargs: keyword arguments to process the trace with

The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function.

Calling the instantiated object initiates the processing and returns
the new trace and a mapping of the outputs.

See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing.
"""

NULL = object()

def __init__(self, trace, *args, **kwargs):
self.env = {}
self.trace = trace
self.new_trace = from_trace(self.trace)
self.have_processed_args = False

def read(self, x: VariableInterface | Any) -> Any:
if isinstance(x, VariableInterface):
return self.env[x.name]
else:
return x

def write(self, v: VariableInterface | Any, val: Any, allow_duplicates=True) -> None:
if not isinstance(v, VariableInterface):
return
if v.name in self.env:
if allow_duplicates:
# Duplicates are allowed and not overwritten
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
self.env[v.name] = val

def add_to_swap_map(self, old, new):
if old is new:
return
if isinstance(old, ProxyInterface):
if isinstance(new, ProxyInterface) and variableify(new) in self.env:
# the new isn't new, but something returned the input
# this means we need to map the old to the new
old, new = new, old
elif isinstance(old, TensorProxyInterface):
lantiga marked this conversation as resolved.
Show resolved Hide resolved
# should we have a fix shapes pass? the sharding
# (FSDP, tensor parallel) transforms do "break" shape metadata
self.new_trace.names.remove(old.name) # taken by the .replace proxy
if isinstance(new, VJPDual):
old = old.replace(shape=new.primal._shape)
else:
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
self.swap_map[variableify(new.primal)] = old
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
self.swap_map[variableify(new)] = old

def do_swap(self, v):
if isinstance(v, VJPDual):
v.primal = tree_map(self.do_swap, v.primal)
v.residuals = tree_map(self.do_swap, v.residuals)
return v
if not isinstance(v, ProxyInterface):
return v
return self.swap_map.get(variableify(v), v)

def add_unprocessed_bsyms(self, bsyms):
self.unprocessed_bsyms[:0] = bsyms
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could probably add self.unprocessed_bsyms and self.new_bsyms to __init__ for readability - not mandatory


def add_bsyms_from_function(self, fn, /, *args, **kwargs):
self.new_trace.push_scope([])
result = fn(*args, **kwargs)
self.new_bsyms += self.new_trace.pop_scope()
self.set_result(result)
return result

def add_processed_bsyms(self, bsyms):
self.new_bsyms += bsyms

def set_result(self, result):
self.replacement_result = result

def process_bsym(self, bsym):
raise NotImplementedError("This needs to be implemented in subclasses")

def process_args(self, *args, **kwargs):
self.have_processed_args = True
with tracectx(self.new_trace):
self.swap_map = {}

safe_map_flat(self.add_to_swap_map, list(self.trace.args), list(args))
safe_map_flat(self.add_to_swap_map, list(self.trace.kwargs.values()), list(kwargs.values()))
args, kwargs = tree_map(self.do_swap, (args, kwargs))

safe_map_flat(self.write, list(self.trace.args), list(args))
safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values()))

def __call__(self):
with tracectx(self.new_trace):
self.unprocessed_bsyms = self.trace.bound_symbols[:]

while self.unprocessed_bsyms:
bsym = self.unprocessed_bsyms.pop(0)

if self.have_processed_args and bsym.sym.id in trace_interpreter_skip_list:
self.new_trace.bound_symbols.append(bsym.from_bsym())
continue

args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)

# this should be prettier
self.replacement_result = self.NULL
self.new_bsyms = []

self.process_bsym(bsym)

if self.new_bsyms:
assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms"

if self.replacement_result is not self.NULL:
self.swap_map = {}

# TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round
if not self.new_bsyms:
# empty result means we want to swap references to the old
# result to the new result (which will be one of the args)
safe_map_flat(
self.add_to_swap_map,
list(sequencify(self.replacement_result)),
list(sequencify(bsym.output)),
)
else:
safe_map_flat(
self.add_to_swap_map,
list(sequencify(bsym.output)),
list(sequencify(self.replacement_result)),
)

### replace bsyms

for new_bsym in self.new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
self.new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(self.swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
)
)

result = tree_map(self.do_swap, self.replacement_result)

# we need to allow duplicates here because the re-interpretation is not necessairly DCEed when subsymbols symbols are flattened into the trace after re-execution.
try:
safe_map_flat(
partial(self.write, allow_duplicates=True),
list(sequencify(bsym.output)),
list(sequencify(result)),
)
except AssertionError as e:
raise RuntimeError(
f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {bsym}."
" This is likely due to a mismatch in the number of outputs."
f" The original symbol has {len(bsym.output)} outputs and the dispatched function has {len(sequencify(result))} outputs."
) from e

return self.new_trace, tree_map(self.read, self.trace.output)
31 changes: 29 additions & 2 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import thunder
import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.baseutils import BoundSymbolInterface, NumberProxyInterface
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
Expand Down Expand Up @@ -111,6 +111,29 @@ def check(inp, log_str):
check(copy_to_out, "output")


def remove_duplicate_number_proxies(bsyms: Sequence[BoundSymbol]) -> list[BoundSymbol]:
"""This removes duplicate number proxies when they are returned multiple times.
The remaining DCE pass does not see them (because they often are in a tuple?).
In particular, proxies may be extracted multiple times when using the thunder.jit's
symbolic constraints mode.
"""
seen = set()

def keep_or_swap(p):
if not isinstance(p, NumberProxyInterface):
return p
if p.name in seen:
return p.value # don't make it a duplicate
seen.add(p.name)
return p

new_bsyms = []
for bsym in bsyms:
output = tree_map(keep_or_swap, bsym.output)
new_bsyms.append(bsym.from_bsym(output=output))
return new_bsyms


# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
# improve performance
# Runs a Dead Code Elimination (DCE) pass
Expand Down Expand Up @@ -174,7 +197,11 @@ def _helper(x):
needed_proxies.add(variableify(x))

dcetrace = from_trace(trace)
dcetrace.bound_symbols = list(reversed(dced))
dced_bound_symbols = list(reversed(dced))
# duplicate number proxies happen with the symbolic shapes and are
# not covered by the above (due to being in tuples?).
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
dcetrace.bound_symbols = dced_bound_symbols

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
21 changes: 20 additions & 1 deletion thunder/core/vjp_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import wraps
from inspect import Parameter, Signature
from itertools import chain
Expand Down Expand Up @@ -229,3 +229,22 @@ def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]:
lambda: "All saved tensors must be TensorProxy or None",
)
return tuple(saved_tensors)


def set_saved_for_backward_tensors(trace: TraceCtx, saved_tensors: Sequence[TensorProxy]):
"""
Given a trace, return the tensors that are saved for backward in the trace.

Args:
trace: The trace to set saved tensors for.
saved_tensors: proxies for the tensors to save.
"""
utils.check(
all(isinstance(t, TensorProxy) or t is None for t in saved_tensors),
lambda: "All saved tensors must be TensorProxy or None",
)
ret_node = trace.bound_symbols.pop(-1)
assert ret_node.sym == prims.python_return
output = ret_node.args
output = (output[0], (tuple(saved_tensors), *output[1][1:]), *output[2:])
trace.bound_symbols.append(ret_node.from_bsym(args=output))
Loading
Loading