From 05441a6041ebce147a45a9078d1c6fb5346e46cf Mon Sep 17 00:00:00 2001 From: Keming Date: Sat, 7 Oct 2023 14:01:35 +0800 Subject: [PATCH] fix: return original resp when no basemodel found in the resp (#353) Signed-off-by: Keming --- spectree/_pydantic.py | 18 +++++++ spectree/plugins/base.py | 12 +++-- .../test_plugin_spec[falcon][full_spec].json | 52 +++++++++++++++++++ .../test_plugin_spec[flask][full_spec].json | 52 +++++++++++++++++++ ...ugin_spec[flask_blueprint][full_spec].json | 52 +++++++++++++++++++ ...st_plugin_spec[flask_view][full_spec].json | 52 +++++++++++++++++++ ...est_plugin_spec[starlette][full_spec].json | 52 +++++++++++++++++++ tests/common.py | 8 ++- tests/flask_imports/__init__.py | 2 + tests/flask_imports/dry_plugin_flask.py | 6 +++ tests/test_plugin_falcon.py | 19 ++++++- tests/test_plugin_flask.py | 7 +++ tests/test_plugin_flask_blueprint.py | 7 +++ tests/test_plugin_flask_view.py | 11 ++++ tests/test_plugin_starlette.py | 14 ++++- tests/test_pydantic.py | 18 +++++++ 16 files changed, 375 insertions(+), 7 deletions(-) diff --git a/spectree/_pydantic.py b/spectree/_pydantic.py index a1f96579..513f4a69 100644 --- a/spectree/_pydantic.py +++ b/spectree/_pydantic.py @@ -59,6 +59,24 @@ def is_base_model_instance(value: Any) -> bool: return is_base_model(type(value)) +def is_partial_base_model_instance(instance: Any) -> bool: + """Check if it's a Pydantic BaseModel instance or [BaseModel] + or {key: BaseModel} instance. + """ + if not instance: + return False + if is_base_model_instance(instance): + return True + if isinstance(instance, dict): + return any( + is_partial_base_model_instance(key) or is_partial_base_model_instance(value) + for key, value in instance.items() + ) + if isinstance(instance, list) or isinstance(instance, tuple): + return any(is_partial_base_model_instance(value) for value in instance) + return False + + def is_root_model(t: Any) -> bool: """Check whether a type is a Pydantic RootModel.""" return is_base_model(t) and ROOT_FIELD in t.__fields__ diff --git a/spectree/plugins/base.py b/spectree/plugins/base.py index 47f501c7..f3e31f7e 100644 --- a/spectree/plugins/base.py +++ b/spectree/plugins/base.py @@ -12,7 +12,7 @@ Union, ) -from .._pydantic import serialize_model_instance +from .._pydantic import is_partial_base_model_instance, serialize_model_instance from .._types import JsonType, ModelType, OptionalModelType from ..config import Configuration from ..response import Response @@ -162,6 +162,7 @@ def validate_response( skip_validation = True final_response_payload = serialize_model_instance(response_payload) else: + # non-BaseModel response or partial BaseModel response final_response_payload = response_payload if not skip_validation: @@ -170,8 +171,11 @@ def validate_response( if isinstance(final_response_payload, bytes) else validation_model.parse_obj ) - final_response_payload = serialize_model_instance( - validator(final_response_payload) - ) + validated_instance = validator(final_response_payload) + # in case the response model contains (alias, default_none, unset fields) which + # might not be the what the users want, we only return the validated dict when + # the response contains BaseModel + if is_partial_base_model_instance(final_response_payload): + final_response_payload = serialize_model_instance(validated_instance) return ResponseValidationResult(payload=final_response_payload) diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json index 3e521647..af3a7aac 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json @@ -117,6 +117,27 @@ "title": "JSON", "type": "object" }, + "OptionalAliasResp.7068f62": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + }, + "schema": { + "title": "Schema", + "type": "string" + } + }, + "required": [ + "schema" + ], + "title": "OptionalAliasResp", + "type": "object" + }, "Query.7068f62": { "properties": { "order": { @@ -415,6 +436,37 @@ "tags": [] } }, + "/api/return_optional_alias": { + "get": { + "description": "", + "operationId": "get__api_return_optional_alias", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptionalAliasResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "on_get ", + "tags": [] + } + }, "/api/return_root": { "get": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json index d0210257..605b181c 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json @@ -135,6 +135,27 @@ "title": "JSON", "type": "object" }, + "OptionalAliasResp.7068f62": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + }, + "schema": { + "title": "Schema", + "type": "string" + } + }, + "required": [ + "schema" + ], + "title": "OptionalAliasResp", + "type": "object" + }, "Query.7068f62": { "properties": { "order": { @@ -477,6 +498,37 @@ "tags": [] } }, + "/api/return_optional_alias": { + "get": { + "description": "", + "operationId": "get__api_return_optional_alias", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptionalAliasResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_optional_alias_resp ", + "tags": [] + } + }, "/api/return_root": { "get": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json index e44934b1..411485b4 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json @@ -135,6 +135,27 @@ "title": "JSON", "type": "object" }, + "OptionalAliasResp.7068f62": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + }, + "schema": { + "title": "Schema", + "type": "string" + } + }, + "required": [ + "schema" + ], + "title": "OptionalAliasResp", + "type": "object" + }, "Query.7068f62": { "properties": { "order": { @@ -477,6 +498,37 @@ "tags": [] } }, + "/api/return_optional_alias": { + "get": { + "description": "", + "operationId": "get__api_return_optional_alias", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptionalAliasResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_optional_alias ", + "tags": [] + } + }, "/api/return_root": { "get": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json index 09a059f7..fd499d9c 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json @@ -135,6 +135,27 @@ "title": "JSON", "type": "object" }, + "OptionalAliasResp.7068f62": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + }, + "schema": { + "title": "Schema", + "type": "string" + } + }, + "required": [ + "schema" + ], + "title": "OptionalAliasResp", + "type": "object" + }, "Query.7068f62": { "properties": { "order": { @@ -482,6 +503,37 @@ "tags": [] } }, + "/api/return_optional_alias": { + "get": { + "description": "", + "operationId": "get__api_return_optional_alias", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptionalAliasResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "get ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json index aca73a73..f4364744 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json @@ -117,6 +117,27 @@ "title": "JSON", "type": "object" }, + "OptionalAliasResp.7068f62": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + }, + "schema": { + "title": "Schema", + "type": "string" + } + }, + "required": [ + "schema" + ], + "title": "OptionalAliasResp", + "type": "object" + }, "Query.7068f62": { "properties": { "order": { @@ -378,6 +399,37 @@ "tags": [] } }, + "/api/return_optional_alias": { + "get": { + "description": "", + "operationId": "get__api_return_optional_alias", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OptionalAliasResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_optional_alias ", + "tags": [] + } + }, "/api/return_root": { "get": { "description": "", diff --git a/tests/common.py b/tests/common.py index 4290d186..943a70ad 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,7 +1,7 @@ import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum, IntEnum -from typing import Any, Dict, List, Union, cast +from typing import Any, Dict, List, Optional, Union, cast from spectree import BaseFile, ExternalDocs, SecurityScheme, SecuritySchemeData, Tag from spectree._pydantic import BaseModel, Field, root_validator @@ -43,6 +43,12 @@ class StrDict(BaseModel): __root__: Dict[str, str] +class OptionalAliasResp(BaseModel): + alias_schema: str = Field(alias="schema") + name: Optional[str] + limit: Optional[int] = None + + class Resp(BaseModel): name: str score: List[int] diff --git a/tests/flask_imports/__init__.py b/tests/flask_imports/__init__.py index 609c18b8..77c4de3b 100644 --- a/tests/flask_imports/__init__.py +++ b/tests/flask_imports/__init__.py @@ -4,6 +4,7 @@ test_flask_make_response_get, test_flask_make_response_post, test_flask_no_response, + test_flask_optional_alias_response, test_flask_return_list_request, test_flask_return_model, test_flask_skip_validation, @@ -17,6 +18,7 @@ "test_flask_skip_validation", "test_flask_validation_error_response_status_code", "test_flask_doc", + "test_flask_optional_alias_response", "test_flask_validate_post_data", "test_flask_no_response", "test_flask_upload_file", diff --git a/tests/flask_imports/dry_plugin_flask.py b/tests/flask_imports/dry_plugin_flask.py index d14697bc..668d9cea 100644 --- a/tests/flask_imports/dry_plugin_flask.py +++ b/tests/flask_imports/dry_plugin_flask.py @@ -251,3 +251,9 @@ def test_flask_upload_file(client): ) assert resp.status_code == 200, resp.data assert resp.json["content"] == file_content + + +def test_flask_optional_alias_response(client): + resp = client.get("/api/return_optional_alias") + assert resp.status_code == 200 + assert resp.json == {"schema": "test"}, resp.json diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index 48f81088..10a9118e 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -13,6 +13,7 @@ FormFileUpload, Headers, ListJSON, + OptionalAliasResp, Query, Resp, RootResp, @@ -29,7 +30,8 @@ def before_handler(req, resp, err, instance): def after_handler(req, resp, err, instance): - resp.set_header("X-Name", instance.name) + if hasattr(instance, "name"): + resp.set_header("X-Name", instance.name) api = SpecTree("falcon", before=before_handler, after=after_handler, annotations=True) @@ -229,6 +231,12 @@ def on_get(self, req, resp): ) +class ReturnOptionalAliasView: + @api.validate(resp=Response(HTTP_200=OptionalAliasResp)) + def on_get(self, req, resp): + resp.media = {"schema": "test"} + + class ViewWithCustomSerializer: name = "view with custom serializer" @@ -257,6 +265,7 @@ def on_post(self, req, resp): app.add_route("/api/list_json", ListJsonView()) app.add_route("/api/return_list", ReturnListView()) app.add_route("/api/return_root", ReturnRootView()) +app.add_route("/api/return_optional_alias", ReturnOptionalAliasView()) app.add_route("/api/custom_serializer", ViewWithCustomSerializer()) api.register(app) @@ -530,3 +539,11 @@ def test_falcon_custom_serializer(client): assert resp.status_code == 200 assert resp.json["name"] == "falcon" assert resp.json["score"] == [1, 2, 3] + + +def test_falcon_optional_alias_response(client): + resp = client.simulate_get( + "/api/return_optional_alias", + ) + assert resp.status_code == 200 + assert resp.json == {"schema": "test"}, resp.json diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index 22e57eee..44a23628 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -14,6 +14,7 @@ FormFileUpload, Headers, ListJSON, + OptionalAliasResp, Order, Query, Resp, @@ -230,6 +231,12 @@ def return_root(): ) +@app.route("/api/return_optional_alias", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=OptionalAliasResp)) +def return_optional_alias_resp(): + return {"schema": "test"} + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index 64c30e61..2ce44ce1 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -14,6 +14,7 @@ FormFileUpload, Headers, ListJSON, + OptionalAliasResp, Order, Query, Resp, @@ -217,6 +218,12 @@ def return_root(): ) +@app.route("/api/return_optional_alias", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=OptionalAliasResp)) +def return_optional_alias(): + return {"schema": "test"} + + api.register(app) flask_app = Flask(__name__) diff --git a/tests/test_plugin_flask_view.py b/tests/test_plugin_flask_view.py index 8746ff6e..5d7924f2 100644 --- a/tests/test_plugin_flask_view.py +++ b/tests/test_plugin_flask_view.py @@ -14,6 +14,7 @@ FormFileUpload, Headers, ListJSON, + OptionalAliasResp, Order, Query, Resp, @@ -235,6 +236,12 @@ def get(self): ) +class ReturnOptionalAlias(MethodView): + @api.validate(resp=Response(HTTP_200=OptionalAliasResp)) + def get(self): + return {"schema": "test"} + + app.add_url_rule("/ping", view_func=Ping.as_view("ping")) app.add_url_rule("/api/user/", view_func=User.as_view("user"), methods=["POST"]) app.add_url_rule( @@ -277,6 +284,10 @@ def get(self): "/api/return_make_response", view_func=ReturnMakeResponseView.as_view("return_make_response"), ) +app.add_url_rule( + "/api/return_optional_alias", + view_func=ReturnOptionalAlias.as_view("return_optional_alias"), +) # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index 39b3deda..fa4b0268 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -20,6 +20,7 @@ FormFileUpload, Headers, ListJSON, + OptionalAliasResp, Order, Query, Resp, @@ -175,6 +176,11 @@ async def return_root(request): ) +@api.validate(resp=Response(HTTP_200=OptionalAliasResp)) +async def return_optional_alias(request): + return JSONResponse({"schema": "test"}) + + app = Starlette( routes=[ Route("/ping", Ping), @@ -210,6 +216,7 @@ async def return_root(request): Route("/list_json", list_json, methods=["POST"]), Route("/return_list", return_list, methods=["GET"]), Route("/return_root", return_root, methods=["GET"]), + Route("/return_optional_alias", return_optional_alias, methods=["GET"]), ], ), Mount("/static", app=StaticFiles(directory="docs"), name="static"), @@ -433,7 +440,6 @@ def test_starlette_return_list_request(client, pre_serialize: bool): def test_starlette_return_root_request_sync(client, return_what: str): resp = client.get(f"/api/return_root?pre_serialize=0&return_what={return_what}") assert resp.status_code == 200 - assert resp.status_code == 200 if return_what in ("RootResp_JSON", "JSON"): assert resp.json() == {"name": "user1", "limit": 1} elif return_what in ("RootResp_List", "List"): @@ -449,3 +455,9 @@ def test_starlette_upload_file(client): ) assert resp.status_code == 200, resp.data assert resp.json()["file"] == file_content + + +def test_starlette_return_optional_alias(client): + resp = client.get("/api/return_optional_alias") + assert resp.status_code == 200 + assert resp.json() == {"schema": "test"} diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index ae7b595a..c8e26ec3 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -7,6 +7,7 @@ BaseModel, is_base_model, is_base_model_instance, + is_partial_base_model_instance, is_root_model, is_root_model_instance, serialize_model_instance, @@ -126,6 +127,23 @@ def test_is_base_model_instance(value, expected): assert is_base_model_instance(value) is expected +@pytest.mark.parametrize( + "value, expected", + [ + (SimpleModel(user_id=1), True), + ([0, SimpleModel(user_id=1)], True), + ([1, 2, 3], False), + ((0, SimpleModel(user_id=1)), True), + ((0, 1), False), + ({"test": SimpleModel(user_id=1)}, True), + ({"test": [SimpleModel(user_id=1)]}, True), + ([0, [1, SimpleModel(user_id=1)]], True), + ], +) +def test_is_partial_base_model_instance(value, expected): + assert is_partial_base_model_instance(value) is expected, value + + @pytest.mark.parametrize( "value, expected", [