Skip to content

Commit

Permalink
fix(middleware): allow read json request without blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Aranda committed May 26, 2021
1 parent b536acf commit 42948e3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
39 changes: 36 additions & 3 deletions src/rubrix/client/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.types import Message, Receive

_logger = logging.getLogger(__name__)
_spaces_regex = re.compile(r"\s+")
Expand Down Expand Up @@ -45,6 +46,28 @@ def text_classification_mapper(inputs, outputs):
)


class CachedJsonRequest(Request):
"""
We must a cached version of incoming requests since request body cannot be read from middleware directly.
See <https://github.com/encode/starlette/issues/847> for more information
TODO Remove usage of CachedRequest when https://github.com/encode/starlette/pull/848 is released
"""

@property
def receive(self) -> Receive:
body = None
if hasattr(self, "_body"):
body = self._body
if body is not None:

async def cached_receive() -> Message:
return dict(type="http.request", body=body)

return cached_receive
return self._receive


class RubrixLogHTTPMiddleware(BaseHTTPMiddleware):
"""An standard starlette middleware that enables rubrix logs for http prediction requests"""

Expand Down Expand Up @@ -72,12 +95,22 @@ async def dispatch(
if self._endpoint != request.url.path: # Filtering endpoint path
return await call_next(request)

response: Response = await call_next(request)
content_type = request.headers.get("Content-type", None)
if "application/json" not in content_type:
return await call_next(request)

cached_request = CachedJsonRequest(
request.scope, request.receive, request._send
)
inputs = await cached_request.json()
response: Response = await call_next(cached_request)
try:
if not isinstance(response, (JSONResponse, StreamingResponse)):
if (
not isinstance(response, (JSONResponse, StreamingResponse))
or response.status_code >= 400
):
return response

inputs = await request.json()
new_response, outputs = await self._extract_response_content(response)
self._queue.put_nowait((inputs, outputs, str(request.url)))
return new_response
Expand Down
12 changes: 10 additions & 2 deletions tests/client/test_asgi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import time
from typing import Any, Dict

import rubrix
from fastapi import FastAPI
from rubrix import TextClassificationRecord, TokenClassificationRecord
from rubrix.client.asgi import RubrixLogHTTPMiddleware, token_classification_mapper
from starlette.applications import Starlette
Expand All @@ -11,15 +15,15 @@ def test_rubrix_middleware_for_text_classification(monkeypatch):
expected_endpoint = "/predict"
expected_dataset_name = "mlmodel_v3_monitor_ds"

app = Starlette()
app = FastAPI()
app.add_middleware(
RubrixLogHTTPMiddleware,
api_endpoint=expected_endpoint,
dataset=expected_dataset_name,
)

@app.route(expected_endpoint, methods=["POST"])
def mock_predict(request):
def mock_predict(data: Dict[str, Any]):
return JSONResponse(
content=[
{"labels": ["A", "B"], "probabilities": [0.9, 0.1]},
Expand Down Expand Up @@ -53,7 +57,11 @@ def __call__(self, records, name: str, **kwargs):
],
)

assert mock_log.was_called
time.sleep(0.200)
mock_log.was_called = False
mock.get("/another/predict/route")
assert not mock_log.was_called


def test_rubrix_middleware_for_token_classification(monkeypatch):
Expand Down

0 comments on commit 42948e3

Please sign in to comment.