Skip to content

Commit

Permalink
address review comments, thank you Mike!
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 3, 2024
1 parent 2654229 commit 4f798aa
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 30 deletions.
22 changes: 17 additions & 5 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,22 @@ def do_swap(v):


class TraceSubstitutionProcessor:
"""This processes a trace in an interpretation-style way by looping over the bound symbols.
This processing aims to preserve as much information on the proxies as possible.
Args:
trace: trace to process
*args: arguments to process the trace with
**kwargs: keyword arguments to process the trace with
The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function.
Calling the instantiated object initiates the processing and returns
the new trace and a mapping of the outputs.
See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing.
"""

NULL = object()

def __init__(self, trace, *args, **kwargs):
Expand Down Expand Up @@ -267,16 +283,14 @@ def do_swap(self, v):
def add_unprocessed_bsyms(self, bsyms):
self.unprocessed_bsyms[:0] = bsyms

def bsyms_from_function(self, fn, /, *args, **kwargs):
def add_bsyms_from_function(self, fn, /, *args, **kwargs):
self.new_trace.push_scope([])
result = fn(*args, **kwargs)
self.new_bsyms += self.new_trace.pop_scope()
self.set_result(result)
return result

def add_processed_bsyms(self, bsyms):

### replacements of inputs!
self.new_bsyms += bsyms

def set_result(self, result):
Expand All @@ -298,8 +312,6 @@ def process_args(self, *args, **kwargs):
safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values()))

def __call__(self):
# if not self.have_processed_args and self.trace.args is not None:
# self.process_args(*self.args, **self.kwargs)
with tracectx(self.new_trace):
self.unprocessed_bsyms = self.trace.bound_symbols[:]

Expand Down
13 changes: 10 additions & 3 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def check(inp, log_str):
check(copy_to_out, "output")


def remove_duplicate_number_proxies(bsyms):
def remove_duplicate_number_proxies(bsyms: Sequence[BoundSymbol]) -> list[BoundSymbol]:
"""This removes duplicate number proxies when they are returned multiple times.
The remaining DCE pass does not see them (because they often are in a tuple?).
In particular, proxies may be extracted multiple times when using the thunder.jit's
symbolic constraints mode.
"""
seen = set()

def keep_or_swap(p):
Expand Down Expand Up @@ -193,8 +198,10 @@ def _helper(x):

dcetrace = from_trace(trace)
dced_bound_symbols = list(reversed(dced))
dced_bound_symbols = dced_bound_symbols
dcetrace.bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
# duplicate number proxies happen with the symbolic shapes and are
# not covered by the above (due to being in tuples?).
dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols)
dcetrace.bound_symbols = dced_bound_symbols

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand Down
33 changes: 14 additions & 19 deletions thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@

# Transforms a trace by determining which execution transforms to call given the list of executors in priority order
# This pass tries to preserve the original trace and proxies.
# Implementation Steps -
# 1. The trace is updated with `visitor_transform` with `visit_helper_` (where executors try to claim the symbols). Note that this replaces the output proxies in the trace.
# 2. `visit_helper_` also creates a swapmap from the new symbols back to old one.
# 3. After the `visitor_transform`, it iterates over the updated trace and puts back the old proxies.
def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx:
start_time_ns = time.perf_counter_ns()

Expand All @@ -50,22 +46,21 @@ def update_swapmap(o: Any, no: Any) -> None:
return
swapmap[vno] = o

def preserve_bsym(bsym: BoundSymbol) -> Any:
trace = get_tracectx()
trace.scopes[-1].append(bsym)
for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args):
trace.names.add(p.name)
return bsym.output

# TODO Consider using an enum for this function's return values
# Tries to find an executor for the BoundSymbol
# If the BoundSymbol already has an executor then None is returned
# If the executor has an execution transform, it's called and True is returned
# If no executor can execute the BoundSymbol, False is returned

# This processes the bsyms to map symbols to operator executors:
# - if a bsym has a python impl, that will be called, so we can keep it.
# - in the order of the executor list
# - if the executor defines an execution transform, call that to
# create symbols for the trace,
# - for operator executors, if we have an implmap entry for the symbol,
# execute that
# - for fusion executors, check if the symbol can be fused (done later)
# - if none of these apply, and the symbol is not a prim, replace the symbol
# with its subsymbols (which will then be processed using the above),
# - if none of the above apply and we have a prim, raise an error
class OpExProcessor(TraceSubstitutionProcessor):
def process_bsym(self, bsym):
if bsym.sym.python_impl is not None:
# keep the bound symbol and use the python impl
self.add_processed_bsyms([bsym])
self.set_result(bsym.output)
return
Expand All @@ -80,15 +75,15 @@ def process_bsym(self, bsym):
execution_transform: None | Callable = ex.get_execution_transform(bsym.sym)
out: Any
if execution_transform is not None:
self.bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs)
self.add_bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs)
return
elif isinstance(ex, OperatorExecutor):
# NOTE execution_transform is None and the executor is an operator executor
# Calls the operator executor's operation
# TODO Instead of directly acquiring the symbol from the implmap, we probably
# want to hide this behind a function
op = ex.implmap[bsym.sym.id].symbol
self.bsyms_from_function(op, *bsym.args, **bsym.kwargs)
self.add_bsyms_from_function(op, *bsym.args, **bsym.kwargs)
return
elif isinstance(ex, FusionExecutor):
# NOTE execution_transform is None and the executor is a fusion executor
Expand Down
4 changes: 1 addition & 3 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,9 @@ def make_compiled(

region_trace.bound_symbols += list(bsyms)
region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None))
# for a in region_trace.args:
# region_trace.add_name(a.name)
for bsym in region_trace.bound_symbols:
for o in bsym.flat_outs:
if o is not None: # TODO: investigate
if o is not None:
region_trace.add_name(o.name)

# maybe make this the default if no sig info is present?
Expand Down

0 comments on commit 4f798aa

Please sign in to comment.