Skip to content

Commit 2d1b9ba

Browse files
authored
[Bugfix] Fix request cancellation without polling (#11190)
1 parent f9ecbb1 commit 2d1b9ba

File tree

12 files changed

+164
-103
lines changed

12 files changed

+164
-103
lines changed

tests/entrypoints/openai/test_basic.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
from http import HTTPStatus
23
from typing import List
34

5+
import openai
46
import pytest
57
import pytest_asyncio
68
import requests
@@ -103,3 +105,52 @@ async def test_check_health(server: RemoteOpenAIServer):
103105
response = requests.get(server.url_for("health"))
104106

105107
assert response.status_code == HTTPStatus.OK
108+
109+
110+
@pytest.mark.parametrize(
111+
"server_args",
112+
[
113+
pytest.param(["--max-model-len", "10100"],
114+
id="default-frontend-multiprocessing"),
115+
pytest.param(
116+
["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
117+
id="disable-frontend-multiprocessing")
118+
],
119+
indirect=True,
120+
)
121+
@pytest.mark.asyncio
122+
async def test_request_cancellation(server: RemoteOpenAIServer):
123+
# clunky test: send an ungodly amount of load in with short timeouts
124+
# then ensure that it still responds quickly afterwards
125+
126+
chat_input = [{"role": "user", "content": "Write a long story"}]
127+
client = server.get_async_client(timeout=0.5)
128+
tasks = []
129+
# Request about 2 million tokens
130+
for _ in range(200):
131+
task = asyncio.create_task(
132+
client.chat.completions.create(messages=chat_input,
133+
model=MODEL_NAME,
134+
max_tokens=10000,
135+
extra_body={"min_tokens": 10000}))
136+
tasks.append(task)
137+
138+
done, pending = await asyncio.wait(tasks,
139+
return_when=asyncio.ALL_COMPLETED)
140+
141+
# Make sure all requests were sent to the server and timed out
142+
# (We don't want to hide other errors like 400s that would invalidate this
143+
# test)
144+
assert len(pending) == 0
145+
for d in done:
146+
with pytest.raises(openai.APITimeoutError):
147+
d.result()
148+
149+
# If the server had not cancelled all the other requests, then it would not
150+
# be able to respond to this one within the timeout
151+
client = server.get_async_client(timeout=5)
152+
response = await client.chat.completions.create(messages=chat_input,
153+
model=MODEL_NAME,
154+
max_tokens=10)
155+
156+
assert len(response.choices) == 1

tests/test_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import os
33
import socket
4-
from functools import partial
54
from typing import AsyncIterator, Tuple
65

76
import pytest
@@ -26,10 +25,7 @@ async def mock_async_iterator(idx: int):
2625
print(f"iterator {idx} cancelled")
2726

2827
iterators = [mock_async_iterator(i) for i in range(3)]
29-
merged_iterator = merge_async_iterators(*iterators,
30-
is_cancelled=partial(asyncio.sleep,
31-
0,
32-
result=False))
28+
merged_iterator = merge_async_iterators(*iterators)
3329

3430
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
3531
async for idx, output in generator:

tests/utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,11 @@ def get_client(self):
163163
api_key=self.DUMMY_API_KEY,
164164
)
165165

166-
def get_async_client(self):
167-
return openai.AsyncOpenAI(
168-
base_url=self.url_for("v1"),
169-
api_key=self.DUMMY_API_KEY,
170-
max_retries=0,
171-
)
166+
def get_async_client(self, **kwargs):
167+
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
168+
api_key=self.DUMMY_API_KEY,
169+
max_retries=0,
170+
**kwargs)
172171

173172

174173
def _test_completion(

vllm/engine/async_llm_engine.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,16 +1065,20 @@ async def generate(
10651065
>>> # Process and return the final output
10661066
>>> ...
10671067
"""
1068-
async for output in await self.add_request(
1069-
request_id,
1070-
prompt,
1071-
sampling_params,
1072-
lora_request=lora_request,
1073-
trace_headers=trace_headers,
1074-
prompt_adapter_request=prompt_adapter_request,
1075-
priority=priority,
1076-
):
1077-
yield LLMEngine.validate_output(output, RequestOutput)
1068+
try:
1069+
async for output in await self.add_request(
1070+
request_id,
1071+
prompt,
1072+
sampling_params,
1073+
lora_request=lora_request,
1074+
trace_headers=trace_headers,
1075+
prompt_adapter_request=prompt_adapter_request,
1076+
priority=priority,
1077+
):
1078+
yield LLMEngine.validate_output(output, RequestOutput)
1079+
except asyncio.CancelledError:
1080+
await self.abort(request_id)
1081+
raise
10781082

10791083
async def encode(
10801084
self,
@@ -1147,15 +1151,19 @@ async def encode(
11471151
>>> # Process and return the final output
11481152
>>> ...
11491153
"""
1150-
async for output in await self.add_request(
1151-
request_id,
1152-
prompt,
1153-
pooling_params,
1154-
lora_request=lora_request,
1155-
trace_headers=trace_headers,
1156-
priority=priority,
1157-
):
1158-
yield LLMEngine.validate_output(output, PoolingRequestOutput)
1154+
try:
1155+
async for output in await self.add_request(
1156+
request_id,
1157+
prompt,
1158+
pooling_params,
1159+
lora_request=lora_request,
1160+
trace_headers=trace_headers,
1161+
priority=priority,
1162+
):
1163+
yield LLMEngine.validate_output(output, PoolingRequestOutput)
1164+
except asyncio.CancelledError:
1165+
await self.abort(request_id)
1166+
raise
11591167

11601168
async def abort(self, request_id: str) -> None:
11611169
"""Abort a request.

vllm/entrypoints/api_server.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from vllm.engine.arg_utils import AsyncEngineArgs
1818
from vllm.engine.async_llm_engine import AsyncLLMEngine
1919
from vllm.entrypoints.launcher import serve_http
20+
from vllm.entrypoints.utils import with_cancellation
2021
from vllm.logger import init_logger
2122
from vllm.sampling_params import SamplingParams
2223
from vllm.usage.usage_lib import UsageContext
23-
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
24-
random_uuid)
24+
from vllm.utils import FlexibleArgumentParser, random_uuid
2525
from vllm.version import __version__ as VLLM_VERSION
2626

2727
logger = init_logger("vllm.entrypoints.api_server")
@@ -47,15 +47,18 @@ async def generate(request: Request) -> Response:
4747
- other fields: the sampling parameters (See `SamplingParams` for details).
4848
"""
4949
request_dict = await request.json()
50+
return await _generate(request_dict, raw_request=request)
51+
52+
53+
@with_cancellation
54+
async def _generate(request_dict: dict, raw_request: Request) -> Response:
5055
prompt = request_dict.pop("prompt")
5156
stream = request_dict.pop("stream", False)
5257
sampling_params = SamplingParams(**request_dict)
5358
request_id = random_uuid()
5459

5560
assert engine is not None
5661
results_generator = engine.generate(prompt, sampling_params, request_id)
57-
results_generator = iterate_with_cancellation(
58-
results_generator, is_cancelled=request.is_disconnected)
5962

6063
# Streaming case
6164
async def stream_results() -> AsyncGenerator[bytes, None]:

vllm/entrypoints/openai/api_server.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from vllm.entrypoints.openai.serving_tokenization import (
6060
OpenAIServingTokenization)
6161
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
62+
from vllm.entrypoints.utils import with_cancellation
6263
from vllm.logger import init_logger
6364
from vllm.usage.usage_lib import UsageContext
6465
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
@@ -311,6 +312,7 @@ async def health(raw_request: Request) -> Response:
311312

312313

313314
@router.post("/tokenize")
315+
@with_cancellation
314316
async def tokenize(request: TokenizeRequest, raw_request: Request):
315317
handler = tokenization(raw_request)
316318

@@ -325,6 +327,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
325327

326328

327329
@router.post("/detokenize")
330+
@with_cancellation
328331
async def detokenize(request: DetokenizeRequest, raw_request: Request):
329332
handler = tokenization(raw_request)
330333

@@ -353,6 +356,7 @@ async def show_version():
353356

354357

355358
@router.post("/v1/chat/completions")
359+
@with_cancellation
356360
async def create_chat_completion(request: ChatCompletionRequest,
357361
raw_request: Request):
358362
handler = chat(raw_request)
@@ -373,6 +377,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
373377

374378

375379
@router.post("/v1/completions")
380+
@with_cancellation
376381
async def create_completion(request: CompletionRequest, raw_request: Request):
377382
handler = completion(raw_request)
378383
if handler is None:
@@ -390,6 +395,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
390395

391396

392397
@router.post("/v1/embeddings")
398+
@with_cancellation
393399
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
394400
handler = embedding(raw_request)
395401
if handler is None:
@@ -407,6 +413,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
407413

408414

409415
@router.post("/score")
416+
@with_cancellation
410417
async def create_score(request: ScoreRequest, raw_request: Request):
411418
handler = score(raw_request)
412419
if handler is None:
@@ -424,6 +431,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
424431

425432

426433
@router.post("/v1/score")
434+
@with_cancellation
427435
async def create_score_v1(request: ScoreRequest, raw_request: Request):
428436
logger.warning(
429437
"To indicate that Score API is not part of standard OpenAI API, we "

vllm/entrypoints/openai/serving_chat.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from vllm.sequence import Logprob
3333
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
3434
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
35-
from vllm.utils import iterate_with_cancellation
3635

3736
logger = init_logger(__name__)
3837

@@ -234,10 +233,6 @@ async def create_chat_completion(
234233
assert len(generators) == 1
235234
result_generator, = generators
236235

237-
if raw_request:
238-
result_generator = iterate_with_cancellation(
239-
result_generator, raw_request.is_disconnected)
240-
241236
# Streaming response
242237
if request.stream:
243238
return self.chat_completion_stream_generator(

vllm/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ async def create_completion(
159159
# TODO: Use a vllm-specific Validation Error
160160
return self.create_error_response(str(e))
161161

162-
result_generator = merge_async_iterators(
163-
*generators, is_cancelled=raw_request.is_disconnected)
162+
result_generator = merge_async_iterators(*generators)
164163

165164
model_name = self._get_model_name(lora_request)
166165
num_prompts = len(engine_prompts)

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,7 @@ async def create_embedding(
202202
# TODO: Use a vllm-specific Validation Error
203203
return self.create_error_response(str(e))
204204

205-
result_generator = merge_async_iterators(
206-
*generators,
207-
is_cancelled=raw_request.is_disconnected if raw_request else None,
208-
)
205+
result_generator = merge_async_iterators(*generators)
209206

210207
num_prompts = len(engine_prompts)
211208

vllm/entrypoints/openai/serving_score.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,7 @@ async def create_score(
186186
# TODO: Use a vllm-specific Validation Error
187187
return self.create_error_response(str(e))
188188

189-
result_generator = merge_async_iterators(
190-
*generators,
191-
is_cancelled=raw_request.is_disconnected if raw_request else None,
192-
)
189+
result_generator = merge_async_iterators(*generators)
193190

194191
num_prompts = len(engine_prompts)
195192

vllm/entrypoints/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import asyncio
2+
import functools
3+
4+
from fastapi import Request
5+
6+
7+
async def listen_for_disconnect(request: Request) -> None:
8+
"""Returns if a disconnect message is received"""
9+
while True:
10+
message = await request.receive()
11+
if message["type"] == "http.disconnect":
12+
break
13+
14+
15+
def with_cancellation(handler_func):
16+
"""Decorator that allows a route handler to be cancelled by client
17+
disconnections.
18+
19+
This does _not_ use request.is_disconnected, which does not work with
20+
middleware. Instead this follows the pattern from
21+
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
22+
to wait for an http disconnect message, and the other to do the work that we
23+
want done. When the first task finishes, the other is cancelled.
24+
25+
A core assumption of this method is that the body of the request has already
26+
been read. This is a safe assumption to make for fastapi handlers that have
27+
already parsed the body of the request into a pydantic model for us.
28+
This decorator is unsafe to use elsewhere, as it will consume and throw away
29+
all incoming messages for the request while it looks for a disconnect
30+
message.
31+
32+
In the case where a `StreamingResponse` is returned by the handler, this
33+
wrapper will stop listening for disconnects and instead the response object
34+
will start listening for disconnects.
35+
"""
36+
37+
# Functools.wraps is required for this wrapper to appear to fastapi as a
38+
# normal route handler, with the correct request type hinting.
39+
@functools.wraps(handler_func)
40+
async def wrapper(*args, **kwargs):
41+
42+
# The request is either the second positional arg or `raw_request`
43+
request = args[1] if len(args) > 1 else kwargs["raw_request"]
44+
45+
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
46+
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
47+
48+
done, pending = await asyncio.wait([handler_task, cancellation_task],
49+
return_when=asyncio.FIRST_COMPLETED)
50+
for task in pending:
51+
task.cancel()
52+
53+
if handler_task in done:
54+
return handler_task.result()
55+
return None
56+
57+
return wrapper

0 commit comments

Comments
 (0)