Skip to content

Commit

Permalink
Basic scaffolding of a custom kernel DSL. (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident authored Dec 1, 2023
1 parent dd1c771 commit 65c82e8
Show file tree
Hide file tree
Showing 8 changed files with 459 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/shark_turbine/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import gen
from . import lang


# Helpers that are good to have in the global scope.
def __getattr__(name):
if name == "DEBUG":
return lang.is_debug()
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")


# Dynamic attributes so that IDEs see them.
DEBUG: bool
189 changes: 189 additions & 0 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from abc import ABC, abstractmethod
from typing import Optional, TypeVar, Callable, assert_type, cast

import functools
import threading
import warnings

import torch.fx as fx

from ..lang.types import (
KernelBuffer,
Grid,
)

_tls = threading.local()
TCallable = TypeVar("TCallable", bound=Callable)

###############################################################################
# Wrapped tracing trampolines for proxy objects.
# These only get called during tracing of proxy objects.
###############################################################################


@fx.wrap
def _kernel_buffer_setitem(kernel_buffer: KernelBuffer, key, item) -> None:
...


###############################################################################
# Tracing machinery
###############################################################################


class KernelBufferProxy(fx.Proxy):
"""Custom proxy for KernelBuffer so that we can override special methods."""

def __setitem__(self, key, item):
_kernel_buffer_setitem(self, key, item)


class KernelTracer(fx.Tracer):
"""Custom Tracer for generating a trace of a kernel computation."""

def proxy(self, node: fx.Node) -> fx.Proxy:
if node.type == KernelBuffer:
return KernelBufferProxy(node, self)
return super().proxy(node)


class CapturedTrace:
def __init__(self, gm: fx.GraphModule):
self.gm = gm


###############################################################################
# Execution context.
# A valid BaseContext derived instance (EagerContext or CompiledContext) must
# be active for any evaluation of a generated/traced function.
###############################################################################


class BaseContext:
def __init__(self, *, eager: bool):
self.eager = eager

@staticmethod
def current() -> "BaseContext":
try:
return _tls.context[-1]
except (AttributeError, IndexError):
raise RuntimeError("No context is on the stack")

def __enter__(self) -> "BaseContext":
try:
stack = _tls.context
except AttributeError:
stack = []
_tls.context = stack
stack.append(self)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
_tls.context.pop()


class EagerContext(BaseContext):
def __init__(self, rank: int = 0):
super().__init__(eager=True)
self.rank = rank
self.current_thread: list[int] = rank * [0]


class CompiledContext(BaseContext):
def __init__(self, tracer: KernelTracer):
super().__init__(eager=False)
self.tracer = tracer


###############################################################################
# Launch context
# The launch context controls how the call into a kernel is dispatched.
# This can either be to run it eagerly for debugging or some higher order
# integration.
###############################################################################


class Launchable(ABC):
"""Base class for objects which behave like a kernel launch when called."""

def __init__(self, eager_function: Callable):
self._eager_function = eager_function

def __call__(self, *args, **kwargs):
launch_context = LaunchContext.current()
return launch_context.launch(self, args, kwargs)

@abstractmethod
def eager_execute(self, args, kwargs):
...


class LaunchContext(ABC):
@staticmethod
def current() -> "LaunchContext":
try:
return _tls.launch[-1]
except (AttributeError, IndexError):
warnings.warn(
"defaulting to debug/eager execution of tk kernel launch "
"because no launch context has been established"
)
return DebugLaunchContext()

def __enter__(self) -> "LaunchContext":
try:
stack = _tls.launch
except AttributeError:
stack = []
_tls.launch = stack
stack.append(self)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
_tls.launch.pop()

@abstractmethod
def launch(self, launchable: Launchable, args, kwargs):
...


class DebugLaunchContext(LaunchContext):
def launch(self, launchable: Launchable, args, kwargs):
return launchable.eager_execute(args, kwargs)


###############################################################################
# Helpers
###############################################################################


def eager_context() -> EagerContext:
context = BaseContext.current()
assert context.eager, "Expected to be executed against an EagerContext"
assert_type(context, EagerContext)
return context


def custom_primitive_fn(
f: Optional[TCallable] = None, *, compiled: Callable
) -> TCallable:
"""Decorator for a primitive function with a custom callback for tracing.
The wrapped function will be invoked as-is when executing eagerly. When
tracing, the `compiled` callback will be invoked with the same signature
but with the `CompiledContext` added as a first postional argument.
"""
if f is None:
return functools.partial(custom_primitive_fn, compiled=compiled)

@functools.wraps(f)
def wrapper(*args, **kwargs): # type: ignore
context = BaseContext.current()
if context.eager:
return f(*args, **kwargs)
else:
assert_type(context, CompiledContext)
return compiled(context, *args, **kwargs)

return cast(TCallable, wrapper)
1 change: 1 addition & 0 deletions python/shark_turbine/kernel/gen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .thread import *
84 changes: 84 additions & 0 deletions python/shark_turbine/kernel/gen/thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Generic, Optional, TypeVar, Callable, Union, assert_type, cast

import functools
import math

import torch.fx as fx

from ..lang import (
KernelBuffer,
Grid,
)

from .._support.tracing import (
CapturedTrace,
CompiledContext,
EagerContext,
KernelTracer,
Launchable,
)

__all__ = [
"thread",
]

TCallable = TypeVar("TCallable", bound=Callable)


def thread(f: TCallable) -> TCallable:
# Eagerly capture the trace and attach it to the wrapped function.
tracer = KernelTracer()
with CompiledContext(tracer) as context:
g = tracer.trace(f)
gm = fx.GraphModule(tracer.root, g, f.__name__)

return UnconfiguredThread[TCallable](f.__name__, f, CapturedTrace(gm))


class UnconfiguredThread(Generic[TCallable]):
def __init__(self, name: str, wrapped_f: TCallable, trace: CapturedTrace):
self._name = name
self._wrapped_f = wrapped_f
self._trace = trace

def __getitem__(self, grid: Union[int, Grid]) -> TCallable:
if isinstance(grid, int):
grid = (grid,)
assert isinstance(grid, tuple) and all(isinstance(i, int) for i in grid)
return cast(
TCallable, LaunchableThread(grid, self._name, self._wrapped_f, self._trace)
)

def __repr__(self):
return f"tk.gen.thread @{self._name}[no grid]"


class LaunchableThread(Launchable):
def __init__(
self, grid: Grid, name: str, eager_function: Callable, trace: CapturedTrace
):
super().__init__(eager_function)
self.grid = grid
self._name = name
self._trace = trace

def eager_execute(self, args, kwargs):
grid = self.grid
rank = len(grid)
with EagerContext(rank=rank) as context:
# Transform args to KernelBuffers.
buffer_args = [
arg if isinstance(arg, KernelBuffer) else KernelBuffer(arg)
for arg in args
]
volume = math.prod(grid)
current_thread = context.current_thread
for it in range(volume):
for i in range(rank - 1):
current_thread[i] = it // grid[i]
it = it % grid[i]
current_thread[-1] = it
self._eager_function(*buffer_args, **kwargs)

def __repr__(self):
return f"tk.gen.thread @{self._name}[{', '.join(str(i) for i in self.grid)}]"
2 changes: 2 additions & 0 deletions python/shark_turbine/kernel/lang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .prims import *
from .types import *
35 changes: 35 additions & 0 deletions python/shark_turbine/kernel/lang/prims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import assert_type

from .._support.tracing import (
BaseContext,
CompiledContext,
custom_primitive_fn,
eager_context,
)

__all__ = [
"is_debug",
"program_id",
]


def is_debug() -> bool:
"""Returns whether we are currently executing a kernel in eager debug mode."""
return BaseContext.current().eager


def _compiled_program_id(context: CompiledContext, axis):
# Compiled. Note that tracing must be open coded on this
# function because it does not take a proxy as an argument
# (and therefore, the symbolic tracer exempts it from tracing
# according to its heuristic).
proxy = context.tracer.create_proxy("call_function", program_id, (axis,), {})
return proxy


@custom_primitive_fn(compiled=_compiled_program_id)
def program_id(axis: int) -> int:
"""Access the program id value for the given grid axis."""
context = eager_context()
assert axis >= 0 and axis < context.rank
return context.current_thread[axis]
39 changes: 39 additions & 0 deletions python/shark_turbine/kernel/lang/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

__all__ = [
"KernelBuffer",
"Grid",
]

Grid = tuple[int, ...]


class KernelBuffer:
"""Represents a buffer in global memory.
Top level kernels always operate on global memory via these
buffers, and the primary operations that can be performed on
them are loads/stores and DMAs to some form of compute
capable local buffer.
When executing eagerly, these are backed by a normal torch
Tensor. When compiling, an appropriate duck-typed proxy
is used.
"""

__slots__ = [
"_tensor",
]

def __init__(self, tensor: torch.Tensor):
assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}"
self._tensor = tensor

def __repr__(self):
return f"KernelBuffer({self._tensor})"

def __setitem__(self, key, item):
self._tensor.__setitem__(key, item)

def __getitem__(self, key):
return self._tensor.__getitem__(key)
Loading

0 comments on commit 65c82e8

Please sign in to comment.