Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-759): Revamp api handler #3714

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 135 additions & 46 deletions src/ai/backend/common/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,32 @@
from pydantic_core._pydantic_core import ValidationError

from .exception import (
InvalidAPIHandlerDefinition,
InvalidAPIParameters,
MalformedRequestBody,
MiddlewareParamParsingFailed,
ParameterNotParsedError,
)

T = TypeVar("T", bound=BaseModel)
TModel = TypeVar("TModel", bound=BaseModel)


class BodyParam(Generic[T]):
_model: Type[T]
_parsed: Optional[T]
class _HTTPRequestDataValidator:
@abstractmethod
async def validate_request(self, request: web.Request) -> Self:
pass


class BodyParam(_HTTPRequestDataValidator, Generic[TModel]):
_model: Type[TModel]
_parsed: Optional[TModel]

def __init__(self, model: Type[T]) -> None:
def __init__(self, model: Type[TModel]) -> None:
self._model = model
self._parsed: Optional[T] = None
self._parsed: Optional[TModel] = None

@property
def parsed(self) -> T:
def parsed(self) -> TModel:
if not self._parsed:
raise ParameterNotParsedError(
f"Parameter of type {self._model.__name__} has not been parsed yet"
Expand All @@ -55,17 +62,30 @@ def from_body(self, json_body: str) -> Self:
self._parsed = self._model.model_validate(json_body)
return self

async def validate_request(self, request: web.Request) -> Self:
if not request.can_read_body:
raise MalformedRequestBody(
f"Malformed body - URL: {request.url}, Method: {request.method}"
)
try:
body = await request.json()
except json.decoder.JSONDecodeError as e:
raise MalformedRequestBody(
f"Malformed body - URL: {request.url}, Method: {request.method}, error: {repr(e)}"
)
return self.from_body(body)

class QueryParam(Generic[T]):
_model: Type[T]
_parsed: Optional[T]

def __init__(self, model: Type[T]) -> None:
class QueryParam(_HTTPRequestDataValidator, Generic[TModel]):
_model: Type[TModel]
_parsed: Optional[TModel]

def __init__(self, model: Type[TModel]) -> None:
self._model = model
self._parsed: Optional[T] = None
self._parsed: Optional[TModel] = None

@property
def parsed(self) -> T:
def parsed(self) -> TModel:
if not self._parsed:
raise ParameterNotParsedError(
f"Parameter of type {self._model.__name__} has not been parsed yet"
Expand All @@ -76,17 +96,20 @@ def from_query(self, query: MultiMapping[str]) -> Self:
self._parsed = self._model.model_validate(query)
return self

async def validate_request(self, request: web.Request) -> Self:
return self.from_query(request.query)


class HeaderParam(Generic[T]):
_model: Type[T]
_parsed: Optional[T]
class HeaderParam(_HTTPRequestDataValidator, Generic[TModel]):
_model: Type[TModel]
_parsed: Optional[TModel]

def __init__(self, model: Type[T]) -> None:
def __init__(self, model: Type[TModel]) -> None:
self._model = model
self._parsed: Optional[T] = None
self._parsed: Optional[TModel] = None

@property
def parsed(self) -> T:
def parsed(self) -> TModel:
if not self._parsed:
raise ParameterNotParsedError(
f"Parameter of type {self._model.__name__} has not been parsed yet"
Expand All @@ -97,17 +120,20 @@ def from_header(self, headers: CIMultiDictProxy[str]) -> Self:
self._parsed = self._model.model_validate(headers)
return self

async def validate_request(self, request: web.Request) -> Self:
return self.from_header(request.headers)

class PathParam(Generic[T]):
_model: Type[T]
_parsed: Optional[T]

def __init__(self, model: Type[T]) -> None:
class PathParam(_HTTPRequestDataValidator, Generic[TModel]):
_model: Type[TModel]
_parsed: Optional[TModel]

def __init__(self, model: Type[TModel]) -> None:
self._model = model
self._parsed: Optional[T] = None
self._parsed: Optional[TModel] = None

@property
def parsed(self) -> T:
def parsed(self) -> TModel:
if not self._parsed:
raise ParameterNotParsedError(
f"Parameter of type {self._model.__name__} has not been parsed yet"
Expand All @@ -118,6 +144,9 @@ def from_path(self, match_info: UrlMappingMatchInfo) -> Self:
self._parsed = self._model.model_validate(match_info)
return self

async def validate_request(self, request: web.Request) -> Self:
return self.from_path(request.match_info)


class MiddlewareParam(ABC, BaseModel):
@classmethod
Expand Down Expand Up @@ -160,6 +189,7 @@ def status_code(self) -> int:


_ParamType: TypeAlias = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam
_ValidatorType: TypeAlias = BodyParam | QueryParam | PathParam | HeaderParam | type[MiddlewareParam]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, what is the difference between two?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can initialize all ...Param types before evaluating request data except MiddlewareParam, which requires input data for initializing. _ValidatorType represents the types that can evaluate and validate input data.



async def _extract_param_value(request: web.Request, input_param_type: Any) -> _ParamType:
Expand Down Expand Up @@ -221,18 +251,15 @@ def get_all(self) -> dict[str, _ParamType]:
return self._params


HandlerT = TypeVar("HandlerT")
HandlerReturn = Awaitable[APIResponse] | Coroutine[Any, Any, APIResponse]

ResponseType = web.Response | APIResponse
AwaitableResponse = Awaitable[ResponseType] | Coroutine[Any, Any, ResponseType]

BaseHandler: TypeAlias = Callable[..., AwaitableResponse]
ParsedRequestHandler: TypeAlias = Callable[..., Awaitable[web.Response]]
BaseHandler: TypeAlias = Callable[..., HandlerReturn]
ParsedRequestHandler: TypeAlias = Callable[..., Awaitable[web.StreamResponse]]


async def _parse_and_execute_handler(
request: web.Request, handler: BaseHandler, signature: Signature
) -> web.Response:
) -> web.StreamResponse:
handler_params = _HandlerParameters()
for name, param in signature.parameters.items():
# If handler has no parameter, for loop is skipped
Expand All @@ -253,13 +280,71 @@ async def _parse_and_execute_handler(

response = await handler(**handler_params.get_all())

if not isinstance(response, APIResponse):
raise InvalidAPIParameters(
f"Only Response wrapped by APIResponse Class can be handle: {type(response)}"
if isinstance(response, APIResponse):
return web.json_response(
response.to_json,
status=response.status_code,
)
return response


def _register_parameter_validator(
signature: Signature,
*,
handler_name: str,
) -> dict[str, _ValidatorType]:
signature_validator_map: dict[str, _ValidatorType] = {}
for name, param in signature.parameters.items():
if param.annotation is inspect.Parameter.empty:
raise InvalidAPIHandlerDefinition(
f"Not allowed signature for API handler function (handler:{handler_name},name:{name})"
)
param_type = param.annotation
original_type = get_origin(param_type)
if original_type is None:
if issubclass(param_type, MiddlewareParam):
signature_validator_map[name] = param_type
continue
else:
raise InvalidAPIHandlerDefinition(
f"Not allowed signature for API handler function. (handler:{handler_name}, name:{name}, type:{original_type})"
)
model_args = get_args(param_type)
try:
validation_model = model_args[0]
except IndexError:
raise InvalidAPIHandlerDefinition(
f"API parameter model got no argument (handler:{handler_name}, name:{name}, type:{original_type})"
)

param_instance = param_type(validation_model)
signature_validator_map[name] = param_instance
return signature_validator_map


async def _serialize_parameter(
request: web.Request, param_instance_or_class: _ValidatorType
) -> _ParamType:
param_instance: _ParamType
match param_instance_or_class:
case PathParam() | BodyParam() | HeaderParam() | QueryParam():
Comment on lines +329 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to ask why you changed from comparing classes to comparing instances?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended to use pattern matching rather than using multiple isinstance() if it is possible
https://peps.python.org/pep-0636/#adding-a-ui-matching-objects

try:
param_instance = await param_instance_or_class.validate_request(request)
except ValidationError as e:
raise InvalidAPIParameters(str(e))
case _:
try:
param_instance = param_instance_or_class.from_request(request)
except ValidationError as e:
raise MiddlewareParamParsingFailed(
f"Failed while parsing {param_instance_or_class}. (error:{repr(e)})"
)
return param_instance


def _parse_response(response: APIResponse) -> web.Response:
return web.json_response(
response.to_json,
data=response.to_json,
status=response.status_code,
)

Expand Down Expand Up @@ -337,16 +422,20 @@ async def handler(

original_signature: Signature = inspect.signature(handler)

sanitized_signature = original_signature.replace(
parameters=list(original_signature.parameters.values())[1:]
)
signature_validator_map = _register_parameter_validator(
sanitized_signature, handler_name=str(handler)
)

@functools.wraps(handler)
async def wrapped(first_arg: Any, *args, **kwargs) -> web.Response:
instance = first_arg
sanitized_signature = original_signature.replace(
parameters=list(original_signature.parameters.values())[1:]
)
return await _parse_and_execute_handler(
request=args[0],
handler=lambda *a, **kw: handler(instance, *a, **kw),
signature=sanitized_signature,
)
async def wrapped(first_arg: Any, request: web.Request) -> web.Response:
kwargs: dict[str, _ParamType] = {}
for name, param_instance_or_class in signature_validator_map.items():
param_instance = await _serialize_parameter(request, param_instance_or_class)
kwargs[name] = param_instance
response = await handler(first_arg, **kwargs)
return _parse_response(response)

return wrapped
4 changes: 4 additions & 0 deletions src/ai/backend/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def __init__(self, invalid_data: Mapping[str, Any]) -> None:
self.invalid_data = invalid_data


class InvalidAPIHandlerDefinition(Exception):
pass


class UnknownImageReference(ValueError):
"""
Represents an error for invalid/unknown image reference.
Expand Down
Loading