diff --git a/fastagency/api/openapi/client.py b/fastagency/api/openapi/client.py index 3a0f41827..80251827c 100644 --- a/fastagency/api/openapi/client.py +++ b/fastagency/api/openapi/client.py @@ -48,7 +48,7 @@ def __init__( self.registered_funcs: list[Callable[..., Any]] = [] self.globals: dict[str, Any] = {} - self.security: dict[str, BaseSecurity] = {} + self.security: dict[str, list[BaseSecurity]] = {} self.security_params: dict[Optional[str], BaseSecurityParameters] = {} @staticmethod @@ -106,18 +106,34 @@ def set_security_params( if security is None: raise ValueError(f"Security is not set for '{name}'") - if not security.accept(security_params): + for match_security in security: + if match_security.accept(security_params): + break + else: raise ValueError( f"Security parameters {security_params} do not match security {security}" ) self.security_params[name] = security_params - def _get_security_params(self, name: str) -> Optional[BaseSecurityParameters]: + def _get_matching_security( + self, security: list[BaseSecurity], security_params: BaseSecurityParameters + ) -> BaseSecurity: + # check if security matches security parameters + for match_security in security: + if match_security.accept(security_params): + return match_security + raise ValueError( + f"Security parameters {security_params} does not match any given security {security}" + ) + + def _get_security_params( + self, name: str + ) -> tuple[Optional[BaseSecurityParameters], Optional[BaseSecurity]]: # check if security is set for the method security = self.security.get(name) - if security is None: - return None + if not security: + return None, None security_params = self.security_params.get(name) if security_params is None: @@ -128,20 +144,16 @@ def _get_security_params(self, name: str) -> Optional[BaseSecurityParameters]: f"Security parameters are not set for {name} and there are no default security parameters" ) - # check if security matches security parameters - if not security.accept(security_params): - raise ValueError( - f"Security parameters {security_params} do not match security {security}" - ) + match_security = self._get_matching_security(security, security_params) - return security_params + return security_params, match_security def _request( self, method: Literal["put", "get", "post", "delete"], path: str, description: Optional[str] = None, - security: Optional[BaseSecurity] = None, + security: Optional[list[BaseSecurity]] = None, **kwargs: Any, ) -> Callable[..., dict[str, Any]]: def decorator(func: Callable[..., Any]) -> Callable[..., dict[str, Any]]: @@ -156,13 +168,13 @@ def wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]: security = self.security.get(name) if security is not None: - security_params = self._get_security_params(name) + security_params, matched_security = self._get_security_params(name) if security_params is None: raise ValueError( f"Security parameters are not set for '{name}'" ) else: - security_params.apply(params, body_dict, security) + security_params.apply(params, body_dict, matched_security) # type: ignore [arg-type] response = getattr(requests, method)(url, params=params, **body_dict) return response.json() # type: ignore [no-any-return] diff --git a/templates/main.jinja2 b/templates/main.jinja2 index 74462f6c6..aaa052df3 100644 --- a/templates/main.jinja2 +++ b/templates/main.jinja2 @@ -38,11 +38,11 @@ app = OpenAPI( , tags={{operation.tags}} {% endif %} {% if operation.security %} - {% for security in operation.security %} + , security=[{% for security in operation.security %} {% for key, value in security.items() %} - , security={{security_parameters[key]}} - {% endfor %} + {{security_parameters[key]}}, {% endfor %} + {% endfor %}] {% endif %} ) def {{operation.function_name}}({{operation.snake_case_arguments}} diff --git a/tests/api/openapi/security/conftest.py b/tests/api/openapi/security/conftest.py index a32d432aa..804e60abb 100644 --- a/tests/api/openapi/security/conftest.py +++ b/tests/api/openapi/security/conftest.py @@ -5,9 +5,9 @@ import pytest import uvicorn -from fastapi import Depends, FastAPI, Query +from fastapi import Depends, FastAPI, HTTPException, Query, status from fastapi.responses import JSONResponse -from fastapi.security import APIKeyHeader +from fastapi.security import APIKeyCookie, APIKeyHeader from ....conftest import Server, find_free_port @@ -15,18 +15,32 @@ def create_secure_fastapi_app(host: str, port: int) -> FastAPI: app = FastAPI(servers=[{"url": f"http://{host}:{port}"}]) - header_scheme = APIKeyHeader(name="x-key") - - universal_api_key = "super secret key" # pragma: allowlist secret + api_key = "super secret key" # pragma: allowlist secret + api_key_name = "access_token" # pragma: allowlist secret + + header_scheme = APIKeyHeader(name=api_key_name, auto_error=False) + cookie_scheme = APIKeyCookie(name=api_key_name, auto_error=False) + + async def get_api_key( + api_key_header: str = Depends(header_scheme), + api_key_cookie: str = Depends(cookie_scheme), + ) -> str: + if api_key_header == api_key: + return api_key_header + elif api_key_cookie == api_key: + return api_key_cookie + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API Key", + ) @app.get("/items/") async def read_items( city: Annotated[str, Query(description="city for which forecast is requested")], - key: str = Depends(header_scheme), + api_key: str = Depends(get_api_key), ) -> JSONResponse: - is_authenticated = key == universal_api_key - content = {"is_authenticated": is_authenticated} - status_code = 200 if is_authenticated else 403 + content = {"api_key": api_key} + status_code = 200 return JSONResponse(status_code=status_code, content=content) return app diff --git a/tests/api/openapi/security/expected_main_gen.txt b/tests/api/openapi/security/expected_main_gen.txt index c288bd22a..14bedd60b 100644 --- a/tests/api/openapi/security/expected_main_gen.txt +++ b/tests/api/openapi/security/expected_main_gen.txt @@ -8,7 +8,7 @@ from typing import * from typing import Any, Union from fastagency.api.openapi import OpenAPI -from fastagency.api.openapi.security import APIKeyHeader +from fastagency.api.openapi.security import APIKeyCookie, APIKeyHeader from models_gen import HTTPValidationError @@ -23,7 +23,10 @@ app = OpenAPI( '/items/', response_model=Any, responses={'422': {'model': HTTPValidationError}}, - security=APIKeyHeader(name="x-key"), + security=[ + APIKeyHeader(name="access_token"), + APIKeyCookie(name="access_token"), + ], ) def read_items_items__get( city: Annotated[str, """city for which forecast is requested"""] diff --git a/tests/api/openapi/security/expected_openapi.json b/tests/api/openapi/security/expected_openapi.json index 6c25bd7a6..2bc550c34 100644 --- a/tests/api/openapi/security/expected_openapi.json +++ b/tests/api/openapi/security/expected_openapi.json @@ -17,6 +17,9 @@ "security": [ { "APIKeyHeader": [] + }, + { + "APIKeyCookie": [] } ], "parameters": [ @@ -108,7 +111,12 @@ "APIKeyHeader": { "type": "apiKey", "in": "header", - "name": "x-key" + "name": "access_token" + }, + "APIKeyCookie": { + "type": "apiKey", + "in": "cookie", + "name": "access_token" } } } diff --git a/tests/api/openapi/security/test_security.py b/tests/api/openapi/security/test_security.py index 0f099a68c..732839554 100644 --- a/tests/api/openapi/security/test_security.py +++ b/tests/api/openapi/security/test_security.py @@ -3,10 +3,11 @@ import tempfile from pathlib import Path +import pytest import requests from fastagency.api.openapi import OpenAPI -from fastagency.api.openapi.security import APIKeyHeader +from fastagency.api.openapi.security import APIKeyCookie, APIKeyHeader, APIKeyQuery def test_secure_app_openapi(secure_fastapi_url: str) -> None: @@ -56,6 +57,7 @@ def test_generate_client(secure_fastapi_url: str) -> None: with expected_models_gen_path.open() as f: expected_models_gen = f.readlines()[4:] + # print(actual_main_gen_txt) assert actual_main_gen_txt == expected_main_gen_txt assert actual_models_gen == expected_models_gen @@ -81,15 +83,65 @@ def test_import_and_call_generate_client(secure_fastapi_url: str) -> None: assert generated_client_app.security != {}, generated_client_app.security + api_key = "super secret key" # pragma: allowlist secret + # set global security params for all methods - # generated_client_app.set_security_params(APIKeyHeader.Parameters(value="super secret key")) + # generated_client_app.set_security_params(APIKeyHeader.Parameters(value=api_key)) # or set security params for a specific method generated_client_app.set_security_params( - APIKeyHeader.Parameters(value="super secret key"), + APIKeyHeader.Parameters(value=api_key), "read_items_items__get", ) - # no security params added to the signature of the method client_resp = read_items_items__get(city="New York") - assert client_resp == {"is_authenticated": True} + assert client_resp == {"api_key": api_key} + + # Test with cookie security + generated_client_app.set_security_params( + APIKeyCookie.Parameters(value=api_key), + "read_items_items__get", + ) + client_resp = read_items_items__get(city="New York") + assert client_resp == {"api_key": api_key} + + +def test__get_matching_security(secure_fastapi_url: str) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + td = Path(temp_dir) / "gen" + + resp = requests.get(f"{secure_fastapi_url}/openapi.json") + assert resp.status_code == 200 + openapi_json = resp.json() + + main_name = OpenAPI.generate_code( + input_text=json.dumps(openapi_json), + output_dir=td, + ) + assert main_name == "main_gen" + + sys.path.insert(1, str(td)) + from main_gen import app as generated_client_app + + api_key_header = APIKeyHeader(name="access_token") + api_key_cookie = APIKeyCookie(name="access_token") + security = [ + api_key_header, + api_key_cookie, + ] + security_params = APIKeyHeader.Parameters(value="super secret key") + actual_matching_security = generated_client_app._get_matching_security( + security, security_params + ) + + assert actual_matching_security == api_key_header + + with pytest.raises(ValueError) as e: # noqa: PT011 + generated_client_app._get_matching_security( + security, APIKeyQuery.Parameters(value="super secret key") + ) + + assert ( + str(e.value) + == f"Security parameters {security_params} does not match any given security {security}" + )