Skip to content

Commit b82f59c

Browse files
authored
use transform for execution to get torch_compile executable (#1500)
1 parent a3cdbc4 commit b82f59c

File tree

10 files changed

+352
-119
lines changed

10 files changed

+352
-119
lines changed

thunder/core/symbol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def __reduce__(self): # for pickling
231231
raise ValueError("Cannot serialize a symbol without a module and executor.")
232232

233233
if self.executor is None:
234-
assert getattr(sys.modules[self.module.__name__], self.name, None) is self
234+
assert (
235+
getattr(sys.modules[self.module.__name__], self.name, None) is self
236+
), f"{self.module.__name__}.{self.name} is not {self}"
235237
else:
236238
assert thunder.get_executor(self.executor.name).opmap.get(self.name) is self
237239

thunder/core/trace_interpreter.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import Any
23

34
from thunder.core import prims
@@ -45,9 +46,9 @@ def read(x: VariableInterface | Any) -> Any:
4546
def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
4647
if not isinstance(v, VariableInterface):
4748
return
48-
# Duplicates are allowed and overwritten
4949
if v.name in env:
5050
if allow_duplicates:
51+
# Duplicates are allowed and not overwritten
5152
return
5253
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
5354
env[v.name] = val
@@ -104,9 +105,9 @@ def read(x: VariableInterface | Any) -> Any:
104105
def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None:
105106
if not isinstance(v, VariableInterface):
106107
return
107-
# Duplicates are allowed and overwritten
108108
if v.name in env:
109109
if allow_duplicates:
110+
# Duplicates are allowed and not overwritten
110111
return
111112
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
112113
env[v.name] = val
@@ -203,3 +204,179 @@ def do_swap(v):
203204
return new_trace, tree_map(read, trace.output), env
204205

205206
return new_trace, tree_map(read, trace.output)
207+
208+
209+
class TraceSubstitutionProcessor:
210+
"""This processes a trace in an interpretation-style way by looping over the bound symbols.
211+
This processing aims to preserve as much information on the proxies as possible.
212+
213+
Args:
214+
trace: trace to process
215+
*args: arguments to process the trace with
216+
**kwargs: keyword arguments to process the trace with
217+
218+
The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function.
219+
220+
Calling the instantiated object initiates the processing and returns
221+
the new trace and a mapping of the outputs.
222+
223+
See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing.
224+
"""
225+
226+
NULL = object()
227+
228+
def __init__(self, trace, *args, **kwargs):
229+
self.env = {}
230+
self.trace = trace
231+
self.new_trace = from_trace(self.trace)
232+
self.have_processed_args = False
233+
234+
def read(self, x: VariableInterface | Any) -> Any:
235+
if isinstance(x, VariableInterface):
236+
return self.env[x.name]
237+
else:
238+
return x
239+
240+
def write(self, v: VariableInterface | Any, val: Any, allow_duplicates=True) -> None:
241+
if not isinstance(v, VariableInterface):
242+
return
243+
if v.name in self.env:
244+
if allow_duplicates:
245+
# Duplicates are allowed and not overwritten
246+
return
247+
raise ValueError(f"Variable {v.name} is being overwritten this is not allowed")
248+
self.env[v.name] = val
249+
250+
def add_to_swap_map(self, old, new):
251+
if old is new:
252+
return
253+
if isinstance(old, ProxyInterface):
254+
if isinstance(new, ProxyInterface) and variableify(new) in self.env:
255+
# the new isn't new, but something returned the input
256+
# this means we need to map the old to the new
257+
old, new = new, old
258+
elif isinstance(old, TensorProxyInterface):
259+
# should we have a fix shapes pass? the sharding
260+
# (FSDP, tensor parallel) transforms do "break" shape metadata
261+
self.new_trace.names.remove(old.name) # taken by the .replace proxy
262+
if isinstance(new, VJPDual):
263+
old = old.replace(shape=new.primal._shape)
264+
else:
265+
old = old.replace(shape=new._shape)
266+
267+
if isinstance(new, VJPDual):
268+
self.swap_map[variableify(new.primal)] = old
269+
new.primal = old
270+
else:
271+
assert isinstance(new, ProxyInterface), (old, new)
272+
self.swap_map[variableify(new)] = old
273+
274+
def do_swap(self, v):
275+
if isinstance(v, VJPDual):
276+
v.primal = tree_map(self.do_swap, v.primal)
277+
v.residuals = tree_map(self.do_swap, v.residuals)
278+
return v
279+
if not isinstance(v, ProxyInterface):
280+
return v
281+
return self.swap_map.get(variableify(v), v)
282+
283+
def add_unprocessed_bsyms(self, bsyms):
284+
self.unprocessed_bsyms[:0] = bsyms
285+
286+
def add_bsyms_from_function(self, fn, /, *args, **kwargs):
287+
self.new_trace.push_scope([])
288+
result = fn(*args, **kwargs)
289+
self.new_bsyms += self.new_trace.pop_scope()
290+
self.set_result(result)
291+
return result
292+
293+
def add_processed_bsyms(self, bsyms):
294+
self.new_bsyms += bsyms
295+
296+
def set_result(self, result):
297+
self.replacement_result = result
298+
299+
def process_bsym(self, bsym):
300+
raise NotImplementedError("This needs to be implemented in subclasses")
301+
302+
def process_args(self, *args, **kwargs):
303+
self.have_processed_args = True
304+
with tracectx(self.new_trace):
305+
self.swap_map = {}
306+
307+
safe_map_flat(self.add_to_swap_map, list(self.trace.args), list(args))
308+
safe_map_flat(self.add_to_swap_map, list(self.trace.kwargs.values()), list(kwargs.values()))
309+
args, kwargs = tree_map(self.do_swap, (args, kwargs))
310+
311+
safe_map_flat(self.write, list(self.trace.args), list(args))
312+
safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values()))
313+
314+
def __call__(self):
315+
with tracectx(self.new_trace):
316+
self.unprocessed_bsyms = self.trace.bound_symbols[:]
317+
318+
while self.unprocessed_bsyms:
319+
bsym = self.unprocessed_bsyms.pop(0)
320+
321+
if self.have_processed_args and bsym.sym.id in trace_interpreter_skip_list:
322+
self.new_trace.bound_symbols.append(bsym.from_bsym())
323+
continue
324+
325+
args = tree_map(self.read, bsym.args)
326+
kwargs = tree_map(self.read, bsym.kwargs)
327+
328+
# this should be prettier
329+
self.replacement_result = self.NULL
330+
self.new_bsyms = []
331+
332+
self.process_bsym(bsym)
333+
334+
if self.new_bsyms:
335+
assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms"
336+
337+
if self.replacement_result is not self.NULL:
338+
self.swap_map = {}
339+
340+
# TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round
341+
if not self.new_bsyms:
342+
# empty result means we want to swap references to the old
343+
# result to the new result (which will be one of the args)
344+
safe_map_flat(
345+
self.add_to_swap_map,
346+
list(sequencify(self.replacement_result)),
347+
list(sequencify(bsym.output)),
348+
)
349+
else:
350+
safe_map_flat(
351+
self.add_to_swap_map,
352+
list(sequencify(bsym.output)),
353+
list(sequencify(self.replacement_result)),
354+
)
355+
356+
### replace bsyms
357+
358+
for new_bsym in self.new_bsyms:
359+
# TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym?
360+
self.new_trace.bound_symbols.append(
361+
new_bsym.from_bsym_swap_proxies(self.swap_map).from_bsym(
362+
source_filename=bsym.source_filename, source_positions=bsym.source_positions
363+
)
364+
)
365+
366+
result = tree_map(self.do_swap, self.replacement_result)
367+
368+
# we need to allow duplicates here because the re-interpretation is not necessairly DCEed when subsymbols symbols are flattened into the trace after re-execution.
369+
try:
370+
safe_map_flat(
371+
partial(self.write, allow_duplicates=True),
372+
list(sequencify(bsym.output)),
373+
list(sequencify(result)),
374+
)
375+
except AssertionError as e:
376+
raise RuntimeError(
377+
f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {bsym}."
378+
" This is likely due to a mismatch in the number of outputs."
379+
f" The original symbol has {len(bsym.output)} outputs and the dispatched function has {len(sequencify(result))} outputs."
380+
) from e
381+
382+
return self.new_trace, tree_map(self.read, self.trace.output)

thunder/core/transform_common.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import thunder
1111
import thunder.core.prims as prims
12-
from thunder.core.baseutils import BoundSymbolInterface
12+
from thunder.core.baseutils import BoundSymbolInterface, NumberProxyInterface
1313
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify
1414
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
1515
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
@@ -111,6 +111,29 @@ def check(inp, log_str):
111111
check(copy_to_out, "output")
112112

113113

114+
def remove_duplicate_number_proxies(bsyms: Sequence[BoundSymbol]) -> list[BoundSymbol]:
115+
"""This removes duplicate number proxies when they are returned multiple times.
116+
The remaining DCE pass does not see them (because they often are in a tuple?).
117+
In particular, proxies may be extracted multiple times when using the thunder.jit's
118+
symbolic constraints mode.
119+
"""
120+
seen = set()
121+
122+
def keep_or_swap(p):
123+
if not isinstance(p, NumberProxyInterface):
124+
return p
125+
if p.name in seen:
126+
return p.value # don't make it a duplicate
127+
seen.add(p.name)
128+
return p
129+
130+
new_bsyms = []
131+
for bsym in bsyms:
132+
output = tree_map(keep_or_swap, bsym.output)
133+
new_bsyms.append(bsym.from_bsym(output=output))
134+
return new_bsyms
135+
136+
114137
# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
115138
# improve performance
116139
# Runs a Dead Code Elimination (DCE) pass
@@ -174,7 +197,11 @@ def _helper(x):
174197
needed_proxies.add(variableify(x))
175198

176199
dcetrace = from_trace(trace)
177-
dcetrace.bound_symbols = list(reversed(dced))
200+
dced_bound_symbols = list(reversed(dced))
201+
# duplicate number proxies happen with the symbolic shapes and are
202+
# not covered by the above (due to being in tuples?).
203+
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
204+
dcetrace.bound_symbols = dced_bound_symbols
178205

179206
end_time_ns = time.perf_counter_ns()
180207
elapsed_time_ns = end_time_ns - start_time_ns

thunder/core/vjp_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from collections.abc import Callable
2+
from collections.abc import Callable, Sequence
33
from functools import wraps
44
from inspect import Parameter, Signature
55
from itertools import chain
@@ -229,3 +229,22 @@ def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]:
229229
lambda: "All saved tensors must be TensorProxy or None",
230230
)
231231
return tuple(saved_tensors)
232+
233+
234+
def set_saved_for_backward_tensors(trace: TraceCtx, saved_tensors: Sequence[TensorProxy]):
235+
"""
236+
Given a trace, return the tensors that are saved for backward in the trace.
237+
238+
Args:
239+
trace: The trace to set saved tensors for.
240+
saved_tensors: proxies for the tensors to save.
241+
"""
242+
utils.check(
243+
all(isinstance(t, TensorProxy) or t is None for t in saved_tensors),
244+
lambda: "All saved tensors must be TensorProxy or None",
245+
)
246+
ret_node = trace.bound_symbols.pop(-1)
247+
assert ret_node.sym == prims.python_return
248+
output = ret_node.args
249+
output = (output[0], (tuple(saved_tensors), *output[1][1:]), *output[2:])
250+
trace.bound_symbols.append(ret_node.from_bsym(args=output))

0 commit comments

Comments
 (0)