Skip to content

Commit

Permalink
[tk] Implement basic vector and scalar code generation (nod-ai#220)
Browse files Browse the repository at this point in the history
* Python value/type propagation
* Loads/stores between KernelBuffer and vectors
* Extended Python integer types
* Python scalar operations
  • Loading branch information
stellaraccident authored Dec 5, 2023
1 parent fac744b commit 0c658bd
Show file tree
Hide file tree
Showing 13 changed files with 610 additions and 79 deletions.
53 changes: 53 additions & 0 deletions python/shark_turbine/kernel/_support/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Optional, Type, TypeVar

import threading

_tls = threading.local()

T = TypeVar("T")


def push(context_type: Type[T], instance: T) -> T:
"""Pushes an instance onto a thread-local context stack.
The context type must define an attribute __tk_context_idname__ which is
a valid/unique identifier.
"""
assert isinstance(instance, context_type)
key = context_type.__tk_context_idname__
try:
stack: list = getattr(_tls, key)
except AttributeError:
stack = []
setattr(_tls, key, stack)
stack.append(instance)
return instance


def pop(context_type: Type[T], expected: Optional[T] = None):
"""Pops the current context off of the stack.
Raises IndexError if no current.
"""
stack: list = getattr(_tls, context_type.__tk_context_idname__)
instance = stack.pop()
assert (
expected is None or expected is instance
), f"mismatched context push/pop for {context_type}"


def current(context_type: Type[T]) -> T:
"""Returns the current context from the stack.
Raises IndexError on failure.
"""
try:
stack: list = getattr(_tls, context_type.__tk_context_idname__)
except AttributeError:
raise IndexError(f"No current context for {context_type}")
try:
instance = stack[-1]
except IndexError:
raise IndexError(f"No current context for {context_type}")
assert isinstance(instance, context_type)
return instance
28 changes: 11 additions & 17 deletions python/shark_turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from abc import ABC, abstractmethod
from enum import Enum
import threading

import torch

from .. import ops

from . import context

__all__ = [
"KernelBuffer",
"Grid",
Expand All @@ -16,8 +19,6 @@
"sym",
]

_tls = threading.local()


class NotSetType:
...
Expand Down Expand Up @@ -329,10 +330,10 @@ def __repr__(self):
return f"{type(self)}({self._tensor})"

def __setitem__(self, key, item):
self._tensor.__setitem__(key, item)
ops.kernel_buffer_setitem(self, key, item)

def __getitem__(self, key):
return self._tensor.__getitem__(key)
return ops.kernel_buffer_getitem(self, key)


class InputBuffer(KernelBuffer):
Expand All @@ -357,6 +358,8 @@ class IndexingContext:
symbols to concrete values.
"""

__tk_context_idname__ = "IndexingContext"

def __init__(self):
self.constant_bindings: dict[SymbolDef, int] = {}

Expand All @@ -375,19 +378,10 @@ def get_static_value(self, sym: SymbolDef) -> Optional[int]:
##### Context management.
@staticmethod
def current() -> "IndexingContext":
try:
return _tls.indexing_stack[-1]
except (AttributeError, IndexError):
raise AssertionError("no IndexingContext is active")
return context.current(IndexingContext)

def __enter__(self) -> "IndexingContext":
try:
stack = _tls.indexing_stack
except AttributeError:
stack = []
_tls.indexing_stack = stack
stack.append(self)
return self
return context.push(IndexingContext, self)

def __exit__(self, exc_type, exc_val, exc_tb):
_tls.indexing_stack.pop()
context.pop(IndexingContext, self)
94 changes: 59 additions & 35 deletions python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, TypeVar, Callable, Type, assert_type, cast

import functools
import threading
import warnings

import torch.fx as fx
Expand All @@ -11,19 +10,18 @@
KernelBuffer,
)

_tls = threading.local()
TCallable = TypeVar("TCallable", bound=Callable)

###############################################################################
# Wrapped tracing trampolines for proxy objects.
# These only get called during tracing of proxy objects.
###############################################################################
from ..lang.types import (
Index,
)

from .. import ops
from ..ops.base import (
OpDispatcher,
)

@fx.wrap
def _kernel_buffer_setitem(kernel_buffer: KernelBuffer, key, item) -> None:
...
from . import context

TCallable = TypeVar("TCallable", bound=Callable)

###############################################################################
# Tracing machinery
Expand All @@ -42,8 +40,11 @@ def __init__(
self.symbolic_shape = orig_type.symbolic_shape
self.rank = orig_type.rank

def __getitem__(self, key):
return ops.kernel_buffer_getitem(self, key)

def __setitem__(self, key, item):
_kernel_buffer_setitem(self, key, item)
ops.kernel_buffer_setitem(self, key, item)


class KernelTracer(fx.Tracer):
Expand All @@ -68,28 +69,23 @@ def __init__(self, gm: fx.GraphModule):
###############################################################################


class BaseContext:
class BaseContext(OpDispatcher):
__tk_context_idname__ = "ExecutionContext"

def __init__(self, *, eager: bool):
self.eager = eager

@staticmethod
def current() -> "BaseContext":
try:
return _tls.context[-1]
except (AttributeError, IndexError):
raise RuntimeError("No context is on the stack")
return context.current(BaseContext)

def __enter__(self) -> "BaseContext":
try:
stack = _tls.context
except AttributeError:
stack = []
_tls.context = stack
stack.append(self)
return self
context.push(OpDispatcher, self)
return context.push(BaseContext, self)

def __exit__(self, exc_type, exc_val, exc_tb):
_tls.context.pop()
context.pop(OpDispatcher, self)
context.pop(BaseContext, self)


class EagerContext(BaseContext):
Expand All @@ -98,12 +94,44 @@ def __init__(self, rank: int = 0):
self.rank = rank
self.current_thread: list[int] = rank * [0]

def handle_thread_program_id(self, op, axis: int) -> int:
assert axis >= 0 and axis < self.rank
return Index(self.current_thread[axis])

def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
return kernel_buffer._tensor.__getitem__(key)

def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
kernel_buffer._tensor.__setitem__(key, item)


class CompiledContext(BaseContext):
def __init__(self, tracer: KernelTracer):
super().__init__(eager=False)
self.tracer = tracer

def handle_thread_program_id(self, op, axis: int) -> Index:
proxy = self.tracer.create_proxy(
"call_function", op, args=(axis,), kwargs={}, type_expr=Index
)
return proxy

def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
return self.tracer.create_proxy(
"call_function",
op,
args=(kernel_buffer, key),
kwargs={},
)

def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
self.tracer.create_proxy(
"call_function",
target=op,
args=(kernel_buffer, key, item),
kwargs={},
)


###############################################################################
# Launch context
Expand All @@ -129,28 +157,24 @@ def eager_execute(self, args, kwargs):


class LaunchContext(ABC):
__tk_context_idname__ = "ExecutionContext"

@staticmethod
def current() -> "LaunchContext":
try:
return _tls.launch[-1]
except (AttributeError, IndexError):
return context.current(LaunchContext)
except IndexError:
warnings.warn(
"defaulting to debug/eager execution of tk kernel launch "
"because no launch context has been established"
)
return DebugLaunchContext()

def __enter__(self) -> "LaunchContext":
try:
stack = _tls.launch
except AttributeError:
stack = []
_tls.launch = stack
stack.append(self)
return self
return context.push(LaunchContext, self)

def __exit__(self, exc_type, exc_val, exc_tb):
_tls.launch.pop()
context.pop(LaunchContext, self)

@abstractmethod
def launch(self, launchable: Launchable, args, kwargs):
Expand Down
13 changes: 13 additions & 0 deletions python/shark_turbine/kernel/compiler/ir.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
from iree.compiler.ir import (
AffineConstantExpr,
AffineExpr,
AffineMap,
AffineMapAttr,
Attribute,
Context,
DenseElementsAttr,
F32Type,
FloatAttr,
FunctionType,
IndexType,
InsertionPoint,
IntegerAttr,
IntegerType,
Location,
Operation,
MemRefType,
ShapedType,
Type as IrType,
Value,
VectorType,
)

from iree.compiler.dialects import (
arith as arith_d,
builtin as builtin_d,
func as func_d,
vector as vector_d,
)
Loading

0 comments on commit 0c658bd

Please sign in to comment.