From 0380121995cf8ac06e8854b3c356fb1b80c8a326 Mon Sep 17 00:00:00 2001 From: Alvaro Lopez Garcia Date: Fri, 21 Jun 2024 12:09:44 +0200 Subject: [PATCH] feat+wip: move models endpoint to FastAPI --- deepaas/api/v2/__init__.py | 3 +- deepaas/api/v2/models.py | 119 ++++++++++++++++++++---------------- deepaas/api/v2/responses.py | 88 ++++++++++++++++++-------- 3 files changed, 133 insertions(+), 77 deletions(-) diff --git a/deepaas/api/v2/__init__.py b/deepaas/api/v2/__init__.py index 085462d6..2460163b 100644 --- a/deepaas/api/v2/__init__.py +++ b/deepaas/api/v2/__init__.py @@ -18,7 +18,7 @@ from oslo_config import cfg from deepaas.api.v2 import debug as v2_debug -# from deepaas.api.v2 import models as v2_model +from deepaas.api.v2 import models as v2_model # from deepaas.api.v2 import predict as v2_predict # from deepaas.api.v2 import responses # from deepaas.api.v2 import train as v2_train @@ -40,6 +40,7 @@ def get_app(enable_train=True, enable_predict=True): v2_debug.setup_debug() APP.include_router(v2_debug.router, tags=["debug"]) + APP.include_router(v2_model.get_router(), tags=["models"]) # APP.router.add_get("/", get_version, name="v2", allow_head=False) # v2_debug.setup_routes(APP) diff --git a/deepaas/api/v2/models.py b/deepaas/api/v2/models.py index 406f2642..d30aa138 100644 --- a/deepaas/api/v2/models.py +++ b/deepaas/api/v2/models.py @@ -14,52 +14,53 @@ # License for the specific language governing permissions and limitations # under the License. -import urllib.parse - -from aiohttp import web -import aiohttp_apispec +import fastapi from deepaas.api.v2 import responses from deepaas import model -@aiohttp_apispec.docs( - tags=["models"], +router = fastapi.APIRouter(prefix="/models") + + +@router.get( + "/", summary="Return loaded models and its information", - description="DEEPaaS can load several models and server them on the same " - "endpoint, making a call to the root of the models namespace " - "will return the loaded models, as long as their basic " - "metadata.", + description="Return list of DEEPaaS loaded models. In previous versions, DEEPaaS " + "could load several models and serve them on the same endpoint.", + tags=["models"], + response_model=responses.ModelList, ) -@aiohttp_apispec.response_schema(responses.ModelMeta(), 200) -async def index(request): - """Return loaded models and its information. +async def index_models( + request: fastapi.Request, +): + """Return loaded models and its information.""" + + name = model.V2_MODEL_NAME + model_obj = model.V2_MODEL + m = { + "id": name, + "name": name, + "links": [ + { + "rel": "self", + "href": str(request.url_for("get_model/" + name)), + } + ], + } + meta = model_obj.get_metadata() + m.update(meta) + return {"models": [m]} - DEEPaaS can load several models and server them on the same endpoint, - making a call to the root of the models namespace will return the - loaded models, as long as their basic metadata. - """ - models = [] - for name, obj in model.V2_MODELS.items(): - m = { - "id": name, - "name": name, - "links": [ - { - "rel": "self", - "href": urllib.parse.urljoin("%s/" % request.path, name), - } - ], - } - meta = obj.get_metadata() - m.update(meta) - models.append(m) - return web.json_response({"models": models}) - - -def _get_handler(model_name, model_obj): +def _get_handler_for_model(model_name, model_obj): + """Auxiliary function to get the handler for a model. + + This function returns a handler for a model that can be used to + register the routes in the router. + """ class Handler(object): + """Class to handle the model metadata endpoints.""" model_name = None model_obj = None @@ -67,36 +68,50 @@ def __init__(self, model_name, model_obj): self.model_name = model_name self.model_obj = model_obj - @aiohttp_apispec.docs( - tags=["models"], - summary="Return model's metadata", - ) - @aiohttp_apispec.response_schema(responses.ModelMeta(), 200) - async def get(self, request): + async def get(self, request: fastapi.Request): + """Return model's metadata.""" m = { "id": self.model_name, "name": self.model_name, "links": [ { "rel": "self", - "href": request.path.rstrip("/"), + "href": str(request.url), } ], } meta = self.model_obj.get_metadata() m.update(meta) - return web.json_response(m) + return m + + def register_routes(self, router): + """Register routes for the model in the router.""" + router.add_api_route( + f"/{self.model_name}", + self.get, + name="get_model/" + self.model_name, + summary="Return model's metadata", + tags=["models"], + response_model=responses.ModelMeta, + ) return Handler(model_name, model_obj) -def setup_routes(app): - app.router.add_get("/models/", index, allow_head=False) +def get_router() -> fastapi.APIRouter: + """Auxiliary function to get the router. + + We use this function to be able to include the router in the main + application and do things before it gets included. + + In this case we explicitly include the model's endpoints. + + """ + model_name = model.V2_MODEL_NAME + model_obj = model.V2_MODEL + + hdlr = _get_handler_for_model(model_name, model_obj) + hdlr.register_routes(router) - # In the next lines we iterate over the loaded models and create the - # different resources for each model. This way we can also load the - # expected parameters if needed (as in the training method). - for model_name, model_obj in model.V2_MODELS.items(): - hdlr = _get_handler(model_name, model_obj) - app.router.add_get("/models/%s/" % model_name, hdlr.get, allow_head=False) + return router diff --git a/deepaas/api/v2/responses.py b/deepaas/api/v2/responses.py index 9bbff1de..6c35ad37 100644 --- a/deepaas/api/v2/responses.py +++ b/deepaas/api/v2/responses.py @@ -14,26 +14,23 @@ # License for the specific language governing permissions and limitations # under the License. +import typing + import marshmallow from marshmallow import fields from marshmallow import validate +import pydantic -class Location(marshmallow.Schema): - rel = fields.Str(required=True) - href = fields.Url(required=True) - type = fields.Str(required=True) - - -class Version(marshmallow.Schema): - version = fields.Str(required="True") - id = fields.Str(required="True") - links = fields.Nested(Location) - type = fields.Str() +# class Version(marshmallow.Schema): +# version = fields.Str(required="True") +# id = fields.Str(required="True") +# # links = fields.Nested(Location) +# type = fields.Str() -class Versions(marshmallow.Schema): - versions = fields.List(fields.Nested(Version)) +# class Versions(marshmallow.Schema): +# versions = fields.List(fields.Nested(Version)) class Failure(marshmallow.Schema): @@ -45,17 +42,6 @@ class Prediction(marshmallow.Schema): predictions = fields.Str(required=True, description="String containing predictions") -class ModelMeta(marshmallow.Schema): - id = fields.Str(required=True, description="Model identifier") # noqa - name = fields.Str(required=True, description="Model name") - description = fields.Str(required=True, description="Model description") - license = fields.Str(required=False, description="Model license") - author = fields.Str(required=False, description="Model author") - version = fields.Str(required=False, description="Model version") - url = fields.Str(required=False, description="Model url") - links = fields.List(fields.Nested(Location)) - - class Training(marshmallow.Schema): uuid = fields.UUID(required=True, description="Training identifier") date = fields.DateTime(required=True, description="Training start time") @@ -70,3 +56,57 @@ class Training(marshmallow.Schema): class TrainingList(marshmallow.Schema): trainings = fields.List(fields.Nested(Training)) + + +# Pydantic models for the API + + +class Location(pydantic.BaseModel): + rel: str + href: pydantic.AnyHttpUrl + type: str = "application/json" + + +class ModelMeta(pydantic.BaseModel): + """"V2 model metadata. + + This class is used to represent the metadata of a model in the V2 API, as we were + doing in previous versions. + """ + id: str = pydantic.Field(..., description="Model identifier") # noqa + name: str = pydantic.Field(..., description="Model name") + description: typing.Optional[str] = pydantic.Field( + description="Model description", + default=None + ) + summary: typing.Optional[str] = pydantic.Field( + description="Model summary", + default=None + ) + license: typing.Optional[str] = pydantic.Field( + description="Model license", + default=None + ) + author: typing.Optional[str] = pydantic.Field( + description="Model author", + default=None + ) + version: typing.Optional[str] = pydantic.Field( + description="Model version", + default=None + ) + url: typing.Optional[str] = pydantic.Field( + description="Model url", + default=None + ) + # Links can be alist of Locations, or an empty list + links: typing.List[Location] = pydantic.Field( + description="Model links", + ) + + +class ModelList(pydantic.BaseModel): + models: typing.List[ModelMeta] = pydantic.Field( + ..., + description="List of loaded models" + )