Skip to content

Commit

Permalink
feat+wip: move predict method to FastAPI
Browse files Browse the repository at this point in the history
This requires that we change all model args/responses from Marshmallow
to Pydantic. Most of the code is in this change, we can split it later
on two different changes (marshmallow + pydantic and FastAPI for
predict).
  • Loading branch information
alvarolopez committed Aug 8, 2024
1 parent e005e71 commit fb89404
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 60 deletions.
3 changes: 2 additions & 1 deletion deepaas/api/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from deepaas.api.v2 import debug as v2_debug
from deepaas.api.v2 import models as v2_model
# from deepaas.api.v2 import predict as v2_predict
from deepaas.api.v2 import predict as v2_predict
# from deepaas.api.v2 import responses
# from deepaas.api.v2 import train as v2_train
from deepaas import log
Expand All @@ -41,6 +41,7 @@ def get_app(enable_train=True, enable_predict=True):

APP.include_router(v2_debug.router, tags=["debug"])
APP.include_router(v2_model.get_router(), tags=["models"])
APP.include_router(v2_predict.get_router(), tags=["predict"])

# APP.router.add_get("/", get_version, name="v2", allow_head=False)
# v2_debug.setup_routes(APP)
Expand Down
112 changes: 64 additions & 48 deletions deepaas/api/v2/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# License for the specific language governing permissions and limitations
# under the License.

from aiohttp import web
import aiohttp_apispec
from webargs import aiohttpparser
import webargs.core
# from aiohttp import web
# import aiohttp_apispec
# from webargs import aiohttpparser
# import webargs.core

import fastapi
import fastapi.encoders
import fastapi.exceptions

from deepaas.api.v2 import responses
from deepaas.api.v2 import utils
Expand All @@ -33,68 +37,80 @@ def _get_model_response(model_name, model_obj):
return responses.Prediction


def _get_handler(model_name, model_obj):
aux = model_obj.get_predict_args()
accept = aux.get("accept", None)
if accept:
accept.validate.choices.append("*/*")
accept.load_default = accept.validate.choices[0]
accept.location = "headers"
router = fastapi.APIRouter(prefix="/models")


def _get_handler_for_model(model_name, model_obj):
"""Auxiliary function to get the handler for a model.
This function returns a handler for a model that can be used to
register the routes in the router.
handler_args = webargs.core.dict2schema(aux)
handler_args.opts.ordered = True
"""

response = _get_model_response(model_name, model_obj)
user_declared_args = model_obj.get_predict_args()
pydantic_schema = utils.get_pydantic_schema_from_marshmallow_fields(
"PydanticSchema",
user_declared_args,
)

class Handler(object):
"""Class to handle the model metadata endpoints."""

model_name = None
model_obj = None

def __init__(self, model_name, model_obj):
self.model_name = model_name
self.model_obj = model_obj

@aiohttp_apispec.docs(
tags=["models"],
summary="Make a prediction given the input data",
produces=accept.validate.choices if accept else None,
)
@aiohttp_apispec.querystring_schema(handler_args)
@aiohttp_apispec.response_schema(response(), 200)
@aiohttp_apispec.response_schema(responses.Failure(), 400)
async def post(self, request):
args = await aiohttpparser.parser.parse(handler_args, request)
task = self.model_obj.predict(**args)
await task

ret = task.result()["output"]
async def predict(self, args: pydantic_schema = fastapi.Depends()):
"""Make a prediction given the input data."""
dict_args = args.model_dump(by_alias=True)

ret = await self.model_obj.predict(**args.model_dump(by_alias=True))

if isinstance(ret, model.v2.wrapper.ReturnedFile):
ret = open(ret.filename, "rb")

accept = args.get("accept", "application/json")
if accept not in ["application/json", "*/*"]:
response = web.Response(
body=ret,
content_type=accept,
)
return response
if self.model_obj.has_schema:
self.model_obj.validate_response(ret)
return web.json_response(ret)
# FIXME(aloga): Validation does not work, as we are converting from
# Marshmallow to Pydantic, check this as son as possible.
# self.model_obj.validate_response(ret)
return fastapi.responses.JSONResponse(ret)

return fastapi.responses.JSONResponse(
content={"status": "OK", "predictions": ret}
)

return web.json_response({"status": "OK", "predictions": ret})
def register_routes(self, router):
"""Register the routes in the router."""

response = _get_model_response(self.model_name, self.model_obj)

router.add_api_route(
f"/{self.model_name}/predict",
self.predict,
methods=["POST"],
response_model=response,
)

return Handler(model_name, model_obj)


def setup_routes(app, enable=True):
# In the next lines we iterate over the loaded models and create the
# different resources for each model. This way we can also load the
# expected parameters if needed (as in the training method).
for model_name, model_obj in model.V2_MODELS.items():
if enable:
hdlr = _get_handler(model_name, model_obj)
else:
hdlr = utils.NotEnabledHandler()
app.router.add_post("/models/%s/predict/" % model_name, hdlr.post)
def get_router():
"""Auxiliary function to get the router.
We use this function to be able to include the router in the main
application and do things before it gets included.
In this case we explicitly include the model precit endpoint.
"""
model_name = model.V2_MODEL_NAME
model_obj = model.V2_MODEL

hdlr = _get_handler_for_model(model_name, model_obj)
hdlr.register_routes(router)

return router
18 changes: 9 additions & 9 deletions deepaas/api/v2/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@
# versions = fields.List(fields.Nested(Version))


class Failure(marshmallow.Schema):
message = fields.Str(required=True, description="Failure message")


class Prediction(marshmallow.Schema):
status = fields.String(required=True, description="Response status message")
predictions = fields.Str(required=True, description="String containing predictions")


class Training(marshmallow.Schema):
uuid = fields.UUID(required=True, description="Training identifier")
date = fields.DateTime(required=True, description="Training start time")
Expand Down Expand Up @@ -110,3 +101,12 @@ class ModelList(pydantic.BaseModel):
...,
description="List of loaded models"
)


class Prediction(pydantic.BaseModel):
status: str = pydantic.Field(description="Response status message")
predictions: str = pydantic.Field(description="String containing predictions")


class Failure(pydantic.BaseModel):
message: str = pydantic.Field(description="Failure message")
Loading

0 comments on commit fb89404

Please sign in to comment.