Skip to content

Commit 0c658bd

Browse files
[tk] Implement basic vector and scalar code generation (nod-ai#220)
* Python value/type propagation * Loads/stores between KernelBuffer and vectors * Extended Python integer types * Python scalar operations
1 parent fac744b commit 0c658bd

File tree

13 files changed

+610
-79
lines changed

13 files changed

+610
-79
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Optional, Type, TypeVar
2+
3+
import threading
4+
5+
_tls = threading.local()
6+
7+
T = TypeVar("T")
8+
9+
10+
def push(context_type: Type[T], instance: T) -> T:
11+
"""Pushes an instance onto a thread-local context stack.
12+
13+
The context type must define an attribute __tk_context_idname__ which is
14+
a valid/unique identifier.
15+
"""
16+
assert isinstance(instance, context_type)
17+
key = context_type.__tk_context_idname__
18+
try:
19+
stack: list = getattr(_tls, key)
20+
except AttributeError:
21+
stack = []
22+
setattr(_tls, key, stack)
23+
stack.append(instance)
24+
return instance
25+
26+
27+
def pop(context_type: Type[T], expected: Optional[T] = None):
28+
"""Pops the current context off of the stack.
29+
30+
Raises IndexError if no current.
31+
"""
32+
stack: list = getattr(_tls, context_type.__tk_context_idname__)
33+
instance = stack.pop()
34+
assert (
35+
expected is None or expected is instance
36+
), f"mismatched context push/pop for {context_type}"
37+
38+
39+
def current(context_type: Type[T]) -> T:
40+
"""Returns the current context from the stack.
41+
42+
Raises IndexError on failure.
43+
"""
44+
try:
45+
stack: list = getattr(_tls, context_type.__tk_context_idname__)
46+
except AttributeError:
47+
raise IndexError(f"No current context for {context_type}")
48+
try:
49+
instance = stack[-1]
50+
except IndexError:
51+
raise IndexError(f"No current context for {context_type}")
52+
assert isinstance(instance, context_type)
53+
return instance

python/shark_turbine/kernel/_support/indexing.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from abc import ABC, abstractmethod
44
from enum import Enum
5-
import threading
65

76
import torch
87

8+
from .. import ops
9+
10+
from . import context
11+
912
__all__ = [
1013
"KernelBuffer",
1114
"Grid",
@@ -16,8 +19,6 @@
1619
"sym",
1720
]
1821

19-
_tls = threading.local()
20-
2122

2223
class NotSetType:
2324
...
@@ -329,10 +330,10 @@ def __repr__(self):
329330
return f"{type(self)}({self._tensor})"
330331

331332
def __setitem__(self, key, item):
332-
self._tensor.__setitem__(key, item)
333+
ops.kernel_buffer_setitem(self, key, item)
333334

334335
def __getitem__(self, key):
335-
return self._tensor.__getitem__(key)
336+
return ops.kernel_buffer_getitem(self, key)
336337

337338

338339
class InputBuffer(KernelBuffer):
@@ -357,6 +358,8 @@ class IndexingContext:
357358
symbols to concrete values.
358359
"""
359360

361+
__tk_context_idname__ = "IndexingContext"
362+
360363
def __init__(self):
361364
self.constant_bindings: dict[SymbolDef, int] = {}
362365

@@ -375,19 +378,10 @@ def get_static_value(self, sym: SymbolDef) -> Optional[int]:
375378
##### Context management.
376379
@staticmethod
377380
def current() -> "IndexingContext":
378-
try:
379-
return _tls.indexing_stack[-1]
380-
except (AttributeError, IndexError):
381-
raise AssertionError("no IndexingContext is active")
381+
return context.current(IndexingContext)
382382

383383
def __enter__(self) -> "IndexingContext":
384-
try:
385-
stack = _tls.indexing_stack
386-
except AttributeError:
387-
stack = []
388-
_tls.indexing_stack = stack
389-
stack.append(self)
390-
return self
384+
return context.push(IndexingContext, self)
391385

392386
def __exit__(self, exc_type, exc_val, exc_tb):
393-
_tls.indexing_stack.pop()
387+
context.pop(IndexingContext, self)

python/shark_turbine/kernel/_support/tracing.py

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional, TypeVar, Callable, Type, assert_type, cast
33

44
import functools
5-
import threading
65
import warnings
76

87
import torch.fx as fx
@@ -11,19 +10,18 @@
1110
KernelBuffer,
1211
)
1312

14-
_tls = threading.local()
15-
TCallable = TypeVar("TCallable", bound=Callable)
16-
17-
###############################################################################
18-
# Wrapped tracing trampolines for proxy objects.
19-
# These only get called during tracing of proxy objects.
20-
###############################################################################
13+
from ..lang.types import (
14+
Index,
15+
)
2116

17+
from .. import ops
18+
from ..ops.base import (
19+
OpDispatcher,
20+
)
2221

23-
@fx.wrap
24-
def _kernel_buffer_setitem(kernel_buffer: KernelBuffer, key, item) -> None:
25-
...
22+
from . import context
2623

24+
TCallable = TypeVar("TCallable", bound=Callable)
2725

2826
###############################################################################
2927
# Tracing machinery
@@ -42,8 +40,11 @@ def __init__(
4240
self.symbolic_shape = orig_type.symbolic_shape
4341
self.rank = orig_type.rank
4442

43+
def __getitem__(self, key):
44+
return ops.kernel_buffer_getitem(self, key)
45+
4546
def __setitem__(self, key, item):
46-
_kernel_buffer_setitem(self, key, item)
47+
ops.kernel_buffer_setitem(self, key, item)
4748

4849

4950
class KernelTracer(fx.Tracer):
@@ -68,28 +69,23 @@ def __init__(self, gm: fx.GraphModule):
6869
###############################################################################
6970

7071

71-
class BaseContext:
72+
class BaseContext(OpDispatcher):
73+
__tk_context_idname__ = "ExecutionContext"
74+
7275
def __init__(self, *, eager: bool):
7376
self.eager = eager
7477

7578
@staticmethod
7679
def current() -> "BaseContext":
77-
try:
78-
return _tls.context[-1]
79-
except (AttributeError, IndexError):
80-
raise RuntimeError("No context is on the stack")
80+
return context.current(BaseContext)
8181

8282
def __enter__(self) -> "BaseContext":
83-
try:
84-
stack = _tls.context
85-
except AttributeError:
86-
stack = []
87-
_tls.context = stack
88-
stack.append(self)
89-
return self
83+
context.push(OpDispatcher, self)
84+
return context.push(BaseContext, self)
9085

9186
def __exit__(self, exc_type, exc_val, exc_tb):
92-
_tls.context.pop()
87+
context.pop(OpDispatcher, self)
88+
context.pop(BaseContext, self)
9389

9490

9591
class EagerContext(BaseContext):
@@ -98,12 +94,44 @@ def __init__(self, rank: int = 0):
9894
self.rank = rank
9995
self.current_thread: list[int] = rank * [0]
10096

97+
def handle_thread_program_id(self, op, axis: int) -> int:
98+
assert axis >= 0 and axis < self.rank
99+
return Index(self.current_thread[axis])
100+
101+
def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
102+
return kernel_buffer._tensor.__getitem__(key)
103+
104+
def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
105+
kernel_buffer._tensor.__setitem__(key, item)
106+
101107

102108
class CompiledContext(BaseContext):
103109
def __init__(self, tracer: KernelTracer):
104110
super().__init__(eager=False)
105111
self.tracer = tracer
106112

113+
def handle_thread_program_id(self, op, axis: int) -> Index:
114+
proxy = self.tracer.create_proxy(
115+
"call_function", op, args=(axis,), kwargs={}, type_expr=Index
116+
)
117+
return proxy
118+
119+
def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
120+
return self.tracer.create_proxy(
121+
"call_function",
122+
op,
123+
args=(kernel_buffer, key),
124+
kwargs={},
125+
)
126+
127+
def handle_kernel_buffer_setitem(self, op, kernel_buffer: KernelBuffer, key, item):
128+
self.tracer.create_proxy(
129+
"call_function",
130+
target=op,
131+
args=(kernel_buffer, key, item),
132+
kwargs={},
133+
)
134+
107135

108136
###############################################################################
109137
# Launch context
@@ -129,28 +157,24 @@ def eager_execute(self, args, kwargs):
129157

130158

131159
class LaunchContext(ABC):
160+
__tk_context_idname__ = "ExecutionContext"
161+
132162
@staticmethod
133163
def current() -> "LaunchContext":
134164
try:
135-
return _tls.launch[-1]
136-
except (AttributeError, IndexError):
165+
return context.current(LaunchContext)
166+
except IndexError:
137167
warnings.warn(
138168
"defaulting to debug/eager execution of tk kernel launch "
139169
"because no launch context has been established"
140170
)
141171
return DebugLaunchContext()
142172

143173
def __enter__(self) -> "LaunchContext":
144-
try:
145-
stack = _tls.launch
146-
except AttributeError:
147-
stack = []
148-
_tls.launch = stack
149-
stack.append(self)
150-
return self
174+
return context.push(LaunchContext, self)
151175

152176
def __exit__(self, exc_type, exc_val, exc_tb):
153-
_tls.launch.pop()
177+
context.pop(LaunchContext, self)
154178

155179
@abstractmethod
156180
def launch(self, launchable: Launchable, args, kwargs):
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,30 @@
11
from iree.compiler.ir import (
2+
AffineConstantExpr,
3+
AffineExpr,
4+
AffineMap,
5+
AffineMapAttr,
6+
Attribute,
27
Context,
8+
DenseElementsAttr,
39
F32Type,
10+
FloatAttr,
411
FunctionType,
512
IndexType,
613
InsertionPoint,
14+
IntegerAttr,
15+
IntegerType,
716
Location,
817
Operation,
918
MemRefType,
19+
ShapedType,
1020
Type as IrType,
1121
Value,
22+
VectorType,
1223
)
1324

1425
from iree.compiler.dialects import (
26+
arith as arith_d,
1527
builtin as builtin_d,
1628
func as func_d,
29+
vector as vector_d,
1730
)

0 commit comments

Comments
 (0)