Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

only use executor lookasides when their executor is used #53

Merged
merged 2 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ def jit(
# if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON and sharp_edges is None:
# sharp_edges = SHARP_EDGES_OPTIONS.WARN

executor_lookasides = {}
for ex in executors or []:
# TODO: sharp edge if lookasides are shadowed?
executor_lookasides.update(ex._lookasides)

# TODO RC1 Refine the compile data option to remove unused options
cd = CompileData(
fn=fn,
Expand All @@ -364,6 +369,7 @@ def jit(
only_execute_prims=False,
disable_preprocessing=True,
compile_options=compile_options,
executor_lookasides=executor_lookasides,
)
cs = CompileStats()

Expand Down
4 changes: 4 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
debug_log: None | StringIO = None,
compile_options: dict[str, Any] = {},
get_computation_and_inputs: Callable | None = None,
executor_lookasides: dict[Callable, Callable] | None = None,
):
# Records whether we're using the thunder.jit() entrypoint or not
# The thunder.jit() entrypoint introduces important architectural updates,
Expand All @@ -164,6 +165,9 @@ def __init__(
# runs prologues to get the compute/backward/epilogue function and inputs
self.get_computation_and_inputs = get_computation_and_inputs

# lookasides provided by the executors
self.executor_lookasides = executor_lookasides

# Resolves cache option
self.cache_option = resolve_cache_option(cache_option)

Expand Down
25 changes: 22 additions & 3 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,13 @@ def get_postfix(pr: ProvenanceRecord):

class GeneralJitCtx(MinimalCtx):
def __init__(
self, prologue_trace, computation_trace, *, sharp_edges: SHARP_EDGES_OPTIONS, process_group_for_ddp=None
self,
prologue_trace,
computation_trace,
*,
sharp_edges: SHARP_EDGES_OPTIONS,
process_group_for_ddp=None,
executor_lookasides,
):
super().__init__(sharp_edges=sharp_edges)

Expand All @@ -530,6 +536,7 @@ def __init__(
self._process_group_for_ddp = process_group_for_ddp
self._additional_outputs = collections.defaultdict(list)
self._proxy_swapmap: dict[Variable, Proxy] = {}
self._executor_lookasides: dict[Callable, Callable] = executor_lookasides

@property
def prologue_trace(self) -> TraceCtx:
Expand Down Expand Up @@ -872,7 +879,12 @@ def proxy_recursion(v):
def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable:
# Identifies the lookaside
lookaside: None | Callable
if isinstance(fn, Symbol) or fn in _clang_fn_set:

ctx: GeneralJitCtx = get_general_jit_ctx()

if (executor_lookaside := ctx._executor_lookasides.get(fn, None)) is not None:
lookaside = executor_lookaside
elif isinstance(fn, Symbol) or fn in _clang_fn_set:
# Performs symbol lookasides
# NOTE Symbols "lookaside" to themselves; this just prevents their internals from being jitted
# NOTE clang operations are not symbols, but we still prevent their internals from being jitted
Expand Down Expand Up @@ -1384,9 +1396,16 @@ def thunder_general_jit(
si.varkwargs = ("kwargs", None)
prologue_trace._siginfo = si

compile_data = get_compile_data()
executor_lookasides = {k: interpreter_needs_wrap(v) for k, v in compile_data.executor_lookasides.items()}

process_group_for_ddp: Optional["ProcessGroup"] = _get_process_group_from(fn, *args, *kwargs.values())
ctx: GeneralJitCtx = GeneralJitCtx(
prologue_trace, computation_trace, sharp_edges=sharp_edges, process_group_for_ddp=process_group_for_ddp
prologue_trace,
computation_trace,
sharp_edges=sharp_edges,
process_group_for_ddp=process_group_for_ddp,
executor_lookasides=executor_lookasides,
)
jfn = interpret(
fn,
Expand Down
11 changes: 2 additions & 9 deletions thunder/extend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
"add_always_executor",
"remove_default_executor",
"remove_always_executor",
"register_lookaside",
]


Expand All @@ -50,6 +49,7 @@ def __init__(self, name: Hashable, *, version: None | Any = None):
self._version = version

self._implmap: dict[Hashable, ImplInfo] = {}
self._lookasides: dict[Callable, Callable] = {}

@property
def name(self) -> Hashable:
Expand Down Expand Up @@ -240,7 +240,7 @@ def _bind_postprocess(bsym: BoundSymbol) -> None:
self.opmap[name] = sym

if replaces is not None:
register_lookaside(replaces, sym)
self._lookasides[replaces] = sym

return sym

Expand Down Expand Up @@ -382,10 +382,3 @@ def deregister_executor(ex: Hashable | Executor) -> None:

remove_always_executor(id)
remove_default_executor(id)


def register_lookaside(function, symbol) -> None:
"""register `symbol` as a lookaside for `function`"""
import thunder.core.jit_ext

thunder.core.jit_ext._general_jit_lookaside_map[function] = thunder.core.jit_ext.interpreter_needs_wrap(symbol)
7 changes: 7 additions & 0 deletions thunder/tests/test_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,11 @@ def myadd_grad_trafo(a, b):

a.requires_grad_()

# without the executor, we just (should and do) jit through official_add
cfn = thunder.jit(fn)
res = cfn(a, b)

s = str(thunder.last_traces(cfn)[-1])
assert "myadd2" not in s and "myadd1" not in s

deregister_executor(myex)
Loading