Skip to content

Commit

Permalink
Streaming response early disconnect mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmitry Maliuga committed Sep 6, 2024
1 parent 8e1fc9b commit 7d0bfbe
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 39 deletions.
35 changes: 12 additions & 23 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.responses import AsyncContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -192,33 +193,21 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(Response):
class _StreamingResponse(StreamingResponse):
def __init__(
self,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
self._info = info
# Disabling early disconnect to allow stacked middleware gracefull termination
super().__init__(content, status_code, headers, media_type, background, early_disconnect=False)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
22 changes: 14 additions & 8 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
early_disconnect: bool = True,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
Expand All @@ -223,6 +224,7 @@ def __init__(
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.early_disconnect = early_disconnect
self.init_headers(headers)

async def listen_for_disconnect(self, receive: Receive) -> None:
Expand All @@ -240,21 +242,25 @@ async def stream_response(self, send: Send) -> None:
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, (bytes, memoryview)):
chunk = chunk.encode(self.charset)
if self.early_disconnect:
if not isinstance(chunk, (bytes, memoryview)):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
if self.early_disconnect:
async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()

task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
else:
await self.stream_response(send)

if self.background is not None:
await self.background()
Expand Down
36 changes: 28 additions & 8 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,16 +1006,23 @@ async def endpoint(request: Request) -> Response:

@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
ordered_events: list[str] = []
unordered_events: list[str] = []

class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
def __init__(self, app: ASGIApp, version: int) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
self.events.append(f"{self.version}:STARTED")
ordered_events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
ordered_events.append(f"{self.version}:COMPLETED")

def background() -> None:
unordered_events.append(f"{self.version}:BACKGROUND")

res.background = BackgroundTask(background)
return res

async def sleepy(request: Request) -> Response:
Expand All @@ -1027,11 +1034,9 @@ async def sleepy(request: Request) -> Response:
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)],
)

scope = {
Expand All @@ -1051,7 +1056,7 @@ async def send(message: Message) -> None:

await app(scope, receive().__anext__, send)

assert events == [
assert ordered_events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
Expand All @@ -1074,6 +1079,21 @@ async def send(message: Message) -> None:
"1:COMPLETED",
]

assert sorted(unordered_events) == sorted(
[
"1:BACKGROUND",
"2:BACKGROUND",
"3:BACKGROUND",
"4:BACKGROUND",
"5:BACKGROUND",
"6:BACKGROUND",
"7:BACKGROUND",
"8:BACKGROUND",
"9:BACKGROUND",
"10:BACKGROUND",
]
)

assert sent == [
{
"type": "http.response.start",
Expand Down

0 comments on commit 7d0bfbe

Please sign in to comment.