From 8517e57bce554ddaa4d3f6ce2884a5718bd60638 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 22 Mar 2024 14:42:32 -0700 Subject: [PATCH] only use executor lookasides when their executor is used --- thunder/__init__.py | 6 ++++++ thunder/common.py | 4 ++++ thunder/core/jit_ext.py | 26 +++++++++++++++++++++++--- thunder/extend/__init__.py | 11 ++--------- thunder/tests/test_extend.py | 7 +++++++ 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 5d6f698a76..21ba6499ec 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -349,6 +349,11 @@ def jit( # if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None: # sharp_edges = SHARP_EDGES_OPTIONS.WARN + executor_lookasides = {} + for ex in executors or []: + # TODO: sharp edge if lookasides are shadowed? + executor_lookasides.update(ex._lookasides) + # TODO RC1 Refine the compile data option to remove unused options cd = CompileData( fn=fn, @@ -364,6 +369,7 @@ def jit( only_execute_prims=False, disable_preprocessing=True, compile_options=compile_options, + executor_lookasides=executor_lookasides, ) cs = CompileStats() diff --git a/thunder/common.py b/thunder/common.py index 67afb496a0..54dbe13c34 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -154,6 +154,7 @@ def __init__( debug_log: None | StringIO = None, compile_options: dict[str, Any] = {}, get_computation_and_inputs: Callable | None = None, + executor_lookasides: dict[Callable, Callable] | None = None, ): # Records whether we're using the thunder.jit() entrypoint or not # The thunder.jit() entrypoint introduces important architectural updates, @@ -164,6 +165,9 @@ def __init__( # runs prologues to get the compute/backward/epilogue function and inputs self.get_computation_and_inputs = get_computation_and_inputs + # lookasides provided by the executors + self.executor_lookasides = executor_lookasides + # Resolves cache option self.cache_option = resolve_cache_option(cache_option) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index e626a77f2a..897c1935a0 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -519,7 +519,13 @@ def get_postfix(pr: ProvenanceRecord): class GeneralJitCtx(MinimalCtx): def __init__( - self, prologue_trace, computation_trace, *, sharp_edges: SHARP_EDGES_OPTIONS, process_group_for_ddp=None + self, + prologue_trace, + computation_trace, + *, + sharp_edges: SHARP_EDGES_OPTIONS, + process_group_for_ddp=None, + executor_lookasides, ): super().__init__(sharp_edges=sharp_edges) @@ -529,6 +535,7 @@ def __init__( self._process_group_for_ddp = process_group_for_ddp self._additional_outputs = collections.defaultdict(list) self._proxy_swapmap: dict[Variable, Proxy] = {} + self._executor_lookasides: dict[Callable, Callable] = executor_lookasides @property def prologue_trace(self) -> TraceCtx: @@ -871,7 +878,12 @@ def proxy_recursion(v): def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: # Identifies the lookaside lookaside: None | Callable - if isinstance(fn, Symbol) or fn in _clang_fn_set: + + ctx: GeneralJitCtx = get_general_jit_ctx() + + if (executor_lookaside := ctx._executor_lookasides.get(fn, None)) is not None: + lookaside = executor_lookaside + elif isinstance(fn, Symbol) or fn in _clang_fn_set: # Performs symbol lookasides # NOTE Symbols "lookaside" to themselves; this just prevents their internals from being jitted # NOTE clang operations are not symbols, but we still prevent their internals from being jitted @@ -1369,9 +1381,17 @@ def thunder_general_jit( si.varkwargs = ("kwargs", None) prologue_trace._siginfo = si + compile_data = get_compile_data() + process_group_for_ddp = getattr(fn, "process_group_for_ddp", None) + executor_lookasides = {k: interpreter_needs_wrap(v) for k, v in compile_data.executor_lookasides.items()} + ctx: GeneralJitCtx = GeneralJitCtx( - prologue_trace, computation_trace, sharp_edges=sharp_edges, process_group_for_ddp=process_group_for_ddp + prologue_trace, + computation_trace, + sharp_edges=sharp_edges, + process_group_for_ddp=process_group_for_ddp, + executor_lookasides=executor_lookasides, ) jfn = interpret( fn, diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 2acdcf6427..9a087eb6ee 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -25,7 +25,6 @@ "add_always_executor", "remove_default_executor", "remove_always_executor", - "register_lookaside", ] @@ -50,6 +49,7 @@ def __init__(self, name: Hashable, *, version: None | Any = None): self._version = version self._implmap: dict[Hashable, ImplInfo] = {} + self._lookasides: dict[Callable, Callable] = {} @property def name(self) -> Hashable: @@ -240,7 +240,7 @@ def _bind_postprocess(bsym: BoundSymbol) -> None: self.opmap[name] = sym if replaces is not None: - register_lookaside(replaces, sym) + self._lookasides[replaces] = sym return sym @@ -386,10 +386,3 @@ def deregister_executor(ex: Hashable | Executor) -> None: remove_always_executor(id) remove_default_executor(id) - - -def register_lookaside(function, symbol) -> None: - """register `symbol` as a lookaside for `function`""" - import thunder.core.jit_ext - - thunder.core.jit_ext._general_jit_lookaside_map[function] = thunder.core.jit_ext.interpreter_needs_wrap(symbol) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index b7dd300d12..2846e12519 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -184,4 +184,11 @@ def myadd_grad_trafo(a, b): a.requires_grad_() + # without the executor, we just (should and do) jit through official_add + cfn = thunder.jit(fn) + res = cfn(a, b) + + s = str(thunder.last_traces(cfn)[-1]) + assert "myadd2" not in s and "myadd1" not in s + deregister_executor(myex)