diff --git a/deepaas/api/v2/__init__.py b/deepaas/api/v2/__init__.py index 2460163b..c0106a70 100644 --- a/deepaas/api/v2/__init__.py +++ b/deepaas/api/v2/__init__.py @@ -19,7 +19,7 @@ from deepaas.api.v2 import debug as v2_debug from deepaas.api.v2 import models as v2_model -# from deepaas.api.v2 import predict as v2_predict +from deepaas.api.v2 import predict as v2_predict # from deepaas.api.v2 import responses # from deepaas.api.v2 import train as v2_train from deepaas import log @@ -41,6 +41,7 @@ def get_app(enable_train=True, enable_predict=True): APP.include_router(v2_debug.router, tags=["debug"]) APP.include_router(v2_model.get_router(), tags=["models"]) + APP.include_router(v2_predict.get_router(), tags=["predict"]) # APP.router.add_get("/", get_version, name="v2", allow_head=False) # v2_debug.setup_routes(APP) diff --git a/deepaas/api/v2/predict.py b/deepaas/api/v2/predict.py index 8141469e..b7278faa 100644 --- a/deepaas/api/v2/predict.py +++ b/deepaas/api/v2/predict.py @@ -14,10 +14,14 @@ # License for the specific language governing permissions and limitations # under the License. -from aiohttp import web -import aiohttp_apispec -from webargs import aiohttpparser -import webargs.core +# from aiohttp import web +# import aiohttp_apispec +# from webargs import aiohttpparser +# import webargs.core + +import fastapi +import fastapi.encoders +import fastapi.exceptions from deepaas.api.v2 import responses from deepaas.api.v2 import utils @@ -33,20 +37,26 @@ def _get_model_response(model_name, model_obj): return responses.Prediction -def _get_handler(model_name, model_obj): - aux = model_obj.get_predict_args() - accept = aux.get("accept", None) - if accept: - accept.validate.choices.append("*/*") - accept.load_default = accept.validate.choices[0] - accept.location = "headers" +router = fastapi.APIRouter(prefix="/models") + + +def _get_handler_for_model(model_name, model_obj): + """Auxiliary function to get the handler for a model. + + This function returns a handler for a model that can be used to + register the routes in the router. - handler_args = webargs.core.dict2schema(aux) - handler_args.opts.ordered = True + """ - response = _get_model_response(model_name, model_obj) + user_declared_args = model_obj.get_predict_args() + pydantic_schema = utils.get_pydantic_schema_from_marshmallow_fields( + "PydanticSchema", + user_declared_args, + ) class Handler(object): + """Class to handle the model metadata endpoints.""" + model_name = None model_obj = None @@ -54,47 +64,53 @@ def __init__(self, model_name, model_obj): self.model_name = model_name self.model_obj = model_obj - @aiohttp_apispec.docs( - tags=["models"], - summary="Make a prediction given the input data", - produces=accept.validate.choices if accept else None, - ) - @aiohttp_apispec.querystring_schema(handler_args) - @aiohttp_apispec.response_schema(response(), 200) - @aiohttp_apispec.response_schema(responses.Failure(), 400) - async def post(self, request): - args = await aiohttpparser.parser.parse(handler_args, request) - task = self.model_obj.predict(**args) - await task - - ret = task.result()["output"] + async def predict(self, args: pydantic_schema = fastapi.Depends()): + """Make a prediction given the input data.""" + dict_args = args.model_dump(by_alias=True) + + ret = await self.model_obj.predict(**args.model_dump(by_alias=True)) if isinstance(ret, model.v2.wrapper.ReturnedFile): ret = open(ret.filename, "rb") - accept = args.get("accept", "application/json") - if accept not in ["application/json", "*/*"]: - response = web.Response( - body=ret, - content_type=accept, - ) - return response if self.model_obj.has_schema: - self.model_obj.validate_response(ret) - return web.json_response(ret) + # FIXME(aloga): Validation does not work, as we are converting from + # Marshmallow to Pydantic, check this as son as possible. + # self.model_obj.validate_response(ret) + return fastapi.responses.JSONResponse(ret) + + return fastapi.responses.JSONResponse( + content={"status": "OK", "predictions": ret} + ) - return web.json_response({"status": "OK", "predictions": ret}) + def register_routes(self, router): + """Register the routes in the router.""" + + response = _get_model_response(self.model_name, self.model_obj) + + router.add_api_route( + f"/{self.model_name}/predict", + self.predict, + methods=["POST"], + response_model=response, + ) return Handler(model_name, model_obj) -def setup_routes(app, enable=True): - # In the next lines we iterate over the loaded models and create the - # different resources for each model. This way we can also load the - # expected parameters if needed (as in the training method). - for model_name, model_obj in model.V2_MODELS.items(): - if enable: - hdlr = _get_handler(model_name, model_obj) - else: - hdlr = utils.NotEnabledHandler() - app.router.add_post("/models/%s/predict/" % model_name, hdlr.post) +def get_router(): + """Auxiliary function to get the router. + + We use this function to be able to include the router in the main + application and do things before it gets included. + + In this case we explicitly include the model precit endpoint. + + """ + model_name = model.V2_MODEL_NAME + model_obj = model.V2_MODEL + + hdlr = _get_handler_for_model(model_name, model_obj) + hdlr.register_routes(router) + + return router diff --git a/deepaas/api/v2/responses.py b/deepaas/api/v2/responses.py index 6c35ad37..b2aeb1a6 100644 --- a/deepaas/api/v2/responses.py +++ b/deepaas/api/v2/responses.py @@ -33,15 +33,6 @@ # versions = fields.List(fields.Nested(Version)) -class Failure(marshmallow.Schema): - message = fields.Str(required=True, description="Failure message") - - -class Prediction(marshmallow.Schema): - status = fields.String(required=True, description="Response status message") - predictions = fields.Str(required=True, description="String containing predictions") - - class Training(marshmallow.Schema): uuid = fields.UUID(required=True, description="Training identifier") date = fields.DateTime(required=True, description="Training start time") @@ -110,3 +101,12 @@ class ModelList(pydantic.BaseModel): ..., description="List of loaded models" ) + + +class Prediction(pydantic.BaseModel): + status: str = pydantic.Field(description="Response status message") + predictions: str = pydantic.Field(description="String containing predictions") + + +class Failure(pydantic.BaseModel): + message: str = pydantic.Field(description="Failure message") diff --git a/deepaas/api/v2/utils.py b/deepaas/api/v2/utils.py index b66b5cc1..dadbae84 100644 --- a/deepaas/api/v2/utils.py +++ b/deepaas/api/v2/utils.py @@ -14,7 +14,16 @@ # License for the specific language governing permissions and limitations # under the License. +import datetime +import decimal +import typing + +import fastapi from aiohttp import web +import marshmallow +import marshmallow.fields +import pydantic +import pydantic.utils class NotEnabledHandler(object): @@ -23,3 +32,205 @@ async def f(*args, **kwargs): raise web.HTTPPaymentRequired() return f + + +# Convert marshmallow fields to pydantic fields + + +CUSTOM_FIELD_DEFAULT = typing.Any + + +def get_dict_type(x): + """For dicts we need to look at the key and value type""" + key_type = get_pydantic_type(x.key_field) + if x.value_field: + value_type = get_pydantic_type(x.value_field) + return typing.Dict[key_type, value_type] + return typing.Dict[key_type, typing.Any] + + +def get_list_type(x): + """For lists we need to look at the value type""" + if x.inner: + c_type = get_pydantic_type(x.inner, optional=False) + return typing.List[c_type] + return typing.List + + +# def get_nested_model(x): +# """Return a model from a nested marshmallow schema""" +# return pydantic_from_marshmallow(x.schema) + + +FIELD_CONVERTERS = { + marshmallow.fields.Bool: bool, + marshmallow.fields.Boolean: bool, + marshmallow.fields.Date: datetime.date, + marshmallow.fields.DateTime: datetime.datetime, + marshmallow.fields.Decimal: decimal.Decimal, + marshmallow.fields.Dict: get_dict_type, + marshmallow.fields.Email: pydantic.EmailStr, + marshmallow.fields.Float: float, + marshmallow.fields.Function: typing.Callable, + marshmallow.fields.Int: int, + marshmallow.fields.Integer: int, + marshmallow.fields.List: get_list_type, + marshmallow.fields.Mapping: typing.Mapping, + marshmallow.fields.Method: typing.Callable, + # marshmallow.fields.Nested: get_nested_model, + marshmallow.fields.Number: typing.Union[pydantic.StrictFloat, pydantic.StrictInt], + marshmallow.fields.Str: str, + marshmallow.fields.String: str, + marshmallow.fields.Time: datetime.time, + marshmallow.fields.TimeDelta: datetime.timedelta, + marshmallow.fields.URL: pydantic.AnyUrl, + marshmallow.fields.Url: pydantic.AnyUrl, + marshmallow.fields.UUID: str, +} + + +def is_custom_field(field): + """If this is a subclass of marshmallow's Field and not in our list, we + assume its a custom field""" + ftype = type(field) + if issubclass(ftype, marshmallow.fields.Field) and ftype not in FIELD_CONVERTERS: + print(" Custom field") + return True + return False + + +def is_file_field(field): + """If this is a file field, we need to handle it differently.""" + if field is not None and field.metadata.get("type") == "file": + print(" File field") + return True + return False + + +def get_pydantic_type(field, optional=True): + """Get pydantic type from a marshmallow field""" + if field is None: + return typing.Any + elif is_file_field(field): + conv = fastapi.UploadFile + elif is_custom_field(field): + conv = typing.Any + else: + conv = FIELD_CONVERTERS[type(field)] + + # TODO: Is there a cleaner way to check for annotation types? + if isinstance(conv, type) or conv.__module__ == "typing": + pyd_type = conv + else: + pyd_type = conv(field) + + if optional and not field.required: + if is_file_field(field): + # If we have a file field, do not wrap with Optional, as FastAPI does not + # handle it correctly. Instead, we put None as default value later in the + # outer function. + pass + else: + pyd_type = typing.Optional[pyd_type] + + # FIXME(aloga): we need to handle enums + return pyd_type + + +def sanitize_field_name(field_name): + field_name = field_name.replace("-", "_") + field_name = field_name.replace(" ", "_") + field_name = field_name.replace(".", "_") + field_name = field_name.replace(":", "_") + field_name = field_name.replace("/", "_") + field_name = field_name.replace("\\", "_") + return field_name + + +def check_for_file_fields(fields): + for field_name, field in fields.items(): + if is_file_field(field): + return True + return False + + +def pydantic_from_marshmallow( + name: str, + schema: marshmallow.Schema +) -> pydantic.BaseModel: + """Convert a marshmallow schema to a pydantic model. + + May only work for fairly simple cases. Barely tested. Enjoy. + """ + + pyd_fields = {} + have_file_fields = check_for_file_fields(schema._declared_fields) + + for field_name, field in schema._declared_fields.items(): + pyd_type = get_pydantic_type(field) + + description = field.metadata.get("description") + + if field.default: + default = field.default + elif field.missing: + default = field.missing + else: + default = None + + if is_file_field(field): + field_cls = fastapi.File + elif have_file_fields: + field_cls = fastapi.Form + else: + field_cls = pydantic.Field + + if field.required and not default: + default = field_cls( + ..., + description=description, + title=field_name, + serialization_alias=field_name, + ) + elif default is None: + if is_file_field(field): + # If we have a file field, it is not wraped with Optional, as FastAPI + # does not handle it correctly (c.f. get_pydantic_type function above). + # Instead, we put None as default value here, and FastAPI will handle it + # correctly. + default = None + else: + default = field_cls( + description=description, + title=field_name, + serialization_alias=field_name, + ) + else: + default = field_cls( + description=description, + default=default, + title=field_name, + serialization_alias=field_name, + ) + + field_name = sanitize_field_name(field_name) + + pyd_fields[field_name] = (pyd_type, default) + + ret = pydantic.create_model( + name, + **pyd_fields, + ) + return ret + + +def get_pydantic_schema_from_marshmallow_fields( + name: str, + fields: dict, +) -> pydantic.BaseModel: + + model = marshmallow.Schema.from_dict(fields) + + pydantic_model = pydantic_from_marshmallow(name, model()) + + return pydantic_model diff --git a/deepaas/model/v2/wrapper.py b/deepaas/model/v2/wrapper.py index d4b2cce4..e6d09676 100644 --- a/deepaas/model/v2/wrapper.py +++ b/deepaas/model/v2/wrapper.py @@ -26,6 +26,7 @@ import marshmallow from oslo_config import cfg +from deepaas.api.v2 import utils from deepaas import log LOG = log.getLogger(__name__) @@ -110,6 +111,7 @@ def __init__(self, name, model_obj): self.has_schema = True except Exception as e: LOG.exception(e) + # FIXME(aloga): do not use web exception here raise web.HTTPInternalServerError( reason=("Model defined schema is invalid, " "check server logs.") ) @@ -118,13 +120,21 @@ def __init__(self, name, model_obj): if issubclass(schema, marshmallow.Schema): self.has_schema = True except TypeError: + # FIXME(aloga): do not use web exception here raise web.HTTPInternalServerError( reason=("Model defined schema is invalid, " "check server logs.") ) else: self.has_schema = False - self.response_schema = schema + # Now convert to pydantic schema... + # FIXME(aloga): use try except + if schema is not None: + self.response_schema = utils.pydantic_from_marshmallow( + "ModelPredictionResponse", schema + ) + else: + self.response_schema = None @contextlib.contextmanager def _catch_error(self): @@ -245,7 +255,7 @@ def predict_wrap(predict_func, *args, **kwargs): return ret - def predict(self, *args, **kwargs): + async def predict(self, *args, **kwargs): """Perform a prediction on wrapped model's ``predict`` method. :raises HTTPNotImplemented: If the method is not