Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple security schemes in OpenAPI Client #235

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions fastagency/api/openapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
self.registered_funcs: list[Callable[..., Any]] = []
self.globals: dict[str, Any] = {}

self.security: dict[str, BaseSecurity] = {}
self.security: dict[str, list[BaseSecurity]] = {}
self.security_params: dict[Optional[str], BaseSecurityParameters] = {}

@staticmethod
Expand Down Expand Up @@ -106,18 +106,34 @@
if security is None:
raise ValueError(f"Security is not set for '{name}'")

if not security.accept(security_params):
for match_security in security:
if match_security.accept(security_params):
break
else:
raise ValueError(
f"Security parameters {security_params} do not match security {security}"
)

self.security_params[name] = security_params

def _get_security_params(self, name: str) -> Optional[BaseSecurityParameters]:
def _get_matching_security(
self, security: list[BaseSecurity], security_params: BaseSecurityParameters
) -> BaseSecurity:
# check if security matches security parameters
for match_security in security:
if match_security.accept(security_params):
return match_security
raise ValueError(
f"Security parameters {security_params} does not match any given security {security}"
)

def _get_security_params(
self, name: str
) -> tuple[Optional[BaseSecurityParameters], Optional[BaseSecurity]]:
# check if security is set for the method
security = self.security.get(name)
if security is None:
return None
if not security:
return None, None

Check warning on line 136 in fastagency/api/openapi/client.py

View check run for this annotation

Codecov / codecov/patch

fastagency/api/openapi/client.py#L136

Added line #L136 was not covered by tests

security_params = self.security_params.get(name)
if security_params is None:
Expand All @@ -128,20 +144,16 @@
f"Security parameters are not set for {name} and there are no default security parameters"
)

# check if security matches security parameters
if not security.accept(security_params):
raise ValueError(
f"Security parameters {security_params} do not match security {security}"
)
match_security = self._get_matching_security(security, security_params)

return security_params
return security_params, match_security

def _request(
self,
method: Literal["put", "get", "post", "delete"],
path: str,
description: Optional[str] = None,
security: Optional[BaseSecurity] = None,
security: Optional[list[BaseSecurity]] = None,
**kwargs: Any,
) -> Callable[..., dict[str, Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., dict[str, Any]]:
Expand All @@ -156,13 +168,13 @@

security = self.security.get(name)
if security is not None:
security_params = self._get_security_params(name)
security_params, matched_security = self._get_security_params(name)
if security_params is None:
raise ValueError(
f"Security parameters are not set for '{name}'"
)
else:
security_params.apply(params, body_dict, security)
security_params.apply(params, body_dict, matched_security) # type: ignore [arg-type]

response = getattr(requests, method)(url, params=params, **body_dict)
return response.json() # type: ignore [no-any-return]
Expand Down
6 changes: 3 additions & 3 deletions templates/main.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ app = OpenAPI(
, tags={{operation.tags}}
{% endif %}
{% if operation.security %}
{% for security in operation.security %}
, security=[{% for security in operation.security %}
{% for key, value in security.items() %}
, security={{security_parameters[key]}}
{% endfor %}
{{security_parameters[key]}},
{% endfor %}
{% endfor %}]
{% endif %}
)
def {{operation.function_name}}({{operation.snake_case_arguments}}
Expand Down
32 changes: 23 additions & 9 deletions tests/api/openapi/security/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,42 @@

import pytest
import uvicorn
from fastapi import Depends, FastAPI, Query
from fastapi import Depends, FastAPI, HTTPException, Query, status
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from fastapi.security import APIKeyCookie, APIKeyHeader

from ....conftest import Server, find_free_port


def create_secure_fastapi_app(host: str, port: int) -> FastAPI:
app = FastAPI(servers=[{"url": f"http://{host}:{port}"}])

header_scheme = APIKeyHeader(name="x-key")

universal_api_key = "super secret key" # pragma: allowlist secret
api_key = "super secret key" # pragma: allowlist secret
api_key_name = "access_token" # pragma: allowlist secret

header_scheme = APIKeyHeader(name=api_key_name, auto_error=False)
cookie_scheme = APIKeyCookie(name=api_key_name, auto_error=False)

async def get_api_key(
api_key_header: str = Depends(header_scheme),
api_key_cookie: str = Depends(cookie_scheme),
) -> str:
if api_key_header == api_key:
return api_key_header
elif api_key_cookie == api_key:
return api_key_cookie
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing API Key",
)

@app.get("/items/")
async def read_items(
city: Annotated[str, Query(description="city for which forecast is requested")],
key: str = Depends(header_scheme),
api_key: str = Depends(get_api_key),
) -> JSONResponse:
is_authenticated = key == universal_api_key
content = {"is_authenticated": is_authenticated}
status_code = 200 if is_authenticated else 403
content = {"api_key": api_key}
status_code = 200
return JSONResponse(status_code=status_code, content=content)

return app
Expand Down
7 changes: 5 additions & 2 deletions tests/api/openapi/security/expected_main_gen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from typing import *
from typing import Any, Union

from fastagency.api.openapi import OpenAPI
from fastagency.api.openapi.security import APIKeyHeader
from fastagency.api.openapi.security import APIKeyCookie, APIKeyHeader

from models_gen import HTTPValidationError

Expand All @@ -23,7 +23,10 @@ app = OpenAPI(
'/items/',
response_model=Any,
responses={'422': {'model': HTTPValidationError}},
security=APIKeyHeader(name="x-key"),
security=[
APIKeyHeader(name="access_token"),
APIKeyCookie(name="access_token"),
],
)
def read_items_items__get(
city: Annotated[str, """city for which forecast is requested"""]
Expand Down
10 changes: 9 additions & 1 deletion tests/api/openapi/security/expected_openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
"security": [
{
"APIKeyHeader": []
},
{
"APIKeyCookie": []
}
],
"parameters": [
Expand Down Expand Up @@ -108,7 +111,12 @@
"APIKeyHeader": {
"type": "apiKey",
"in": "header",
"name": "x-key"
"name": "access_token"
},
"APIKeyCookie": {
"type": "apiKey",
"in": "cookie",
"name": "access_token"
}
}
}
Expand Down
62 changes: 57 additions & 5 deletions tests/api/openapi/security/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import tempfile
from pathlib import Path

import pytest
import requests

from fastagency.api.openapi import OpenAPI
from fastagency.api.openapi.security import APIKeyHeader
from fastagency.api.openapi.security import APIKeyCookie, APIKeyHeader, APIKeyQuery


def test_secure_app_openapi(secure_fastapi_url: str) -> None:
Expand Down Expand Up @@ -56,6 +57,7 @@ def test_generate_client(secure_fastapi_url: str) -> None:
with expected_models_gen_path.open() as f:
expected_models_gen = f.readlines()[4:]

# print(actual_main_gen_txt)
assert actual_main_gen_txt == expected_main_gen_txt
assert actual_models_gen == expected_models_gen

Expand All @@ -81,15 +83,65 @@ def test_import_and_call_generate_client(secure_fastapi_url: str) -> None:

assert generated_client_app.security != {}, generated_client_app.security

api_key = "super secret key" # pragma: allowlist secret

# set global security params for all methods
# generated_client_app.set_security_params(APIKeyHeader.Parameters(value="super secret key"))
# generated_client_app.set_security_params(APIKeyHeader.Parameters(value=api_key))

# or set security params for a specific method
generated_client_app.set_security_params(
APIKeyHeader.Parameters(value="super secret key"),
APIKeyHeader.Parameters(value=api_key),
"read_items_items__get",
)

# no security params added to the signature of the method
client_resp = read_items_items__get(city="New York")
assert client_resp == {"is_authenticated": True}
assert client_resp == {"api_key": api_key}

# Test with cookie security
generated_client_app.set_security_params(
APIKeyCookie.Parameters(value=api_key),
"read_items_items__get",
)
client_resp = read_items_items__get(city="New York")
assert client_resp == {"api_key": api_key}


def test__get_matching_security(secure_fastapi_url: str) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
td = Path(temp_dir) / "gen"

resp = requests.get(f"{secure_fastapi_url}/openapi.json")
assert resp.status_code == 200
openapi_json = resp.json()

main_name = OpenAPI.generate_code(
input_text=json.dumps(openapi_json),
output_dir=td,
)
assert main_name == "main_gen"

sys.path.insert(1, str(td))
from main_gen import app as generated_client_app

api_key_header = APIKeyHeader(name="access_token")
api_key_cookie = APIKeyCookie(name="access_token")
security = [
api_key_header,
api_key_cookie,
]
security_params = APIKeyHeader.Parameters(value="super secret key")
actual_matching_security = generated_client_app._get_matching_security(
security, security_params
)

assert actual_matching_security == api_key_header

with pytest.raises(ValueError) as e: # noqa: PT011
generated_client_app._get_matching_security(
security, APIKeyQuery.Parameters(value="super secret key")
)

assert (
str(e.value)
== f"Security parameters {security_params} does not match any given security {security}"
)