Skip to content

Commit 50ccd8d

Browse files
authored
Use a polling thread per device (#778)
* Use a polling thread per device * fix test * Make sure that the poller stops when the device is gone * oh pypy * Small improvements * codegen
1 parent e2c2f13 commit 50ccd8d

File tree

6 files changed

+429
-144
lines changed

6 files changed

+429
-144
lines changed

tests/test_async.py

Lines changed: 45 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import time
2+
import threading
3+
14
import anyio
25

36
from pytest import mark, raises
@@ -12,13 +15,19 @@ class GPUPromise(BaseGPUPromise):
1215
# Subclass with each own set of unresolved promise instances
1316
_UNRESOLVED = set()
1417

18+
def _sync_wait(self):
19+
# Same implementation as the wgpu_native backend.
20+
# If we have a test that has not polling thread, and sync_wait() is called
21+
# when the promise is still pending, this will hang.
22+
self._thread_event.wait()
23+
1524

1625
class SillyLoop:
1726
def __init__(self):
1827
self._pending_calls = []
1928
self.errors = []
2029

21-
def call_soon(self, f, *args):
30+
def call_soon_threadsafe(self, f, *args):
2231
self._pending_calls.append((f, args))
2332

2433
def process_events(self):
@@ -73,23 +82,18 @@ def test_promise_basics():
7382
# %%%%% Promise using sync_wait
7483

7584

76-
def test_promise_sync_need_poll():
77-
promise = GPUPromise("test", None)
78-
79-
with raises(RuntimeError): # cannot poll without poll function
80-
promise.sync_wait()
85+
def run_in_thread(callable):
86+
t = threading.Thread(target=callable)
87+
t.start()
8188

8289

8390
def test_promise_sync_simple():
84-
count = 0
85-
91+
@run_in_thread
8692
def poller():
87-
nonlocal count
88-
count += 1
89-
if count > 5:
90-
promise._wgpu_set_input(42)
93+
time.sleep(0.1)
94+
promise._wgpu_set_input(42)
9195

92-
promise = GPUPromise("test", None, poller=poller)
96+
promise = GPUPromise("test", None)
9397

9498
result = promise.sync_wait()
9599
assert result == 42
@@ -99,15 +103,12 @@ def test_promise_sync_normal():
99103
def handler(input):
100104
return input * 2
101105

102-
count = 0
103-
106+
@run_in_thread
104107
def poller():
105-
nonlocal count
106-
count += 1
107-
if count > 5:
108-
promise._wgpu_set_input(42)
108+
time.sleep(0.1)
109+
promise._wgpu_set_input(42)
109110

110-
promise = GPUPromise("test", handler, poller=poller)
111+
promise = GPUPromise("test", handler)
111112

112113
result = promise.sync_wait()
113114
assert result == 84
@@ -117,15 +118,12 @@ def test_promise_sync_fail1():
117118
def handler(input):
118119
return input * 2
119120

120-
count = 0
121-
121+
@run_in_thread
122122
def poller():
123-
nonlocal count
124-
count += 1
125-
if count > 5:
126-
promise._wgpu_set_error(ZeroDivisionError())
123+
time.sleep(0.1)
124+
promise._wgpu_set_error(ZeroDivisionError())
127125

128-
promise = GPUPromise("test", handler, poller=poller)
126+
promise = GPUPromise("test", handler)
129127

130128
with raises(ZeroDivisionError):
131129
promise.sync_wait()
@@ -135,15 +133,12 @@ def test_promise_sync_fail2():
135133
def handler(input):
136134
return input / 0
137135

138-
count = 0
139-
136+
@run_in_thread
140137
def poller():
141-
nonlocal count
142-
count += 1
143-
if count > 5:
144-
promise._wgpu_set_input(42)
138+
time.sleep(0.1)
139+
promise._wgpu_set_input(42)
145140

146-
promise = GPUPromise("test", handler, poller=poller)
141+
promise = GPUPromise("test", handler)
147142

148143
with raises(ZeroDivisionError):
149144
promise.sync_wait()
@@ -152,25 +147,14 @@ def poller():
152147
# %% Promise using await with poll and loop
153148

154149

155-
@mark.anyio
156-
async def test_promise_async_need_poll_or_loop():
157-
promise = GPUPromise("test", None)
158-
159-
with raises(RuntimeError): # cannot poll without poll function
160-
await promise
161-
162-
163150
@mark.anyio
164151
async def test_promise_async_poll_simple():
165-
count = 0
166-
152+
@run_in_thread
167153
def poller():
168-
nonlocal count
169-
count += 1
170-
if count > 5:
171-
promise._wgpu_set_input(42)
154+
time.sleep(0.1)
155+
promise._wgpu_set_input(42)
172156

173-
promise = GPUPromise("test", None, poller=poller)
157+
promise = GPUPromise("test", None)
174158

175159
result = await promise
176160
assert result == 42
@@ -181,15 +165,12 @@ async def test_promise_async_poll_normal():
181165
def handler(input):
182166
return input * 2
183167

184-
count = 0
185-
168+
@run_in_thread
186169
def poller():
187-
nonlocal count
188-
count += 1
189-
if count > 5:
190-
promise._wgpu_set_input(42)
170+
time.sleep(0.1)
171+
promise._wgpu_set_input(42)
191172

192-
promise = GPUPromise("test", handler, poller=poller)
173+
promise = GPUPromise("test", handler)
193174

194175
result = await promise
195176
assert result == 84
@@ -200,15 +181,12 @@ async def test_promise_async_poll_fail1():
200181
def handler(input):
201182
return input * 2
202183

203-
count = 0
204-
184+
@run_in_thread
205185
def poller():
206-
nonlocal count
207-
count += 1
208-
if count > 5:
209-
promise._wgpu_set_error(ZeroDivisionError())
186+
time.sleep(0.1)
187+
promise._wgpu_set_error(ZeroDivisionError())
210188

211-
promise = GPUPromise("test", handler, poller=poller)
189+
promise = GPUPromise("test", handler)
212190

213191
with raises(ZeroDivisionError):
214192
await promise
@@ -219,15 +197,12 @@ async def test_promise_async_poll_fail2():
219197
def handler(input):
220198
return input / 0
221199

222-
count = 0
223-
200+
@run_in_thread
224201
def poller():
225-
nonlocal count
226-
count += 1
227-
if count > 5:
228-
promise._wgpu_set_input(42)
202+
time.sleep(0.1)
203+
promise._wgpu_set_input(42)
229204

230-
promise = GPUPromise("test", handler, poller=poller)
205+
promise = GPUPromise("test", handler)
231206

232207
with raises(ZeroDivisionError):
233208
await promise

tests/test_wgpu_native_poller.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import gc
2+
import time
3+
import queue
4+
5+
import wgpu
6+
from wgpu.backends.wgpu_native._poller import PollThread, PollToken
7+
8+
from testutils import can_use_wgpu_lib, run_tests, is_pypy
9+
from pytest import mark
10+
11+
12+
def test_poll_thread():
13+
# A timeout to give polling thread time to progress. The GIL switches
14+
# threads about every 5ms, but in this cases likely faster, because it also switches
15+
# when it goes to sleep on a blocking call. So 50ms seems plenty.
16+
timeout = 0.05
17+
18+
count = 0
19+
gpu_work_done_queue = queue.SimpleQueue()
20+
21+
def reset():
22+
nonlocal count
23+
ref_count = count
24+
# Make sure the poller is not waiting in poll_func
25+
gpu_work_done_queue.put(None)
26+
gpu_work_done_queue.put(None)
27+
# Give it time
28+
time.sleep(timeout)
29+
# Check that it did not enter again, i.e. is waiting for tokens
30+
assert count == ref_count, "Looks like a token is still active"
31+
# Reset
32+
count = 0
33+
while True:
34+
try:
35+
gpu_work_done_queue.get(False)
36+
except queue.Empty:
37+
break
38+
39+
def finish_tokens(*tokens):
40+
# This mimics the GPU finishing an async task, and invoking its
41+
# callback that sets the token to done.
42+
gpu_work_done_queue.put(None)
43+
for token in tokens:
44+
assert not token.is_done()
45+
token.set_done()
46+
47+
def poll_func(block):
48+
# This mimics the wgpuDevicePoll.
49+
nonlocal count
50+
count += 1
51+
if block:
52+
gpu_work_done_queue.get() # blocking
53+
else:
54+
try:
55+
gpu_work_done_queue.get(False)
56+
except queue.Empty:
57+
pass
58+
59+
# Start the poller
60+
t = PollThread(poll_func)
61+
t.start()
62+
63+
reset()
64+
65+
# == Normal behavior
66+
67+
token = t.get_token()
68+
assert isinstance(token, PollToken)
69+
time.sleep(timeout)
70+
assert count == 2
71+
72+
finish_tokens(token)
73+
74+
time.sleep(timeout)
75+
assert count == 2
76+
77+
reset()
78+
79+
# == Always at least one poll
80+
81+
token = t.get_token()
82+
token.set_done()
83+
time.sleep(timeout)
84+
assert count in (1, 2) # typically 1, but can sometimes be 2
85+
86+
reset()
87+
88+
# == Mark done through deletion
89+
90+
token = t.get_token()
91+
time.sleep(timeout)
92+
assert count == 2
93+
94+
finish_tokens()
95+
96+
time.sleep(timeout)
97+
assert count == 3
98+
99+
finish_tokens()
100+
101+
time.sleep(timeout)
102+
assert count == 4
103+
104+
del token
105+
gc.collect()
106+
gc.collect()
107+
108+
finish_tokens()
109+
110+
time.sleep(timeout)
111+
assert count == 4
112+
113+
reset()
114+
115+
# More tasks
116+
117+
token1 = t.get_token()
118+
time.sleep(timeout)
119+
assert count == 2
120+
121+
token2 = t.get_token()
122+
time.sleep(timeout)
123+
assert count == 2
124+
125+
token3 = t.get_token()
126+
token4 = t.get_token()
127+
time.sleep(timeout)
128+
assert count == 2
129+
130+
finish_tokens(token1)
131+
time.sleep(timeout)
132+
assert count == 3
133+
134+
finish_tokens(token2, token3)
135+
time.sleep(timeout)
136+
assert count == 4
137+
138+
finish_tokens() # can actually bump more unrelated works
139+
finish_tokens()
140+
time.sleep(timeout)
141+
assert count == 6
142+
143+
token5 = t.get_token()
144+
finish_tokens(token4)
145+
time.sleep(timeout)
146+
assert count == 7
147+
148+
finish_tokens(token5)
149+
time.sleep(timeout)
150+
assert count == 8
151+
152+
reset()
153+
154+
# Shut it down
155+
156+
t.stop()
157+
time.sleep(0.1)
158+
assert not t.is_alive()
159+
160+
161+
@mark.skipif(not can_use_wgpu_lib, reason="Needs wgpu lib")
162+
def test_poller_stops_when_device_gone():
163+
device = wgpu.gpu.request_adapter_sync().request_device_sync()
164+
165+
t = device._poller
166+
assert t.is_alive()
167+
device.__del__()
168+
time.sleep(0.1)
169+
170+
assert not t.is_alive()
171+
172+
device = wgpu.gpu.request_adapter_sync().request_device_sync()
173+
174+
t = device._poller
175+
assert t.is_alive()
176+
del device
177+
gc.collect()
178+
gc.collect()
179+
if is_pypy:
180+
gc.collect()
181+
gc.collect()
182+
time.sleep(0.1)
183+
184+
assert not t.is_alive()
185+
186+
187+
if __name__ == "__main__":
188+
run_tests(globals())

0 commit comments

Comments
 (0)