Skip to content

Commit 3071fb4

Browse files
adriangbKludexflorimondmancaabersheeran
authored andcommitted
Add Mount(..., middleware=[...]) (#1649)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com> Co-authored-by: Florimond Manca <florimond.manca@protonmail.com> Co-authored-by: Aber <me@abersheeran.com>
1 parent 57f5aca commit 3071fb4

File tree

3 files changed

+218
-3
lines changed

3 files changed

+218
-3
lines changed

docs/middleware.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,41 @@ to use the `middleware=<List of Middleware instances>` style, as it will:
686686
* Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`.
687687
* Preserves the top-level `app` instance.
688688

689+
## Applying middleware to `Mount`s
690+
691+
Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application:
692+
693+
```python
694+
from starlette.applications import Starlette
695+
from starlette.middleware import Middleware
696+
from starlette.middleware.gzip import GZipMiddleware
697+
from starlette.routing import Mount, Route
698+
699+
700+
routes = [
701+
Mount(
702+
"/",
703+
routes=[
704+
Route(
705+
"/example",
706+
endpoint=...,
707+
)
708+
],
709+
middleware=[Middleware(GZipMiddleware)]
710+
)
711+
]
712+
713+
app = Starlette(routes=routes)
714+
```
715+
716+
Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is.
717+
This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses.
718+
If you do want to apply the middleware logic to error responses only on some routes you have a couple of options:
719+
720+
* Add an `ExceptionMiddleware` onto the `Mount`
721+
* Add a `try/except` block to your middleware and return an error response from there
722+
* Split up marking and processing into two middlewares, one that gets put on `Mount` which marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting.
723+
689724
## Third party middleware
690725

691726
#### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)

starlette/routing.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from starlette.convertors import CONVERTOR_TYPES, Convertor
1515
from starlette.datastructures import URL, Headers, URLPath
1616
from starlette.exceptions import HTTPException
17+
from starlette.middleware import Middleware
1718
from starlette.requests import Request
1819
from starlette.responses import PlainTextResponse, RedirectResponse
1920
from starlette.types import ASGIApp, Receive, Scope, Send
@@ -348,24 +349,30 @@ def __init__(
348349
app: typing.Optional[ASGIApp] = None,
349350
routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
350351
name: typing.Optional[str] = None,
352+
*,
353+
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
351354
) -> None:
352355
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
353356
assert (
354357
app is not None or routes is not None
355358
), "Either 'app=...', or 'routes=' must be specified"
356359
self.path = path.rstrip("/")
357360
if app is not None:
358-
self.app: ASGIApp = app
361+
self._base_app: ASGIApp = app
359362
else:
360-
self.app = Router(routes=routes)
363+
self._base_app = Router(routes=routes)
364+
self.app = self._base_app
365+
if middleware is not None:
366+
for cls, options in reversed(middleware):
367+
self.app = cls(app=self.app, **options)
361368
self.name = name
362369
self.path_regex, self.path_format, self.param_convertors = compile_path(
363370
self.path + "/{path:path}"
364371
)
365372

366373
@property
367374
def routes(self) -> typing.List[BaseRoute]:
368-
return getattr(self.app, "routes", [])
375+
return getattr(self._base_app, "routes", [])
369376

370377
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
371378
if scope["type"] in ("http", "websocket"):

tests/test_routing.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
import pytest
66

77
from starlette.applications import Starlette
8+
from starlette.exceptions import HTTPException
9+
from starlette.middleware import Middleware
10+
from starlette.requests import Request
811
from starlette.responses import JSONResponse, PlainTextResponse, Response
912
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
13+
from starlette.testclient import TestClient
14+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1015
from starlette.websockets import WebSocket, WebSocketDisconnect
1116

1217

@@ -768,6 +773,115 @@ def test_route_name(endpoint: typing.Callable, expected_name: str):
768773
assert Route(path="/", endpoint=endpoint).name == expected_name
769774

770775

776+
class AddHeadersMiddleware:
777+
def __init__(self, app: ASGIApp) -> None:
778+
self.app = app
779+
780+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
781+
scope["add_headers_middleware"] = True
782+
783+
async def modified_send(msg: Message) -> None:
784+
if msg["type"] == "http.response.start":
785+
msg["headers"].append((b"X-Test", b"Set by middleware"))
786+
await send(msg)
787+
788+
await self.app(scope, receive, modified_send)
789+
790+
791+
def assert_middleware_header_route(request: Request) -> Response:
792+
assert request.scope["add_headers_middleware"] is True
793+
return Response()
794+
795+
796+
mounted_routes_with_middleware = Starlette(
797+
routes=[
798+
Mount(
799+
"/http",
800+
routes=[
801+
Route(
802+
"/",
803+
endpoint=assert_middleware_header_route,
804+
methods=["GET"],
805+
name="route",
806+
),
807+
],
808+
middleware=[Middleware(AddHeadersMiddleware)],
809+
),
810+
Route("/home", homepage),
811+
]
812+
)
813+
814+
815+
mounted_app_with_middleware = Starlette(
816+
routes=[
817+
Mount(
818+
"/http",
819+
app=Route(
820+
"/",
821+
endpoint=assert_middleware_header_route,
822+
methods=["GET"],
823+
name="route",
824+
),
825+
middleware=[Middleware(AddHeadersMiddleware)],
826+
),
827+
Route("/home", homepage),
828+
]
829+
)
830+
831+
832+
@pytest.mark.parametrize(
833+
"app",
834+
[
835+
mounted_routes_with_middleware,
836+
mounted_app_with_middleware,
837+
],
838+
)
839+
def test_mount_middleware(
840+
test_client_factory: typing.Callable[..., TestClient],
841+
app: Starlette,
842+
) -> None:
843+
test_client = test_client_factory(app)
844+
845+
response = test_client.get("/home")
846+
assert response.status_code == 200
847+
assert "X-Test" not in response.headers
848+
849+
response = test_client.get("/http")
850+
assert response.status_code == 200
851+
assert response.headers["X-Test"] == "Set by middleware"
852+
853+
854+
def test_mount_routes_with_middleware_url_path_for() -> None:
855+
"""Checks that url_path_for still works with mounted routes with Middleware"""
856+
assert mounted_routes_with_middleware.url_path_for("route") == "/http/"
857+
858+
859+
def test_mount_asgi_app_with_middleware_url_path_for() -> None:
860+
"""Mounted ASGI apps do not work with url path for,
861+
middleware does not change this
862+
"""
863+
with pytest.raises(NoMatchFound):
864+
mounted_app_with_middleware.url_path_for("route")
865+
866+
867+
def test_add_route_to_app_after_mount(
868+
test_client_factory: typing.Callable[..., TestClient],
869+
) -> None:
870+
"""Checks that Mount will pick up routes
871+
added to the underlying app after it is mounted
872+
"""
873+
inner_app = Router()
874+
app = Mount("/http", app=inner_app)
875+
inner_app.add_route(
876+
"/inner",
877+
endpoint=homepage,
878+
methods=["GET"],
879+
)
880+
client = test_client_factory(app)
881+
response = client.get("/http/inner")
882+
assert response.status_code == 200
883+
884+
771885
def test_exception_on_mounted_apps(test_client_factory):
772886
def exc(request):
773887
raise Exception("Exc")
@@ -779,3 +893,62 @@ def exc(request):
779893
with pytest.raises(Exception) as ctx:
780894
client.get("/sub/")
781895
assert str(ctx.value) == "Exc"
896+
897+
898+
def test_mounted_middleware_does_not_catch_exception(
899+
test_client_factory: typing.Callable[..., TestClient],
900+
) -> None:
901+
# https://github.com/encode/starlette/pull/1649#discussion_r960236107
902+
def exc(request: Request) -> Response:
903+
raise HTTPException(status_code=403, detail="auth")
904+
905+
class NamedMiddleware:
906+
def __init__(self, app: ASGIApp, name: str) -> None:
907+
self.app = app
908+
self.name = name
909+
910+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
911+
async def modified_send(msg: Message) -> None:
912+
if msg["type"] == "http.response.start":
913+
msg["headers"].append((f"X-{self.name}".encode(), b"true"))
914+
await send(msg)
915+
916+
await self.app(scope, receive, modified_send)
917+
918+
app = Starlette(
919+
routes=[
920+
Mount(
921+
"/mount",
922+
routes=[
923+
Route("/err", exc),
924+
Route("/home", homepage),
925+
],
926+
middleware=[Middleware(NamedMiddleware, name="Mounted")],
927+
),
928+
Route("/err", exc),
929+
Route("/home", homepage),
930+
],
931+
middleware=[Middleware(NamedMiddleware, name="Outer")],
932+
)
933+
934+
client = test_client_factory(app)
935+
936+
resp = client.get("/home")
937+
assert resp.status_code == 200, resp.content
938+
assert "X-Outer" in resp.headers
939+
940+
resp = client.get("/err")
941+
assert resp.status_code == 403, resp.content
942+
assert "X-Outer" in resp.headers
943+
944+
resp = client.get("/mount/home")
945+
assert resp.status_code == 200, resp.content
946+
assert "X-Mounted" in resp.headers
947+
948+
# this is the "surprising" behavior bit
949+
# the middleware on the mount never runs because there
950+
# is nothing to catch the HTTPException
951+
# since Mount middlweare is not wrapped by ExceptionMiddleware
952+
resp = client.get("/mount/err")
953+
assert resp.status_code == 403, resp.content
954+
assert "X-Mounted" not in resp.headers

0 commit comments

Comments
 (0)