-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Basic scaffolding of a custom kernel DSL. (#67)
- Loading branch information
1 parent
dd1c771
commit 65c82e8
Showing
8 changed files
with
459 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2023 Nod Labs, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from . import gen | ||
from . import lang | ||
|
||
|
||
# Helpers that are good to have in the global scope. | ||
def __getattr__(name): | ||
if name == "DEBUG": | ||
return lang.is_debug() | ||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'") | ||
|
||
|
||
# Dynamic attributes so that IDEs see them. | ||
DEBUG: bool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional, TypeVar, Callable, assert_type, cast | ||
|
||
import functools | ||
import threading | ||
import warnings | ||
|
||
import torch.fx as fx | ||
|
||
from ..lang.types import ( | ||
KernelBuffer, | ||
Grid, | ||
) | ||
|
||
_tls = threading.local() | ||
TCallable = TypeVar("TCallable", bound=Callable) | ||
|
||
############################################################################### | ||
# Wrapped tracing trampolines for proxy objects. | ||
# These only get called during tracing of proxy objects. | ||
############################################################################### | ||
|
||
|
||
@fx.wrap | ||
def _kernel_buffer_setitem(kernel_buffer: KernelBuffer, key, item) -> None: | ||
... | ||
|
||
|
||
############################################################################### | ||
# Tracing machinery | ||
############################################################################### | ||
|
||
|
||
class KernelBufferProxy(fx.Proxy): | ||
"""Custom proxy for KernelBuffer so that we can override special methods.""" | ||
|
||
def __setitem__(self, key, item): | ||
_kernel_buffer_setitem(self, key, item) | ||
|
||
|
||
class KernelTracer(fx.Tracer): | ||
"""Custom Tracer for generating a trace of a kernel computation.""" | ||
|
||
def proxy(self, node: fx.Node) -> fx.Proxy: | ||
if node.type == KernelBuffer: | ||
return KernelBufferProxy(node, self) | ||
return super().proxy(node) | ||
|
||
|
||
class CapturedTrace: | ||
def __init__(self, gm: fx.GraphModule): | ||
self.gm = gm | ||
|
||
|
||
############################################################################### | ||
# Execution context. | ||
# A valid BaseContext derived instance (EagerContext or CompiledContext) must | ||
# be active for any evaluation of a generated/traced function. | ||
############################################################################### | ||
|
||
|
||
class BaseContext: | ||
def __init__(self, *, eager: bool): | ||
self.eager = eager | ||
|
||
@staticmethod | ||
def current() -> "BaseContext": | ||
try: | ||
return _tls.context[-1] | ||
except (AttributeError, IndexError): | ||
raise RuntimeError("No context is on the stack") | ||
|
||
def __enter__(self) -> "BaseContext": | ||
try: | ||
stack = _tls.context | ||
except AttributeError: | ||
stack = [] | ||
_tls.context = stack | ||
stack.append(self) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
_tls.context.pop() | ||
|
||
|
||
class EagerContext(BaseContext): | ||
def __init__(self, rank: int = 0): | ||
super().__init__(eager=True) | ||
self.rank = rank | ||
self.current_thread: list[int] = rank * [0] | ||
|
||
|
||
class CompiledContext(BaseContext): | ||
def __init__(self, tracer: KernelTracer): | ||
super().__init__(eager=False) | ||
self.tracer = tracer | ||
|
||
|
||
############################################################################### | ||
# Launch context | ||
# The launch context controls how the call into a kernel is dispatched. | ||
# This can either be to run it eagerly for debugging or some higher order | ||
# integration. | ||
############################################################################### | ||
|
||
|
||
class Launchable(ABC): | ||
"""Base class for objects which behave like a kernel launch when called.""" | ||
|
||
def __init__(self, eager_function: Callable): | ||
self._eager_function = eager_function | ||
|
||
def __call__(self, *args, **kwargs): | ||
launch_context = LaunchContext.current() | ||
return launch_context.launch(self, args, kwargs) | ||
|
||
@abstractmethod | ||
def eager_execute(self, args, kwargs): | ||
... | ||
|
||
|
||
class LaunchContext(ABC): | ||
@staticmethod | ||
def current() -> "LaunchContext": | ||
try: | ||
return _tls.launch[-1] | ||
except (AttributeError, IndexError): | ||
warnings.warn( | ||
"defaulting to debug/eager execution of tk kernel launch " | ||
"because no launch context has been established" | ||
) | ||
return DebugLaunchContext() | ||
|
||
def __enter__(self) -> "LaunchContext": | ||
try: | ||
stack = _tls.launch | ||
except AttributeError: | ||
stack = [] | ||
_tls.launch = stack | ||
stack.append(self) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
_tls.launch.pop() | ||
|
||
@abstractmethod | ||
def launch(self, launchable: Launchable, args, kwargs): | ||
... | ||
|
||
|
||
class DebugLaunchContext(LaunchContext): | ||
def launch(self, launchable: Launchable, args, kwargs): | ||
return launchable.eager_execute(args, kwargs) | ||
|
||
|
||
############################################################################### | ||
# Helpers | ||
############################################################################### | ||
|
||
|
||
def eager_context() -> EagerContext: | ||
context = BaseContext.current() | ||
assert context.eager, "Expected to be executed against an EagerContext" | ||
assert_type(context, EagerContext) | ||
return context | ||
|
||
|
||
def custom_primitive_fn( | ||
f: Optional[TCallable] = None, *, compiled: Callable | ||
) -> TCallable: | ||
"""Decorator for a primitive function with a custom callback for tracing. | ||
The wrapped function will be invoked as-is when executing eagerly. When | ||
tracing, the `compiled` callback will be invoked with the same signature | ||
but with the `CompiledContext` added as a first postional argument. | ||
""" | ||
if f is None: | ||
return functools.partial(custom_primitive_fn, compiled=compiled) | ||
|
||
@functools.wraps(f) | ||
def wrapper(*args, **kwargs): # type: ignore | ||
context = BaseContext.current() | ||
if context.eager: | ||
return f(*args, **kwargs) | ||
else: | ||
assert_type(context, CompiledContext) | ||
return compiled(context, *args, **kwargs) | ||
|
||
return cast(TCallable, wrapper) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .thread import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Generic, Optional, TypeVar, Callable, Union, assert_type, cast | ||
|
||
import functools | ||
import math | ||
|
||
import torch.fx as fx | ||
|
||
from ..lang import ( | ||
KernelBuffer, | ||
Grid, | ||
) | ||
|
||
from .._support.tracing import ( | ||
CapturedTrace, | ||
CompiledContext, | ||
EagerContext, | ||
KernelTracer, | ||
Launchable, | ||
) | ||
|
||
__all__ = [ | ||
"thread", | ||
] | ||
|
||
TCallable = TypeVar("TCallable", bound=Callable) | ||
|
||
|
||
def thread(f: TCallable) -> TCallable: | ||
# Eagerly capture the trace and attach it to the wrapped function. | ||
tracer = KernelTracer() | ||
with CompiledContext(tracer) as context: | ||
g = tracer.trace(f) | ||
gm = fx.GraphModule(tracer.root, g, f.__name__) | ||
|
||
return UnconfiguredThread[TCallable](f.__name__, f, CapturedTrace(gm)) | ||
|
||
|
||
class UnconfiguredThread(Generic[TCallable]): | ||
def __init__(self, name: str, wrapped_f: TCallable, trace: CapturedTrace): | ||
self._name = name | ||
self._wrapped_f = wrapped_f | ||
self._trace = trace | ||
|
||
def __getitem__(self, grid: Union[int, Grid]) -> TCallable: | ||
if isinstance(grid, int): | ||
grid = (grid,) | ||
assert isinstance(grid, tuple) and all(isinstance(i, int) for i in grid) | ||
return cast( | ||
TCallable, LaunchableThread(grid, self._name, self._wrapped_f, self._trace) | ||
) | ||
|
||
def __repr__(self): | ||
return f"tk.gen.thread @{self._name}[no grid]" | ||
|
||
|
||
class LaunchableThread(Launchable): | ||
def __init__( | ||
self, grid: Grid, name: str, eager_function: Callable, trace: CapturedTrace | ||
): | ||
super().__init__(eager_function) | ||
self.grid = grid | ||
self._name = name | ||
self._trace = trace | ||
|
||
def eager_execute(self, args, kwargs): | ||
grid = self.grid | ||
rank = len(grid) | ||
with EagerContext(rank=rank) as context: | ||
# Transform args to KernelBuffers. | ||
buffer_args = [ | ||
arg if isinstance(arg, KernelBuffer) else KernelBuffer(arg) | ||
for arg in args | ||
] | ||
volume = math.prod(grid) | ||
current_thread = context.current_thread | ||
for it in range(volume): | ||
for i in range(rank - 1): | ||
current_thread[i] = it // grid[i] | ||
it = it % grid[i] | ||
current_thread[-1] = it | ||
self._eager_function(*buffer_args, **kwargs) | ||
|
||
def __repr__(self): | ||
return f"tk.gen.thread @{self._name}[{', '.join(str(i) for i in self.grid)}]" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .prims import * | ||
from .types import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import assert_type | ||
|
||
from .._support.tracing import ( | ||
BaseContext, | ||
CompiledContext, | ||
custom_primitive_fn, | ||
eager_context, | ||
) | ||
|
||
__all__ = [ | ||
"is_debug", | ||
"program_id", | ||
] | ||
|
||
|
||
def is_debug() -> bool: | ||
"""Returns whether we are currently executing a kernel in eager debug mode.""" | ||
return BaseContext.current().eager | ||
|
||
|
||
def _compiled_program_id(context: CompiledContext, axis): | ||
# Compiled. Note that tracing must be open coded on this | ||
# function because it does not take a proxy as an argument | ||
# (and therefore, the symbolic tracer exempts it from tracing | ||
# according to its heuristic). | ||
proxy = context.tracer.create_proxy("call_function", program_id, (axis,), {}) | ||
return proxy | ||
|
||
|
||
@custom_primitive_fn(compiled=_compiled_program_id) | ||
def program_id(axis: int) -> int: | ||
"""Access the program id value for the given grid axis.""" | ||
context = eager_context() | ||
assert axis >= 0 and axis < context.rank | ||
return context.current_thread[axis] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
__all__ = [ | ||
"KernelBuffer", | ||
"Grid", | ||
] | ||
|
||
Grid = tuple[int, ...] | ||
|
||
|
||
class KernelBuffer: | ||
"""Represents a buffer in global memory. | ||
Top level kernels always operate on global memory via these | ||
buffers, and the primary operations that can be performed on | ||
them are loads/stores and DMAs to some form of compute | ||
capable local buffer. | ||
When executing eagerly, these are backed by a normal torch | ||
Tensor. When compiling, an appropriate duck-typed proxy | ||
is used. | ||
""" | ||
|
||
__slots__ = [ | ||
"_tensor", | ||
] | ||
|
||
def __init__(self, tensor: torch.Tensor): | ||
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" | ||
self._tensor = tensor | ||
|
||
def __repr__(self): | ||
return f"KernelBuffer({self._tensor})" | ||
|
||
def __setitem__(self, key, item): | ||
self._tensor.__setitem__(key, item) | ||
|
||
def __getitem__(self, key): | ||
return self._tensor.__getitem__(key) |
Oops, something went wrong.