Skip to content

Commit

Permalink
switch back to middleware as function and remove commented code.
Browse files Browse the repository at this point in the history
  • Loading branch information
recalcitrantsupplant committed Nov 1, 2024
1 parent e2b17b4 commit 6d815b2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 47 deletions.
11 changes: 5 additions & 6 deletions prez/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
PrefixNotFoundException,
NoEndpointNodeshapeException
)
from prez.middleware import ValidateHeaderMiddleware
from prez.middleware import create_validate_header_middleware
from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo
from prez.routers.base_router import router as base_prez_router
from prez.routers.custom_endpoints import create_dynamic_router
Expand Down Expand Up @@ -220,11 +220,10 @@ def assemble_app(
allow_headers=["*"],
expose_headers=["*"],
)
# validate_header_middleware = create_validate_header_middleware(
# settings.required_header
# )
# app.middleware("http")(validate_header_middleware)
app.add_middleware(ValidateHeaderMiddleware, required_header=settings.required_header)
validate_header_middleware = create_validate_header_middleware(
settings.required_header
)
app.middleware("http")(validate_header_middleware)

return app

Expand Down
55 changes: 14 additions & 41 deletions prez/middleware.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,23 @@
# from fastapi import Request
# from fastapi.responses import JSONResponse
#
#
# def create_validate_header_middleware(required_header: dict[str, str] | None):
# async def validate_header(request: Request, call_next):
# if required_header:
# header_name, expected_value = next(iter(required_header.items()))
# if (
# header_name not in request.headers
# or request.headers[header_name] != expected_value
# ):
# return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
# status_code=400,
# content={
# "error": "Header Validation Error",
# "message": f"Missing or invalid header: {header_name}",
# "code": "HEADER_VALIDATION_ERROR",
# },
# )
# return await call_next(request)
#
# return validate_header


from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp


class ValidateHeaderMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, required_header: dict[str, str] | None):
super().__init__(app)
self.required_header = required_header

async def dispatch(self, request: Request, call_next):
if self.required_header:
header_name, expected_value = next(iter(self.required_header.items()))
if header_name not in request.headers or request.headers[header_name] != expected_value:
return JSONResponse(
def create_validate_header_middleware(required_header: dict[str, str] | None):
async def validate_header(request: Request, call_next):
if required_header:
header_name, expected_value = next(iter(required_header.items()))
if (
header_name not in request.headers
or request.headers[header_name] != expected_value
):
return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
status_code=400,
content={
"error": "Header Validation Error",
"message": f"Missing or invalid header: {header_name}",
"code": "HEADER_VALIDATION_ERROR"
}
"code": "HEADER_VALIDATION_ERROR",
},
)
response = await call_next(request)
return response
return await call_next(request)

return validate_header

0 comments on commit 6d815b2

Please sign in to comment.