@@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
159
159
160
160
161
161
@pytest .mark .asyncio (scope = "module" )
162
- async def test_asyncio_run (async_engine ):
162
+ @pytest .mark .parametrize ("stop" , [None , ["a stop string" ]])
163
+ async def test_asyncio_run (async_engine , stop ):
163
164
164
165
scheduler_config = await async_engine .get_scheduler_config ()
165
166
num_scheduler_steps = scheduler_config .num_scheduler_steps
@@ -169,6 +170,7 @@ async def run(prompt: str):
169
170
temperature = 0 ,
170
171
max_tokens = 32 ,
171
172
min_tokens = 32 ,
173
+ stop = stop ,
172
174
)
173
175
174
176
output_count = 0
@@ -203,7 +205,8 @@ async def run(prompt: str):
203
205
204
206
205
207
@pytest .mark .asyncio (scope = "module" )
206
- async def test_output_kinds (async_engine ):
208
+ @pytest .mark .parametrize ("stop" , [None , ["a stop string" ]])
209
+ async def test_output_kinds (async_engine , stop ):
207
210
"""Test that output_kind works as expected and that
208
211
results are equivalent across different kinds."""
209
212
@@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
214
217
temperature = 0 ,
215
218
max_tokens = 32 ,
216
219
min_tokens = 32 ,
220
+ stop = stop ,
217
221
)
218
222
219
223
async def run (prompt : str , kind : RequestOutputKind ):
@@ -229,6 +233,8 @@ async def run(prompt: str, kind: RequestOutputKind):
229
233
final_output = output
230
234
231
235
assert final_output is not None
236
+ assert final_output .finished
237
+
232
238
return (final_output .prompt_token_ids ,
233
239
final_output .outputs [0 ].token_ids ,
234
240
final_output .outputs [0 ].text , output_count )
@@ -241,16 +247,18 @@ async def run_deltas(prompt: str):
241
247
output_tokens : List [int ] = []
242
248
output_text = ""
243
249
output_count = 0
250
+ final_output = None
244
251
async for output in async_engine .generate (prompt ,
245
252
params ,
246
253
request_id = uid ()):
247
254
token_ids = output .outputs [0 ].token_ids
248
255
text = output .outputs [0 ].text
256
+ final_output = output
249
257
250
258
# Ensure we get prompt ids iff we haven't yet received output tokens
251
259
if output_tokens :
252
260
assert 1 <= len (token_ids ) <= num_scheduler_steps
253
- assert text
261
+ assert stop or text
254
262
assert not output .prompt_token_ids
255
263
else :
256
264
assert output .prompt_token_ids
@@ -260,6 +268,10 @@ async def run_deltas(prompt: str):
260
268
output_text += text
261
269
262
270
output_count += 1
271
+
272
+ assert final_output is not None
273
+ assert final_output .finished
274
+
263
275
return prompt_tokens , output_tokens , output_text , output_count
264
276
265
277
results = await asyncio .gather (
@@ -291,14 +303,16 @@ async def run_deltas(prompt: str):
291
303
292
304
293
305
@pytest .mark .asyncio (scope = "module" )
294
- async def test_cancellation (async_engine ):
306
+ @pytest .mark .parametrize ("stop" , [None , ["a stop string" ]])
307
+ async def test_cancellation (async_engine , stop ):
295
308
scheduler_config = await async_engine .get_scheduler_config ()
296
309
num_scheduler_steps = scheduler_config .num_scheduler_steps
297
310
298
311
sampling_params = SamplingParams (
299
312
temperature = 0 ,
300
313
min_tokens = 13 ,
301
314
max_tokens = 13 ,
315
+ stop = stop ,
302
316
)
303
317
304
318
stop_at = 5 if num_scheduler_steps == 1 else 1
@@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
319
333
320
334
321
335
@pytest .mark .asyncio (scope = "module" )
322
- async def test_delayed_generator (async_engine ):
336
+ @pytest .mark .parametrize ("stop" , [None , ["a stop string" ]])
337
+ async def test_delayed_generator (async_engine , stop ):
323
338
scheduler_config = await async_engine .get_scheduler_config ()
324
339
325
340
if scheduler_config .num_scheduler_steps != 1 :
@@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
329
344
temperature = 0 ,
330
345
min_tokens = 10 ,
331
346
max_tokens = 10 ,
347
+ stop = stop ,
332
348
)
333
349
334
350
stream = async_engine .generate ("test3" , sampling_params , request_id = uid ())
0 commit comments