Skip to content

Commit

Permalink
Support request.url_for in BaseMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Jdsleppy committed Sep 2, 2024
1 parent 8e1fc9b commit c80f9be
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
5 changes: 3 additions & 2 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


if typing.TYPE_CHECKING:
from starlette.applications import Starlette
from starlette.routing import Router


Expand Down Expand Up @@ -175,8 +176,8 @@ def state(self) -> State:
return self._state

def url_for(self, name: str, /, **path_params: typing.Any) -> URL:
router: Router = self.scope["router"]
url_path = router.url_path_for(name, **path_params)
url_pather: Starlette | Router = self.scope.get("app") or self.scope["router"]
url_path = url_pather.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.base_url)


Expand Down
32 changes: 32 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,3 +1132,35 @@ async def send(message: Message) -> None:
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]


def test_request_url_for_before_call_next(
test_client_factory: TestClientFactory,
) -> None:
class CallsRequestUrlForMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
if request.url == request.url_for("special"):
return PlainTextResponse("Special")
return await call_next(request)

def endpoint(request: Request) -> Response:
return PlainTextResponse("OK")

app = Starlette(
routes=[
Route("/", endpoint, name="index"),
Route("/special", endpoint, name="special"),
],
middleware=[Middleware(CallsRequestUrlForMiddleware)],
)

client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"

response = client.get("/special")
assert response.text == "Special"

0 comments on commit c80f9be

Please sign in to comment.