Skip to content

Commit

Permalink
use transform for execution to get torch_compile executable (#1500)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored and riccardofelluga committed Dec 5, 2024
1 parent 16f45ea commit 7a839b6
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 119 deletions.
4 changes: 3 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def __reduce__(self): # for pickling
raise ValueError("Cannot serialize a symbol without a module and executor.")

if self.executor is None:
assert getattr(sys.modules[self.module.__name__], self.name, None) is self
assert (
getattr(sys.modules[self.module.__name__], self.name, None) is self
), f"{self.module.__name__}.{self.name} is not {self}"
else:
assert thunder.get_executor(self.executor.name).opmap.get(self.name) is self

Expand Down
181 changes: 179 additions & 2 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Any

from thunder.core import prims
Expand Down Expand Up @@ -45,9 +46,9 @@ def read(x: VariableInterface | Any) -> Any:
def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
if not isinstance(v, VariableInterface):
return
# Duplicates are allowed and overwritten
if v.name in env:
if allow_duplicates:
# Duplicates are allowed and not overwritten
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
env[v.name] = val
Expand Down Expand Up @@ -104,9 +105,9 @@ def read(x: VariableInterface | Any) -> Any:
def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
if not isinstance(v, VariableInterface):
return
# Duplicates are allowed and overwritten
if v.name in env:
if allow_duplicates:
# Duplicates are allowed and not overwritten
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
env[v.name] = val
Expand Down Expand Up @@ -203,3 +204,179 @@ def do_swap(v):
return new_trace, tree_map(read, trace.output), env

return new_trace, tree_map(read, trace.output)


class TraceSubstitutionProcessor:
"""This processes a trace in an interpretation-style way by looping over the bound symbols.
This processing aims to preserve as much information on the proxies as possible.
Args:
trace: trace to process
*args: arguments to process the trace with
**kwargs: keyword arguments to process the trace with
The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function.
Calling the instantiated object initiates the processing and returns
the new trace and a mapping of the outputs.
See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing.
"""

NULL = object()

def __init__(self, trace, *args, **kwargs):
self.env = {}
self.trace = trace
self.new_trace = from_trace(self.trace)
self.have_processed_args = False

def read(self, x: VariableInterface | Any) -> Any:
if isinstance(x, VariableInterface):
return self.env[x.name]
else:
return x

def write(self, v: VariableInterface | Any, val: Any, allow_duplicates=True) -> None:
if not isinstance(v, VariableInterface):
return
if v.name in self.env:
if allow_duplicates:
# Duplicates are allowed and not overwritten
return
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
self.env[v.name] = val

def add_to_swap_map(self, old, new):
if old is new:
return
if isinstance(old, ProxyInterface):
if isinstance(new, ProxyInterface) and variableify(new) in self.env:
# the new isn't new, but something returned the input
# this means we need to map the old to the new
old, new = new, old
elif isinstance(old, TensorProxyInterface):
# should we have a fix shapes pass? the sharding
# (FSDP, tensor parallel) transforms do "break" shape metadata
self.new_trace.names.remove(old.name) # taken by the .replace proxy
if isinstance(new, VJPDual):
old = old.replace(shape=new.primal._shape)
else:
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
self.swap_map[variableify(new.primal)] = old
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
self.swap_map[variableify(new)] = old

def do_swap(self, v):
if isinstance(v, VJPDual):
v.primal = tree_map(self.do_swap, v.primal)
v.residuals = tree_map(self.do_swap, v.residuals)
return v
if not isinstance(v, ProxyInterface):
return v
return self.swap_map.get(variableify(v), v)

def add_unprocessed_bsyms(self, bsyms):
self.unprocessed_bsyms[:0] = bsyms

def add_bsyms_from_function(self, fn, /, *args, **kwargs):
self.new_trace.push_scope([])
result = fn(*args, **kwargs)
self.new_bsyms += self.new_trace.pop_scope()
self.set_result(result)
return result

def add_processed_bsyms(self, bsyms):
self.new_bsyms += bsyms

def set_result(self, result):
self.replacement_result = result

def process_bsym(self, bsym):
raise NotImplementedError("This needs to be implemented in subclasses")

def process_args(self, *args, **kwargs):
self.have_processed_args = True
with tracectx(self.new_trace):
self.swap_map = {}

safe_map_flat(self.add_to_swap_map, list(self.trace.args), list(args))
safe_map_flat(self.add_to_swap_map, list(self.trace.kwargs.values()), list(kwargs.values()))
args, kwargs = tree_map(self.do_swap, (args, kwargs))

safe_map_flat(self.write, list(self.trace.args), list(args))
safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values()))

def __call__(self):
with tracectx(self.new_trace):
self.unprocessed_bsyms = self.trace.bound_symbols[:]

while self.unprocessed_bsyms:
bsym = self.unprocessed_bsyms.pop(0)

if self.have_processed_args and bsym.sym.id in trace_interpreter_skip_list:
self.new_trace.bound_symbols.append(bsym.from_bsym())
continue

args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)

# this should be prettier
self.replacement_result = self.NULL
self.new_bsyms = []

self.process_bsym(bsym)

if self.new_bsyms:
assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms"

if self.replacement_result is not self.NULL:
self.swap_map = {}

# TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round
if not self.new_bsyms:
# empty result means we want to swap references to the old
# result to the new result (which will be one of the args)
safe_map_flat(
self.add_to_swap_map,
list(sequencify(self.replacement_result)),
list(sequencify(bsym.output)),
)
else:
safe_map_flat(
self.add_to_swap_map,
list(sequencify(bsym.output)),
list(sequencify(self.replacement_result)),
)

### replace bsyms

for new_bsym in self.new_bsyms:
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
self.new_trace.bound_symbols.append(
new_bsym.from_bsym_swap_proxies(self.swap_map).from_bsym(
source_filename=bsym.source_filename, source_positions=bsym.source_positions
)
)

result = tree_map(self.do_swap, self.replacement_result)

# we need to allow duplicates here because the re-interpretation is not necessairly DCEed when subsymbols symbols are flattened into the trace after re-execution.
try:
safe_map_flat(
partial(self.write, allow_duplicates=True),
list(sequencify(bsym.output)),
list(sequencify(result)),
)
except AssertionError as e:
raise RuntimeError(
f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {bsym}."
" This is likely due to a mismatch in the number of outputs."
f" The original symbol has {len(bsym.output)} outputs and the dispatched function has {len(sequencify(result))} outputs."
) from e

return self.new_trace, tree_map(self.read, self.trace.output)
31 changes: 29 additions & 2 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import thunder
import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.baseutils import BoundSymbolInterface, NumberProxyInterface
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
Expand Down Expand Up @@ -111,6 +111,29 @@ def check(inp, log_str):
check(copy_to_out, "output")


def remove_duplicate_number_proxies(bsyms: Sequence[BoundSymbol]) -> list[BoundSymbol]:
"""This removes duplicate number proxies when they are returned multiple times.
The remaining DCE pass does not see them (because they often are in a tuple?).
In particular, proxies may be extracted multiple times when using the thunder.jit's
symbolic constraints mode.
"""
seen = set()

def keep_or_swap(p):
if not isinstance(p, NumberProxyInterface):
return p
if p.name in seen:
return p.value # don't make it a duplicate
seen.add(p.name)
return p

new_bsyms = []
for bsym in bsyms:
output = tree_map(keep_or_swap, bsym.output)
new_bsyms.append(bsym.from_bsym(output=output))
return new_bsyms


# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
# improve performance
# Runs a Dead Code Elimination (DCE) pass
Expand Down Expand Up @@ -174,7 +197,11 @@ def _helper(x):
needed_proxies.add(variableify(x))

dcetrace = from_trace(trace)
dcetrace.bound_symbols = list(reversed(dced))
dced_bound_symbols = list(reversed(dced))
# duplicate number proxies happen with the symbolic shapes and are
# not covered by the above (due to being in tuples?).
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
dcetrace.bound_symbols = dced_bound_symbols

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
21 changes: 20 additions & 1 deletion thunder/core/vjp_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import wraps
from inspect import Parameter, Signature
from itertools import chain
Expand Down Expand Up @@ -229,3 +229,22 @@ def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]:
lambda: "All saved tensors must be TensorProxy or None",
)
return tuple(saved_tensors)


def set_saved_for_backward_tensors(trace: TraceCtx, saved_tensors: Sequence[TensorProxy]):
"""
Given a trace, return the tensors that are saved for backward in the trace.
Args:
trace: The trace to set saved tensors for.
saved_tensors: proxies for the tensors to save.
"""
utils.check(
all(isinstance(t, TensorProxy) or t is None for t in saved_tensors),
lambda: "All saved tensors must be TensorProxy or None",
)
ret_node = trace.bound_symbols.pop(-1)
assert ret_node.sym == prims.python_return
output = ret_node.args
output = (output[0], (tuple(saved_tensors), *output[1][1:]), *output[2:])
trace.bound_symbols.append(ret_node.from_bsym(args=output))
Loading

0 comments on commit 7a839b6

Please sign in to comment.