-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(BA-759): Revamp api handler #3714
base: main
Are you sure you want to change the base?
Changes from all commits
3be99e7
6dc6470
89e95d0
94c2c9c
79aff67
00e8db9
32357ed
4ee5524
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,25 +26,32 @@ | |
from pydantic_core._pydantic_core import ValidationError | ||
|
||
from .exception import ( | ||
InvalidAPIHandlerDefinition, | ||
InvalidAPIParameters, | ||
MalformedRequestBody, | ||
MiddlewareParamParsingFailed, | ||
ParameterNotParsedError, | ||
) | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
TModel = TypeVar("TModel", bound=BaseModel) | ||
|
||
|
||
class BodyParam(Generic[T]): | ||
_model: Type[T] | ||
_parsed: Optional[T] | ||
class _HTTPRequestDataValidator: | ||
@abstractmethod | ||
async def validate_request(self, request: web.Request) -> Self: | ||
pass | ||
|
||
|
||
class BodyParam(_HTTPRequestDataValidator, Generic[TModel]): | ||
_model: Type[TModel] | ||
_parsed: Optional[TModel] | ||
|
||
def __init__(self, model: Type[T]) -> None: | ||
def __init__(self, model: Type[TModel]) -> None: | ||
self._model = model | ||
self._parsed: Optional[T] = None | ||
self._parsed: Optional[TModel] = None | ||
|
||
@property | ||
def parsed(self) -> T: | ||
def parsed(self) -> TModel: | ||
if not self._parsed: | ||
raise ParameterNotParsedError( | ||
f"Parameter of type {self._model.__name__} has not been parsed yet" | ||
|
@@ -55,17 +62,30 @@ def from_body(self, json_body: str) -> Self: | |
self._parsed = self._model.model_validate(json_body) | ||
return self | ||
|
||
async def validate_request(self, request: web.Request) -> Self: | ||
if not request.can_read_body: | ||
raise MalformedRequestBody( | ||
f"Malformed body - URL: {request.url}, Method: {request.method}" | ||
) | ||
try: | ||
body = await request.json() | ||
except json.decoder.JSONDecodeError as e: | ||
raise MalformedRequestBody( | ||
f"Malformed body - URL: {request.url}, Method: {request.method}, error: {repr(e)}" | ||
) | ||
return self.from_body(body) | ||
|
||
class QueryParam(Generic[T]): | ||
_model: Type[T] | ||
_parsed: Optional[T] | ||
|
||
def __init__(self, model: Type[T]) -> None: | ||
class QueryParam(_HTTPRequestDataValidator, Generic[TModel]): | ||
_model: Type[TModel] | ||
_parsed: Optional[TModel] | ||
|
||
def __init__(self, model: Type[TModel]) -> None: | ||
self._model = model | ||
self._parsed: Optional[T] = None | ||
self._parsed: Optional[TModel] = None | ||
|
||
@property | ||
def parsed(self) -> T: | ||
def parsed(self) -> TModel: | ||
if not self._parsed: | ||
raise ParameterNotParsedError( | ||
f"Parameter of type {self._model.__name__} has not been parsed yet" | ||
|
@@ -76,17 +96,20 @@ def from_query(self, query: MultiMapping[str]) -> Self: | |
self._parsed = self._model.model_validate(query) | ||
return self | ||
|
||
async def validate_request(self, request: web.Request) -> Self: | ||
return self.from_query(request.query) | ||
|
||
|
||
class HeaderParam(Generic[T]): | ||
_model: Type[T] | ||
_parsed: Optional[T] | ||
class HeaderParam(_HTTPRequestDataValidator, Generic[TModel]): | ||
_model: Type[TModel] | ||
_parsed: Optional[TModel] | ||
|
||
def __init__(self, model: Type[T]) -> None: | ||
def __init__(self, model: Type[TModel]) -> None: | ||
self._model = model | ||
self._parsed: Optional[T] = None | ||
self._parsed: Optional[TModel] = None | ||
|
||
@property | ||
def parsed(self) -> T: | ||
def parsed(self) -> TModel: | ||
if not self._parsed: | ||
raise ParameterNotParsedError( | ||
f"Parameter of type {self._model.__name__} has not been parsed yet" | ||
|
@@ -97,17 +120,20 @@ def from_header(self, headers: CIMultiDictProxy[str]) -> Self: | |
self._parsed = self._model.model_validate(headers) | ||
return self | ||
|
||
async def validate_request(self, request: web.Request) -> Self: | ||
return self.from_header(request.headers) | ||
|
||
class PathParam(Generic[T]): | ||
_model: Type[T] | ||
_parsed: Optional[T] | ||
|
||
def __init__(self, model: Type[T]) -> None: | ||
class PathParam(_HTTPRequestDataValidator, Generic[TModel]): | ||
_model: Type[TModel] | ||
_parsed: Optional[TModel] | ||
|
||
def __init__(self, model: Type[TModel]) -> None: | ||
self._model = model | ||
self._parsed: Optional[T] = None | ||
self._parsed: Optional[TModel] = None | ||
|
||
@property | ||
def parsed(self) -> T: | ||
def parsed(self) -> TModel: | ||
if not self._parsed: | ||
raise ParameterNotParsedError( | ||
f"Parameter of type {self._model.__name__} has not been parsed yet" | ||
|
@@ -118,6 +144,9 @@ def from_path(self, match_info: UrlMappingMatchInfo) -> Self: | |
self._parsed = self._model.model_validate(match_info) | ||
return self | ||
|
||
async def validate_request(self, request: web.Request) -> Self: | ||
return self.from_path(request.match_info) | ||
|
||
|
||
class MiddlewareParam(ABC, BaseModel): | ||
@classmethod | ||
|
@@ -160,6 +189,7 @@ def status_code(self) -> int: | |
|
||
|
||
_ParamType: TypeAlias = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam | ||
_ValidatorType: TypeAlias = BodyParam | QueryParam | PathParam | HeaderParam | type[MiddlewareParam] | ||
|
||
|
||
async def _extract_param_value(request: web.Request, input_param_type: Any) -> _ParamType: | ||
|
@@ -221,18 +251,15 @@ def get_all(self) -> dict[str, _ParamType]: | |
return self._params | ||
|
||
|
||
HandlerT = TypeVar("HandlerT") | ||
HandlerReturn = Awaitable[APIResponse] | Coroutine[Any, Any, APIResponse] | ||
|
||
ResponseType = web.Response | APIResponse | ||
AwaitableResponse = Awaitable[ResponseType] | Coroutine[Any, Any, ResponseType] | ||
|
||
BaseHandler: TypeAlias = Callable[..., AwaitableResponse] | ||
ParsedRequestHandler: TypeAlias = Callable[..., Awaitable[web.Response]] | ||
BaseHandler: TypeAlias = Callable[..., HandlerReturn] | ||
ParsedRequestHandler: TypeAlias = Callable[..., Awaitable[web.StreamResponse]] | ||
|
||
|
||
async def _parse_and_execute_handler( | ||
request: web.Request, handler: BaseHandler, signature: Signature | ||
) -> web.Response: | ||
) -> web.StreamResponse: | ||
handler_params = _HandlerParameters() | ||
for name, param in signature.parameters.items(): | ||
# If handler has no parameter, for loop is skipped | ||
|
@@ -253,13 +280,71 @@ async def _parse_and_execute_handler( | |
|
||
response = await handler(**handler_params.get_all()) | ||
|
||
if not isinstance(response, APIResponse): | ||
raise InvalidAPIParameters( | ||
f"Only Response wrapped by APIResponse Class can be handle: {type(response)}" | ||
if isinstance(response, APIResponse): | ||
return web.json_response( | ||
response.to_json, | ||
status=response.status_code, | ||
) | ||
return response | ||
|
||
|
||
def _register_parameter_validator( | ||
signature: Signature, | ||
*, | ||
handler_name: str, | ||
) -> dict[str, _ValidatorType]: | ||
signature_validator_map: dict[str, _ValidatorType] = {} | ||
for name, param in signature.parameters.items(): | ||
if param.annotation is inspect.Parameter.empty: | ||
raise InvalidAPIHandlerDefinition( | ||
f"Not allowed signature for API handler function (handler:{handler_name},name:{name})" | ||
) | ||
param_type = param.annotation | ||
original_type = get_origin(param_type) | ||
if original_type is None: | ||
if issubclass(param_type, MiddlewareParam): | ||
signature_validator_map[name] = param_type | ||
continue | ||
else: | ||
raise InvalidAPIHandlerDefinition( | ||
f"Not allowed signature for API handler function. (handler:{handler_name}, name:{name}, type:{original_type})" | ||
) | ||
model_args = get_args(param_type) | ||
try: | ||
validation_model = model_args[0] | ||
except IndexError: | ||
raise InvalidAPIHandlerDefinition( | ||
f"API parameter model got no argument (handler:{handler_name}, name:{name}, type:{original_type})" | ||
) | ||
|
||
param_instance = param_type(validation_model) | ||
signature_validator_map[name] = param_instance | ||
return signature_validator_map | ||
|
||
|
||
async def _serialize_parameter( | ||
request: web.Request, param_instance_or_class: _ValidatorType | ||
) -> _ParamType: | ||
param_instance: _ParamType | ||
match param_instance_or_class: | ||
case PathParam() | BodyParam() | HeaderParam() | QueryParam(): | ||
Comment on lines
+329
to
+330
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to ask why you changed from comparing classes to comparing instances? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is recommended to use pattern matching rather than using multiple |
||
try: | ||
param_instance = await param_instance_or_class.validate_request(request) | ||
except ValidationError as e: | ||
raise InvalidAPIParameters(str(e)) | ||
case _: | ||
try: | ||
param_instance = param_instance_or_class.from_request(request) | ||
except ValidationError as e: | ||
raise MiddlewareParamParsingFailed( | ||
f"Failed while parsing {param_instance_or_class}. (error:{repr(e)})" | ||
) | ||
return param_instance | ||
|
||
|
||
def _parse_response(response: APIResponse) -> web.Response: | ||
return web.json_response( | ||
response.to_json, | ||
data=response.to_json, | ||
status=response.status_code, | ||
) | ||
|
||
|
@@ -337,16 +422,20 @@ async def handler( | |
|
||
original_signature: Signature = inspect.signature(handler) | ||
|
||
sanitized_signature = original_signature.replace( | ||
parameters=list(original_signature.parameters.values())[1:] | ||
) | ||
signature_validator_map = _register_parameter_validator( | ||
sanitized_signature, handler_name=str(handler) | ||
) | ||
|
||
@functools.wraps(handler) | ||
async def wrapped(first_arg: Any, *args, **kwargs) -> web.Response: | ||
instance = first_arg | ||
sanitized_signature = original_signature.replace( | ||
parameters=list(original_signature.parameters.values())[1:] | ||
) | ||
return await _parse_and_execute_handler( | ||
request=args[0], | ||
handler=lambda *a, **kw: handler(instance, *a, **kw), | ||
signature=sanitized_signature, | ||
) | ||
async def wrapped(first_arg: Any, request: web.Request) -> web.Response: | ||
kwargs: dict[str, _ParamType] = {} | ||
for name, param_instance_or_class in signature_validator_map.items(): | ||
param_instance = await _serialize_parameter(request, param_instance_or_class) | ||
kwargs[name] = param_instance | ||
response = await handler(first_arg, **kwargs) | ||
return _parse_response(response) | ||
|
||
return wrapped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a question, what is the difference between two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can initialize all
...Param
types before evaluating request data exceptMiddlewareParam
, which requires input data for initializing._ValidatorType
represents the types that can evaluate and validate input data.