Skip to content

Commit 18e9e1f

Browse files
authored
[HotFix] Fix final output truncation with stop string + streaming (#8468)
1 parent f57092c commit 18e9e1f

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

tests/async_engine/test_async_llm_engine.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
159159

160160

161161
@pytest.mark.asyncio(scope="module")
162-
async def test_asyncio_run(async_engine):
162+
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
163+
async def test_asyncio_run(async_engine, stop):
163164

164165
scheduler_config = await async_engine.get_scheduler_config()
165166
num_scheduler_steps = scheduler_config.num_scheduler_steps
@@ -169,6 +170,7 @@ async def run(prompt: str):
169170
temperature=0,
170171
max_tokens=32,
171172
min_tokens=32,
173+
stop=stop,
172174
)
173175

174176
output_count = 0
@@ -203,7 +205,8 @@ async def run(prompt: str):
203205

204206

205207
@pytest.mark.asyncio(scope="module")
206-
async def test_output_kinds(async_engine):
208+
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
209+
async def test_output_kinds(async_engine, stop):
207210
"""Test that output_kind works as expected and that
208211
results are equivalent across different kinds."""
209212

@@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
214217
temperature=0,
215218
max_tokens=32,
216219
min_tokens=32,
220+
stop=stop,
217221
)
218222

219223
async def run(prompt: str, kind: RequestOutputKind):
@@ -229,6 +233,8 @@ async def run(prompt: str, kind: RequestOutputKind):
229233
final_output = output
230234

231235
assert final_output is not None
236+
assert final_output.finished
237+
232238
return (final_output.prompt_token_ids,
233239
final_output.outputs[0].token_ids,
234240
final_output.outputs[0].text, output_count)
@@ -241,16 +247,18 @@ async def run_deltas(prompt: str):
241247
output_tokens: List[int] = []
242248
output_text = ""
243249
output_count = 0
250+
final_output = None
244251
async for output in async_engine.generate(prompt,
245252
params,
246253
request_id=uid()):
247254
token_ids = output.outputs[0].token_ids
248255
text = output.outputs[0].text
256+
final_output = output
249257

250258
# Ensure we get prompt ids iff we haven't yet received output tokens
251259
if output_tokens:
252260
assert 1 <= len(token_ids) <= num_scheduler_steps
253-
assert text
261+
assert stop or text
254262
assert not output.prompt_token_ids
255263
else:
256264
assert output.prompt_token_ids
@@ -260,6 +268,10 @@ async def run_deltas(prompt: str):
260268
output_text += text
261269

262270
output_count += 1
271+
272+
assert final_output is not None
273+
assert final_output.finished
274+
263275
return prompt_tokens, output_tokens, output_text, output_count
264276

265277
results = await asyncio.gather(
@@ -291,14 +303,16 @@ async def run_deltas(prompt: str):
291303

292304

293305
@pytest.mark.asyncio(scope="module")
294-
async def test_cancellation(async_engine):
306+
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
307+
async def test_cancellation(async_engine, stop):
295308
scheduler_config = await async_engine.get_scheduler_config()
296309
num_scheduler_steps = scheduler_config.num_scheduler_steps
297310

298311
sampling_params = SamplingParams(
299312
temperature=0,
300313
min_tokens=13,
301314
max_tokens=13,
315+
stop=stop,
302316
)
303317

304318
stop_at = 5 if num_scheduler_steps == 1 else 1
@@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
319333

320334

321335
@pytest.mark.asyncio(scope="module")
322-
async def test_delayed_generator(async_engine):
336+
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
337+
async def test_delayed_generator(async_engine, stop):
323338
scheduler_config = await async_engine.get_scheduler_config()
324339

325340
if scheduler_config.num_scheduler_steps != 1:
@@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
329344
temperature=0,
330345
min_tokens=10,
331346
max_tokens=10,
347+
stop=stop,
332348
)
333349

334350
stream = async_engine.generate("test3", sampling_params, request_id=uid())

vllm/sequence.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def get_output_text_to_return(self, buffer_length: int,
477477
if not delta:
478478
return self.output_text[:-buffer_length] if truncate else (
479479
self.output_text)
480-
length = len(self.output_text) - buffer_length
480+
length = len(self.output_text)
481+
if truncate:
482+
length -= buffer_length
481483
last_offset = self._last_output_text_offset
482484
if last_offset < length:
483485
self._last_output_text_offset = length

0 commit comments

Comments
 (0)