2
2
from typing import Optional , TypeVar , Callable , Type , assert_type , cast
3
3
4
4
import functools
5
- import threading
6
5
import warnings
7
6
8
7
import torch .fx as fx
11
10
KernelBuffer ,
12
11
)
13
12
14
- _tls = threading .local ()
15
- TCallable = TypeVar ("TCallable" , bound = Callable )
16
-
17
- ###############################################################################
18
- # Wrapped tracing trampolines for proxy objects.
19
- # These only get called during tracing of proxy objects.
20
- ###############################################################################
13
+ from ..lang .types import (
14
+ Index ,
15
+ )
21
16
17
+ from .. import ops
18
+ from ..ops .base import (
19
+ OpDispatcher ,
20
+ )
22
21
23
- @fx .wrap
24
- def _kernel_buffer_setitem (kernel_buffer : KernelBuffer , key , item ) -> None :
25
- ...
22
+ from . import context
26
23
24
+ TCallable = TypeVar ("TCallable" , bound = Callable )
27
25
28
26
###############################################################################
29
27
# Tracing machinery
@@ -42,8 +40,11 @@ def __init__(
42
40
self .symbolic_shape = orig_type .symbolic_shape
43
41
self .rank = orig_type .rank
44
42
43
+ def __getitem__ (self , key ):
44
+ return ops .kernel_buffer_getitem (self , key )
45
+
45
46
def __setitem__ (self , key , item ):
46
- _kernel_buffer_setitem (self , key , item )
47
+ ops . kernel_buffer_setitem (self , key , item )
47
48
48
49
49
50
class KernelTracer (fx .Tracer ):
@@ -68,28 +69,23 @@ def __init__(self, gm: fx.GraphModule):
68
69
###############################################################################
69
70
70
71
71
- class BaseContext :
72
+ class BaseContext (OpDispatcher ):
73
+ __tk_context_idname__ = "ExecutionContext"
74
+
72
75
def __init__ (self , * , eager : bool ):
73
76
self .eager = eager
74
77
75
78
@staticmethod
76
79
def current () -> "BaseContext" :
77
- try :
78
- return _tls .context [- 1 ]
79
- except (AttributeError , IndexError ):
80
- raise RuntimeError ("No context is on the stack" )
80
+ return context .current (BaseContext )
81
81
82
82
def __enter__ (self ) -> "BaseContext" :
83
- try :
84
- stack = _tls .context
85
- except AttributeError :
86
- stack = []
87
- _tls .context = stack
88
- stack .append (self )
89
- return self
83
+ context .push (OpDispatcher , self )
84
+ return context .push (BaseContext , self )
90
85
91
86
def __exit__ (self , exc_type , exc_val , exc_tb ):
92
- _tls .context .pop ()
87
+ context .pop (OpDispatcher , self )
88
+ context .pop (BaseContext , self )
93
89
94
90
95
91
class EagerContext (BaseContext ):
@@ -98,12 +94,44 @@ def __init__(self, rank: int = 0):
98
94
self .rank = rank
99
95
self .current_thread : list [int ] = rank * [0 ]
100
96
97
+ def handle_thread_program_id (self , op , axis : int ) -> int :
98
+ assert axis >= 0 and axis < self .rank
99
+ return Index (self .current_thread [axis ])
100
+
101
+ def handle_kernel_buffer_getitem (self , op , kernel_buffer : KernelBuffer , key ):
102
+ return kernel_buffer ._tensor .__getitem__ (key )
103
+
104
+ def handle_kernel_buffer_setitem (self , op , kernel_buffer : KernelBuffer , key , item ):
105
+ kernel_buffer ._tensor .__setitem__ (key , item )
106
+
101
107
102
108
class CompiledContext (BaseContext ):
103
109
def __init__ (self , tracer : KernelTracer ):
104
110
super ().__init__ (eager = False )
105
111
self .tracer = tracer
106
112
113
+ def handle_thread_program_id (self , op , axis : int ) -> Index :
114
+ proxy = self .tracer .create_proxy (
115
+ "call_function" , op , args = (axis ,), kwargs = {}, type_expr = Index
116
+ )
117
+ return proxy
118
+
119
+ def handle_kernel_buffer_getitem (self , op , kernel_buffer : KernelBuffer , key ):
120
+ return self .tracer .create_proxy (
121
+ "call_function" ,
122
+ op ,
123
+ args = (kernel_buffer , key ),
124
+ kwargs = {},
125
+ )
126
+
127
+ def handle_kernel_buffer_setitem (self , op , kernel_buffer : KernelBuffer , key , item ):
128
+ self .tracer .create_proxy (
129
+ "call_function" ,
130
+ target = op ,
131
+ args = (kernel_buffer , key , item ),
132
+ kwargs = {},
133
+ )
134
+
107
135
108
136
###############################################################################
109
137
# Launch context
@@ -129,28 +157,24 @@ def eager_execute(self, args, kwargs):
129
157
130
158
131
159
class LaunchContext (ABC ):
160
+ __tk_context_idname__ = "ExecutionContext"
161
+
132
162
@staticmethod
133
163
def current () -> "LaunchContext" :
134
164
try :
135
- return _tls . launch [ - 1 ]
136
- except ( AttributeError , IndexError ) :
165
+ return context . current ( LaunchContext )
166
+ except IndexError :
137
167
warnings .warn (
138
168
"defaulting to debug/eager execution of tk kernel launch "
139
169
"because no launch context has been established"
140
170
)
141
171
return DebugLaunchContext ()
142
172
143
173
def __enter__ (self ) -> "LaunchContext" :
144
- try :
145
- stack = _tls .launch
146
- except AttributeError :
147
- stack = []
148
- _tls .launch = stack
149
- stack .append (self )
150
- return self
174
+ return context .push (LaunchContext , self )
151
175
152
176
def __exit__ (self , exc_type , exc_val , exc_tb ):
153
- _tls . launch . pop ()
177
+ context . pop (LaunchContext , self )
154
178
155
179
@abstractmethod
156
180
def launch (self , launchable : Launchable , args , kwargs ):
0 commit comments