diff --git a/starlette/requests.py b/starlette/requests.py index 23f8ac70a..ea79db8de 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -19,6 +19,7 @@ if typing.TYPE_CHECKING: + from starlette.applications import Starlette from starlette.routing import Router @@ -175,8 +176,8 @@ def state(self) -> State: return self._state def url_for(self, name: str, /, **path_params: typing.Any) -> URL: - router: Router = self.scope["router"] - url_path = router.url_path_for(name, **path_params) + url_pather: Starlette | Router = self.scope.get("app") or self.scope["router"] + url_path = url_pather.url_path_for(name, **path_params) return url_path.make_absolute_url(base_url=self.base_url) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 225038650..951d2f02d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1132,3 +1132,35 @@ async def send(message: Message) -> None: {"type": "http.response.body", "body": b"good!", "more_body": True}, {"type": "http.response.body", "body": b"", "more_body": False}, ] + + +def test_request_url_for_before_call_next( + test_client_factory: TestClientFactory, +) -> None: + class CallsRequestUrlForMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + if request.url == request.url_for("special"): + return PlainTextResponse("Special") + return await call_next(request) + + def endpoint(request: Request) -> Response: + return PlainTextResponse("OK") + + app = Starlette( + routes=[ + Route("/", endpoint, name="index"), + Route("/special", endpoint, name="special"), + ], + middleware=[Middleware(CallsRequestUrlForMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "OK" + + response = client.get("/special") + assert response.text == "Special"