From 65c82e8985ad6dab7b8a60fceedeb530a9032ddc Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 30 Nov 2023 19:32:24 -0800 Subject: [PATCH] Basic scaffolding of a custom kernel DSL. (#67) --- python/shark_turbine/kernel/__init__.py | 19 ++ .../shark_turbine/kernel/_support/tracing.py | 189 ++++++++++++++++++ python/shark_turbine/kernel/gen/__init__.py | 1 + python/shark_turbine/kernel/gen/thread.py | 84 ++++++++ python/shark_turbine/kernel/lang/__init__.py | 2 + python/shark_turbine/kernel/lang/prims.py | 35 ++++ python/shark_turbine/kernel/lang/types.py | 39 ++++ tests/kernel/simple_kernel_test.py | 90 +++++++++ 8 files changed, 459 insertions(+) create mode 100644 python/shark_turbine/kernel/__init__.py create mode 100644 python/shark_turbine/kernel/_support/tracing.py create mode 100644 python/shark_turbine/kernel/gen/__init__.py create mode 100644 python/shark_turbine/kernel/gen/thread.py create mode 100644 python/shark_turbine/kernel/lang/__init__.py create mode 100644 python/shark_turbine/kernel/lang/prims.py create mode 100644 python/shark_turbine/kernel/lang/types.py create mode 100644 tests/kernel/simple_kernel_test.py diff --git a/python/shark_turbine/kernel/__init__.py b/python/shark_turbine/kernel/__init__.py new file mode 100644 index 000000000..333dce24f --- /dev/null +++ b/python/shark_turbine/kernel/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import gen +from . import lang + + +# Helpers that are good to have in the global scope. +def __getattr__(name): + if name == "DEBUG": + return lang.is_debug() + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +# Dynamic attributes so that IDEs see them. +DEBUG: bool diff --git a/python/shark_turbine/kernel/_support/tracing.py b/python/shark_turbine/kernel/_support/tracing.py new file mode 100644 index 000000000..8c77855da --- /dev/null +++ b/python/shark_turbine/kernel/_support/tracing.py @@ -0,0 +1,189 @@ +from abc import ABC, abstractmethod +from typing import Optional, TypeVar, Callable, assert_type, cast + +import functools +import threading +import warnings + +import torch.fx as fx + +from ..lang.types import ( + KernelBuffer, + Grid, +) + +_tls = threading.local() +TCallable = TypeVar("TCallable", bound=Callable) + +############################################################################### +# Wrapped tracing trampolines for proxy objects. +# These only get called during tracing of proxy objects. +############################################################################### + + +@fx.wrap +def _kernel_buffer_setitem(kernel_buffer: KernelBuffer, key, item) -> None: + ... + + +############################################################################### +# Tracing machinery +############################################################################### + + +class KernelBufferProxy(fx.Proxy): + """Custom proxy for KernelBuffer so that we can override special methods.""" + + def __setitem__(self, key, item): + _kernel_buffer_setitem(self, key, item) + + +class KernelTracer(fx.Tracer): + """Custom Tracer for generating a trace of a kernel computation.""" + + def proxy(self, node: fx.Node) -> fx.Proxy: + if node.type == KernelBuffer: + return KernelBufferProxy(node, self) + return super().proxy(node) + + +class CapturedTrace: + def __init__(self, gm: fx.GraphModule): + self.gm = gm + + +############################################################################### +# Execution context. +# A valid BaseContext derived instance (EagerContext or CompiledContext) must +# be active for any evaluation of a generated/traced function. +############################################################################### + + +class BaseContext: + def __init__(self, *, eager: bool): + self.eager = eager + + @staticmethod + def current() -> "BaseContext": + try: + return _tls.context[-1] + except (AttributeError, IndexError): + raise RuntimeError("No context is on the stack") + + def __enter__(self) -> "BaseContext": + try: + stack = _tls.context + except AttributeError: + stack = [] + _tls.context = stack + stack.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _tls.context.pop() + + +class EagerContext(BaseContext): + def __init__(self, rank: int = 0): + super().__init__(eager=True) + self.rank = rank + self.current_thread: list[int] = rank * [0] + + +class CompiledContext(BaseContext): + def __init__(self, tracer: KernelTracer): + super().__init__(eager=False) + self.tracer = tracer + + +############################################################################### +# Launch context +# The launch context controls how the call into a kernel is dispatched. +# This can either be to run it eagerly for debugging or some higher order +# integration. +############################################################################### + + +class Launchable(ABC): + """Base class for objects which behave like a kernel launch when called.""" + + def __init__(self, eager_function: Callable): + self._eager_function = eager_function + + def __call__(self, *args, **kwargs): + launch_context = LaunchContext.current() + return launch_context.launch(self, args, kwargs) + + @abstractmethod + def eager_execute(self, args, kwargs): + ... + + +class LaunchContext(ABC): + @staticmethod + def current() -> "LaunchContext": + try: + return _tls.launch[-1] + except (AttributeError, IndexError): + warnings.warn( + "defaulting to debug/eager execution of tk kernel launch " + "because no launch context has been established" + ) + return DebugLaunchContext() + + def __enter__(self) -> "LaunchContext": + try: + stack = _tls.launch + except AttributeError: + stack = [] + _tls.launch = stack + stack.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _tls.launch.pop() + + @abstractmethod + def launch(self, launchable: Launchable, args, kwargs): + ... + + +class DebugLaunchContext(LaunchContext): + def launch(self, launchable: Launchable, args, kwargs): + return launchable.eager_execute(args, kwargs) + + +############################################################################### +# Helpers +############################################################################### + + +def eager_context() -> EagerContext: + context = BaseContext.current() + assert context.eager, "Expected to be executed against an EagerContext" + assert_type(context, EagerContext) + return context + + +def custom_primitive_fn( + f: Optional[TCallable] = None, *, compiled: Callable +) -> TCallable: + """Decorator for a primitive function with a custom callback for tracing. + + The wrapped function will be invoked as-is when executing eagerly. When + tracing, the `compiled` callback will be invoked with the same signature + but with the `CompiledContext` added as a first postional argument. + """ + if f is None: + return functools.partial(custom_primitive_fn, compiled=compiled) + + @functools.wraps(f) + def wrapper(*args, **kwargs): # type: ignore + context = BaseContext.current() + if context.eager: + return f(*args, **kwargs) + else: + assert_type(context, CompiledContext) + return compiled(context, *args, **kwargs) + + return cast(TCallable, wrapper) diff --git a/python/shark_turbine/kernel/gen/__init__.py b/python/shark_turbine/kernel/gen/__init__.py new file mode 100644 index 000000000..0eb9037c8 --- /dev/null +++ b/python/shark_turbine/kernel/gen/__init__.py @@ -0,0 +1 @@ +from .thread import * diff --git a/python/shark_turbine/kernel/gen/thread.py b/python/shark_turbine/kernel/gen/thread.py new file mode 100644 index 000000000..29a9fe16d --- /dev/null +++ b/python/shark_turbine/kernel/gen/thread.py @@ -0,0 +1,84 @@ +from typing import Generic, Optional, TypeVar, Callable, Union, assert_type, cast + +import functools +import math + +import torch.fx as fx + +from ..lang import ( + KernelBuffer, + Grid, +) + +from .._support.tracing import ( + CapturedTrace, + CompiledContext, + EagerContext, + KernelTracer, + Launchable, +) + +__all__ = [ + "thread", +] + +TCallable = TypeVar("TCallable", bound=Callable) + + +def thread(f: TCallable) -> TCallable: + # Eagerly capture the trace and attach it to the wrapped function. + tracer = KernelTracer() + with CompiledContext(tracer) as context: + g = tracer.trace(f) + gm = fx.GraphModule(tracer.root, g, f.__name__) + + return UnconfiguredThread[TCallable](f.__name__, f, CapturedTrace(gm)) + + +class UnconfiguredThread(Generic[TCallable]): + def __init__(self, name: str, wrapped_f: TCallable, trace: CapturedTrace): + self._name = name + self._wrapped_f = wrapped_f + self._trace = trace + + def __getitem__(self, grid: Union[int, Grid]) -> TCallable: + if isinstance(grid, int): + grid = (grid,) + assert isinstance(grid, tuple) and all(isinstance(i, int) for i in grid) + return cast( + TCallable, LaunchableThread(grid, self._name, self._wrapped_f, self._trace) + ) + + def __repr__(self): + return f"tk.gen.thread @{self._name}[no grid]" + + +class LaunchableThread(Launchable): + def __init__( + self, grid: Grid, name: str, eager_function: Callable, trace: CapturedTrace + ): + super().__init__(eager_function) + self.grid = grid + self._name = name + self._trace = trace + + def eager_execute(self, args, kwargs): + grid = self.grid + rank = len(grid) + with EagerContext(rank=rank) as context: + # Transform args to KernelBuffers. + buffer_args = [ + arg if isinstance(arg, KernelBuffer) else KernelBuffer(arg) + for arg in args + ] + volume = math.prod(grid) + current_thread = context.current_thread + for it in range(volume): + for i in range(rank - 1): + current_thread[i] = it // grid[i] + it = it % grid[i] + current_thread[-1] = it + self._eager_function(*buffer_args, **kwargs) + + def __repr__(self): + return f"tk.gen.thread @{self._name}[{', '.join(str(i) for i in self.grid)}]" diff --git a/python/shark_turbine/kernel/lang/__init__.py b/python/shark_turbine/kernel/lang/__init__.py new file mode 100644 index 000000000..332381984 --- /dev/null +++ b/python/shark_turbine/kernel/lang/__init__.py @@ -0,0 +1,2 @@ +from .prims import * +from .types import * diff --git a/python/shark_turbine/kernel/lang/prims.py b/python/shark_turbine/kernel/lang/prims.py new file mode 100644 index 000000000..54ed9f2fb --- /dev/null +++ b/python/shark_turbine/kernel/lang/prims.py @@ -0,0 +1,35 @@ +from typing import assert_type + +from .._support.tracing import ( + BaseContext, + CompiledContext, + custom_primitive_fn, + eager_context, +) + +__all__ = [ + "is_debug", + "program_id", +] + + +def is_debug() -> bool: + """Returns whether we are currently executing a kernel in eager debug mode.""" + return BaseContext.current().eager + + +def _compiled_program_id(context: CompiledContext, axis): + # Compiled. Note that tracing must be open coded on this + # function because it does not take a proxy as an argument + # (and therefore, the symbolic tracer exempts it from tracing + # according to its heuristic). + proxy = context.tracer.create_proxy("call_function", program_id, (axis,), {}) + return proxy + + +@custom_primitive_fn(compiled=_compiled_program_id) +def program_id(axis: int) -> int: + """Access the program id value for the given grid axis.""" + context = eager_context() + assert axis >= 0 and axis < context.rank + return context.current_thread[axis] diff --git a/python/shark_turbine/kernel/lang/types.py b/python/shark_turbine/kernel/lang/types.py new file mode 100644 index 000000000..138c149c7 --- /dev/null +++ b/python/shark_turbine/kernel/lang/types.py @@ -0,0 +1,39 @@ +import torch + +__all__ = [ + "KernelBuffer", + "Grid", +] + +Grid = tuple[int, ...] + + +class KernelBuffer: + """Represents a buffer in global memory. + + Top level kernels always operate on global memory via these + buffers, and the primary operations that can be performed on + them are loads/stores and DMAs to some form of compute + capable local buffer. + + When executing eagerly, these are backed by a normal torch + Tensor. When compiling, an appropriate duck-typed proxy + is used. + """ + + __slots__ = [ + "_tensor", + ] + + def __init__(self, tensor: torch.Tensor): + assert isinstance(tensor, torch.Tensor), f"Expected Tensor but got {tensor}" + self._tensor = tensor + + def __repr__(self): + return f"KernelBuffer({self._tensor})" + + def __setitem__(self, key, item): + self._tensor.__setitem__(key, item) + + def __getitem__(self, key): + return self._tensor.__getitem__(key) diff --git a/tests/kernel/simple_kernel_test.py b/tests/kernel/simple_kernel_test.py new file mode 100644 index 000000000..0bbb46a3e --- /dev/null +++ b/tests/kernel/simple_kernel_test.py @@ -0,0 +1,90 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest + +import torch + +import shark_turbine.kernel as tk + + +class Test(unittest.TestCase): + def testIotaEager(self): + @tk.gen.thread + def iota_kernel(out: tk.lang.KernelBuffer): + i = tk.lang.program_id(0) + out[i] = i + + print("iota_kernel:", iota_kernel) + print("iota_kernel[8]:", iota_kernel[8]) + print("iota_kernel[8, 1]:", iota_kernel[8, 1]) + out = torch.empty(8, dtype=torch.int32) + iota_kernel[8](out) + print(out) + + def testIotaFx(self): + @tk.gen.thread + def iota_kernel(out: tk.lang.KernelBuffer): + i = tk.lang.program_id(0) + out[i] = i + + print(iota_kernel._trace.gm.graph) + # Prints: + # .graph(): + # %out : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=out] + # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_global_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._global_buffer_setitem](args = (%out, %program_id, %program_id), kwargs = {}) + # return None + + def testSoftmax(self): + @tk.gen.thread + def softmax_kernel(input: tk.lang.KernelBuffer, output: tk.lang.KernelBuffer): + row_index = tk.lang.program_id(0) + input_row = input[row_index, :] + numerator = torch.exp(input_row - torch.max(input_row)) + output_row = numerator / torch.sum(numerator) + output[row_index, :] = output_row + # Some debugging info if in debug mode and processing the first row. + if tk.DEBUG and row_index == 0: + print(f"*** Input: {input}") + print(f"*** Output: {output}") + print( + f"*** Input Row[{row_index}]: {type(output_row).__name__}({input_row.shape})" + ) + print( + f"*** Output Row: {type(output_row).__name__}({output_row.shape})" + ) + + def softmax(x): + y = torch.empty_like(x) + softmax_kernel[x.shape[0]](x, y) + return y + + input = torch.rand((128, 64)) + generated = softmax(input) + actual = torch.softmax(input, -1) + torch.testing.assert_close(generated, actual) + print(softmax_kernel._trace.gm.graph) + # Prints: + # graph(): + # %input_1 : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=input] + # %output : shark_turbine.kernel.lang.types.KernelBuffer [num_users=1] = placeholder[target=output] + # %program_id : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%input_1, (%program_id, slice(None, None, None))), kwargs = {}) + # %max_1 : [num_users=1] = call_function[target=torch.max](args = (%getitem,), kwargs = {}) + # %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem, %max_1), kwargs = {}) + # %exp : [num_users=2] = call_function[target=torch.exp](args = (%sub,), kwargs = {}) + # %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%exp,), kwargs = {}) + # %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%exp, %sum_1), kwargs = {}) + # %program_id_1 : [num_users=1] = call_function[target=shark_turbine.kernel.lang.prims.program_id](args = (0,), kwargs = {}) + # %_kernel_buffer_setitem : [num_users=0] = call_function[target=shark_turbine.kernel._support.tracing._kernel_buffer_setitem](args = (%output, (%program_id_1, slice(None, None, None)), %truediv), kwargs = {}) + # return None + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()