diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 48a0084a7c..edff5937a9 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -207,6 +207,22 @@ def do_swap(v): class TraceSubstitutionProcessor: + """This processes a trace in an interpretation-style way by looping over the bound symbols. + This processing aims to preserve as much information on the proxies as possible. + + Args: + trace: trace to process + *args: arguments to process the trace with + **kwargs: keyword arguments to process the trace with + + The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function. + + Calling the instantiated object initiates the processing and returns + the new trace and a mapping of the outputs. + + See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing. + """ + NULL = object() def __init__(self, trace, *args, **kwargs): @@ -267,7 +283,7 @@ def do_swap(self, v): def add_unprocessed_bsyms(self, bsyms): self.unprocessed_bsyms[:0] = bsyms - def bsyms_from_function(self, fn, /, *args, **kwargs): + def add_bsyms_from_function(self, fn, /, *args, **kwargs): self.new_trace.push_scope([]) result = fn(*args, **kwargs) self.new_bsyms += self.new_trace.pop_scope() @@ -275,8 +291,6 @@ def bsyms_from_function(self, fn, /, *args, **kwargs): return result def add_processed_bsyms(self, bsyms): - - ### replacements of inputs! self.new_bsyms += bsyms def set_result(self, result): @@ -298,8 +312,6 @@ def process_args(self, *args, **kwargs): safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values())) def __call__(self): - # if not self.have_processed_args and self.trace.args is not None: - # self.process_args(*self.args, **self.kwargs) with tracectx(self.new_trace): self.unprocessed_bsyms = self.trace.bound_symbols[:] diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index dc8d053075..bfe4123dc6 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -111,7 +111,12 @@ def check(inp, log_str): check(copy_to_out, "output") -def remove_duplicate_number_proxies(bsyms): +def remove_duplicate_number_proxies(bsyms: Sequence[BoundSymbol]) -> list[BoundSymbol]: + """This removes duplicate number proxies when they are returned multiple times. + The remaining DCE pass does not see them (because they often are in a tuple?). + In particular, proxies may be extracted multiple times when using the thunder.jit's + symbolic constraints mode. + """ seen = set() def keep_or_swap(p): @@ -193,8 +198,10 @@ def _helper(x): dcetrace = from_trace(trace) dced_bound_symbols = list(reversed(dced)) - dced_bound_symbols = dced_bound_symbols - dcetrace.bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) + # duplicate number proxies happen with the symbolic shapes and are + # not covered by the above (due to being in tuples?). + dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) + dcetrace.bound_symbols = dced_bound_symbols end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 6060e0a50a..f7bbe20ca5 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -28,10 +28,6 @@ # Transforms a trace by determining which execution transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. -# Implementation Steps - -# 1. The trace is updated with `visitor_transform` with `visit_helper_` (where executors try to claim the symbols). Note that this replaces the output proxies in the trace. -# 2. `visit_helper_` also creates a swapmap from the new symbols back to old one. -# 3. After the `visitor_transform`, it iterates over the updated trace and puts back the old proxies. def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: start_time_ns = time.perf_counter_ns() @@ -50,22 +46,21 @@ def update_swapmap(o: Any, no: Any) -> None: return swapmap[vno] = o - def preserve_bsym(bsym: BoundSymbol) -> Any: - trace = get_tracectx() - trace.scopes[-1].append(bsym) - for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): - trace.names.add(p.name) - return bsym.output - - # TODO Consider using an enum for this function's return values - # Tries to find an executor for the BoundSymbol - # If the BoundSymbol already has an executor then None is returned - # If the executor has an execution transform, it's called and True is returned - # If no executor can execute the BoundSymbol, False is returned - + # This processes the bsyms to map symbols to operator executors: + # - if a bsym has a python impl, that will be called, so we can keep it. + # - in the order of the executor list + # - if the executor defines an execution transform, call that to + # create symbols for the trace, + # - for operator executors, if we have an implmap entry for the symbol, + # execute that + # - for fusion executors, check if the symbol can be fused (done later) + # - if none of these apply, and the symbol is not a prim, replace the symbol + # with its subsymbols (which will then be processed using the above), + # - if none of the above apply and we have a prim, raise an error class OpExProcessor(TraceSubstitutionProcessor): def process_bsym(self, bsym): if bsym.sym.python_impl is not None: + # keep the bound symbol and use the python impl self.add_processed_bsyms([bsym]) self.set_result(bsym.output) return @@ -80,7 +75,7 @@ def process_bsym(self, bsym): execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) out: Any if execution_transform is not None: - self.bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs) + self.add_bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs) return elif isinstance(ex, OperatorExecutor): # NOTE execution_transform is None and the executor is an operator executor @@ -88,7 +83,7 @@ def process_bsym(self, bsym): # TODO Instead of directly acquiring the symbol from the implmap, we probably # want to hide this behind a function op = ex.implmap[bsym.sym.id].symbol - self.bsyms_from_function(op, *bsym.args, **bsym.kwargs) + self.add_bsyms_from_function(op, *bsym.args, **bsym.kwargs) return elif isinstance(ex, FusionExecutor): # NOTE execution_transform is None and the executor is a fusion executor diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index c688e134b6..345597f669 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -82,11 +82,9 @@ def make_compiled( region_trace.bound_symbols += list(bsyms) region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None)) - # for a in region_trace.args: - # region_trace.add_name(a.name) for bsym in region_trace.bound_symbols: for o in bsym.flat_outs: - if o is not None: # TODO: investigate + if o is not None: region_trace.add_name(o.name) # maybe make this the default if no sig info is present?