Skip to content

Commit

Permalink
Add support for ASGI pathsend extension in BaseHTTPMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed Sep 2, 2024
1 parent 8fa5837 commit 9286c83
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 3 deletions.
1 change: 0 additions & 1 deletion docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ around explicitly, rather than mutating the middleware instance.
Currently, the `BaseHTTPMiddleware` has some known limitations:

- Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior).
- Using `BaseHTTPMiddleware` will prevent [ASGI pathsend extension](https://asgi.readthedocs.io/en/latest/extensions.html#path-send) to work properly. Thus, if you run your Starlette application with a server implementing this extension, routes returning [FileResponse](responses.md#fileresponse) should avoid the usage of this middleware.

To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below.

Expand Down
6 changes: 5 additions & 1 deletion starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
BodyStreamGenerator = typing.AsyncGenerator[typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None]
T = typing.TypeVar("T")


Expand Down Expand Up @@ -165,9 +166,12 @@ async def coro() -> None:

assert message["type"] == "http.response.start"

async def body_stream() -> typing.AsyncGenerator[bytes, None]:
async def body_stream() -> BodyStreamGenerator:
async with recv_stream:
async for message in recv_stream:
if message["type"] == "http.response.pathsend":
yield message
break
assert message["type"] == "http.response.body"
body = message.get("body", b"")
if body:
Expand Down
54 changes: 53 additions & 1 deletion tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextvars
from contextlib import AsyncExitStack
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Expand All @@ -18,7 +19,7 @@
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -1132,3 +1133,54 @@ async def send(message: Message) -> None:
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]


@pytest.mark.anyio
async def test_asgi_pathsend_events(tmpdir: Path) -> None:
path = tmpdir / "example.txt"
with path.open("w") as file:
file.write("<file content>")

request_body_sent = False
response_complete = anyio.Event()
events: list[Message] = []

async def endpoint_with_pathsend(_: Request) -> FileResponse:
return FileResponse(path)

async def passthrough(
request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_pathsend)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
"extensions": {"http.response.pathsend": {}},
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message: Message) -> None:
events.append(message)
if message["type"] == "http.response.pathsend":
response_complete.set()

await app(scope, receive, send)

assert len(events) == 2
assert events[0]["type"] == "http.response.start"
assert events[1]["type"] == "http.response.pathsend"

0 comments on commit 9286c83

Please sign in to comment.