Skip to content

Commit

Permalink
Fix BackgroundTasks with BaseHTTPMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 6, 2024
1 parent 5cd92cd commit c7a77e5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
38 changes: 26 additions & 12 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response, StreamingResponse
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -193,21 +192,36 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self._info = info
# Disabling early disconnect to allow stacked middleware gracefull termination
super().__init__(content, status_code, headers, media_type, background, early_disconnect=False)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})

if self.background:
await self.background()
17 changes: 6 additions & 11 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def __init__(
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
early_disconnect: bool = True,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
Expand All @@ -224,7 +223,6 @@ def __init__(
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.early_disconnect = early_disconnect
self.init_headers(headers)

async def listen_for_disconnect(self, receive: Receive) -> None:
Expand All @@ -249,17 +247,14 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.early_disconnect:
async with anyio.create_task_group() as task_group:
async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()

task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
else:
await self.stream_response(send)
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))

if self.background is not None:
await self.background()
Expand Down
5 changes: 5 additions & 0 deletions tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ def test_streaming_response_memoryview(test_client_factory: TestClientFactory) -

@pytest.mark.anyio
async def test_streaming_response_stops_if_receiving_http_disconnect() -> None:
"""
Tests for:
- https://github.com/encode/starlette/issues/2516
- https://github.com/encode/starlette/pull/2687
"""
streamed = 0

disconnected = anyio.Event()
Expand Down

0 comments on commit c7a77e5

Please sign in to comment.