diff --git a/README.md b/README.md index 3b646bba..3ba62892 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -[![Documentation Status](https://readthedocs.org/projects/cnaas-nms/badge/?version=latest)](https://cnaas-nms.readthedocs.io/en/latest/?badge=latest) [![codecov](https://codecov.io/gh/SUNET/cnaas-nms/branch/master/graph/badge.svg)](https://codecov.io/gh/SUNET/cnaas-nms) [![Python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370/) +[![Documentation Status](https://readthedocs.org/projects/cnaas-nms/badge/?version=latest)](https://cnaas-nms.readthedocs.io/en/latest/?badge=latest) [![codecov](https://codecov.io/gh/SUNET/cnaas-nms/branch/master/graph/badge.svg)](https://codecov.io/gh/SUNET/cnaas-nms) [![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/release/python-3110/) # CNaaS-NMS @@ -47,7 +47,7 @@ python3 -m cnaas_nms.confpush.tests.test_get ## Authorization -Currently we can use two styles for the authorization. We can use the original style or use OIDC style. For OIDC we need to define some env variables or add a auth_config.yaml in the config. The needed variables are: OIDC_CONF_WELL_KNOWN_URL, OIDC_CLIENT_SECRET, OIDC_CLIENT_ID, FRONTEND_CALLBACK_URL and OIDC_ENABLED. To use the OIDC style the last variable needs to be set to true. +Currently we can use two styles for the authorization. We can use the original style or use OIDC style. For OIDC we need to define some env variables or add a auth_config.yaml in the config. The needed variables are: OIDC_CONF_WELL_KNOWN_URL, OIDC_CLIENT_SECRET, OIDC_CLIENT_ID, FRONTEND_CALLBACK_URL and OIDC_ENABLED. To use the OIDC style the last variable needs to be set to true. ## License diff --git a/docker/api/Dockerfile b/docker/api/Dockerfile index 0c0e4501..b9fdf2aa 100644 --- a/docker/api/Dockerfile +++ b/docker/api/Dockerfile @@ -1,4 +1,4 @@ -FROM debian:buster +FROM debian:bookworm USER root @@ -16,7 +16,7 @@ RUN apt-get update \ libpq-dev \ libssl-dev \ net-tools \ - netcat \ + netcat-traditional \ netcat-openbsd \ nginx \ procps \ @@ -26,11 +26,17 @@ RUN apt-get update \ python3-venv \ python3-wheel \ python3-yaml \ + python3-full \ psmisc \ supervisor \ uwsgi-plugin-python3 \ ssh-client \ - && pip3 install --no-cache-dir uwsgi + python3-gevent\ + + && pip3 install --no-cache-dir uwsgi --break-system-packages + # && apt-get install -y uwsgi + + #RUN pip3 install uwsgi # Prepare for supervisord, ngninx COPY config/supervisord_app.conf /etc/supervisor/supervisord.conf @@ -80,7 +86,6 @@ USER www-data # Prepare for uwsgi COPY --chown=root:www-data config/uwsgi.ini /opt/cnaas/venv/cnaas-nms/ - # Expose HTTPS EXPOSE 1443 diff --git a/docker/api/cnaas-setup.sh b/docker/api/cnaas-setup.sh index 25365363..541cd8fa 100755 --- a/docker/api/cnaas-setup.sh +++ b/docker/api/cnaas-setup.sh @@ -7,7 +7,7 @@ export DEBIAN_FRONTEND noninteractive # Start venv -python3 -m venv /opt/cnaas/venv +python3.11 -m venv /opt/cnaas/venv cd /opt/cnaas/venv/ source bin/activate @@ -21,4 +21,6 @@ git config --add remote.origin.fetch "+refs/pull/*/head:refs/remotes/origin/pr/* git fetch --all git checkout $2 # install dependencies +python3 -m pip install --no-cache-dir uwsgi +python3 -m pip install --no-cache-dir uwsgi gevent python3 -m pip install --no-cache-dir -r requirements.txt diff --git a/docker/api/config/auth_config.yml b/docker/api/config/auth_config.yml index b5f681d5..f0a28429 100644 --- a/docker/api/config/auth_config.yml +++ b/docker/api/config/auth_config.yml @@ -3,3 +3,5 @@ oidc_client_secret: "xxx" oidc_client_id: "client-id" frontend_callback_url: "http://localhost/callback" oidc_enabled: False +audience: "client-id" +verify_audience: True diff --git a/docs/configuration/index.rst b/docs/configuration/index.rst index fc0c340e..b8b8928b 100644 --- a/docs/configuration/index.rst +++ b/docs/configuration/index.rst @@ -53,11 +53,13 @@ Defines parameters for the API: Define parameters for the authentication: -- oidc_conf_well_known_url: set the url for the oidc -- oidc_client_secret: set the secret of the oidc -- oidc_client_id: set the client_id of the oidc -- frontend_callback_url: set the frontend url the oidc client should link to after the login process -- oidc_enabled: set True to enabled the oidc login. Default: False +- oidc_conf_well_known_url: OIDC well-known URL for metadata +- oidc_client_secret: The client secret for OIDC +- oidc_client_id: The client_id for OIDC +- frontend_callback_url: The frontend URL that the OIDC client should redirect to after the login process +- oidc_enabled: Set True to enabled OIDC login. Defaults to False +- audience: The string to verify the aud attribute in the access token with +- verify_audience: Set to False to disable aud check. Defaults to True /etc/cnaas-nms/repository.yml ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/requirements.txt b/requirements.txt index 018203f5..48a70468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,14 +13,13 @@ mypy-extensions==0.4.3 markupsafe==2.1.1 nornir==3.3.0 nornir-jinja2==0.2.0 -nornir-napalm==0.3.0 +nornir-napalm==0.4.0 nornir-netmiko==0.2.0 nornir-utils==0.2.0 -napalm==4.0.0 +napalm==4.1.0 nose==1.3.7 pluggy==1.0.0 -psycopg2==2.9.3 -psycopg2-binary==2.9.3 +psycopg2-binary==2.9.9 pytest==7.1.3 pytest-cov==3.0.0 pytest-docker-compose==3.2.1 @@ -30,8 +29,10 @@ Sphinx==5.1.1 SQLAlchemy==1.4.41 sqlalchemy-stubs==0.4 SQLAlchemy-Utils==0.38.3 -pydantic==1.10.2 -Werkzeug==2.2.3 +pydantic==2.3.0 +Werkzeug==3.0.1 greenlet==3.0.1 +pyyaml!=6.0.0,!=5.4.0,!=5.4.1 +pydantic_settings==2.1.0 Authlib==1.0.1 python-jose==3.1.0 diff --git a/src/cnaas_nms/api/app.py b/src/cnaas_nms/api/app.py index 9f3d6468..1437fab0 100644 --- a/src/cnaas_nms/api/app.py +++ b/src/cnaas_nms/api/app.py @@ -1,8 +1,11 @@ import os import re import sys - from typing import Optional + +import werkzeug.exceptions +from authlib.integrations.flask_client import OAuth +from authlib.oauth2.rfc6749 import MissingAuthorizationError from engineio.payload import Payload from flask import Flask, jsonify, request from flask_cors import CORS @@ -11,11 +14,9 @@ from flask_restx import Api from flask_socketio import SocketIO, join_room from jwt import decode -from jwt.exceptions import DecodeError, InvalidSignatureError, InvalidTokenError, ExpiredSignatureError, InvalidKeyError -from authlib.integrations.flask_client import OAuth -from authlib.oauth2.rfc6749 import MissingAuthorizationError - +from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidKeyError, InvalidSignatureError, InvalidTokenError +from cnaas_nms.api.auth import api as auth_api from cnaas_nms.api.device import ( device_api, device_cert_api, @@ -28,7 +29,6 @@ device_update_interfaces_api, devices_api, ) -from cnaas_nms.api.auth import api as auth_api from cnaas_nms.api.firmware import api as firmware_api from cnaas_nms.api.groups import api as groups_api from cnaas_nms.api.interface import api as interfaces_api @@ -40,15 +40,11 @@ from cnaas_nms.api.repository import api as repository_api from cnaas_nms.api.settings import api as settings_api from cnaas_nms.api.system import api as system_api - -from cnaas_nms.app_settings import auth_settings -from cnaas_nms.app_settings import api_settings - +from cnaas_nms.app_settings import api_settings, auth_settings from cnaas_nms.tools.log import get_logger from cnaas_nms.tools.security import get_oauth_userinfo from cnaas_nms.version import __api_version__ - logger = get_logger() @@ -74,10 +70,6 @@ def handle_error(self, e): data = {"status": "error", "message": "Invalid authentication header: {}".format(e)} elif isinstance(e, InvalidSignatureError): data = {"status": "error", "message": "Invalid token signature"} - elif isinstance(e, IndexError): - # We might catch IndexErrors which are not caused by JWT, - # but this is better than nothing. - data = {"status": "error", "message": "JWT token missing?"} elif isinstance(e, NoAuthorizationError): data = {"status": "error", "message": "JWT token missing?"} elif isinstance(e, InvalidHeaderError): @@ -89,6 +81,8 @@ def handle_error(self, e): elif isinstance(e, ConnectionError): data = {"status": "error", "message": "ConnectionError: {}".format(e)} return jsonify(data), 500 + elif isinstance(e, werkzeug.exceptions.HTTPException): + data = {"status": "error", "message": "{}".format(e.name)} else: return super(CnaasApi, self).handle_error(e) return jsonify(data), 401 @@ -104,6 +98,7 @@ def handle_error(self, e): client_id=auth_settings.OIDC_CLIENT_ID, client_secret=auth_settings.OIDC_CLIENT_SECRET, client_kwargs={"scope": auth_settings.OIDC_CLIENT_SCOPE}, + authorize_params={"audience": auth_settings.AUDIENCE}, response_type="code", response_mode="query", ) @@ -167,26 +162,27 @@ def handle_error(self, e): api.add_namespace(plugins_api) api.add_namespace(system_api) + # SocketIO on connect @socketio.on("connect") def socketio_on_connect(): # get te token string - token_string = request.args.get('jwt') + token_string = request.args.get("jwt") if not token_string: return False - #if oidc, get userinfo + # if oidc, get userinfo if auth_settings.OIDC_ENABLED: try: - user = get_oauth_userinfo(token_string)['email'] + user = get_oauth_userinfo(token_string)["email"] except InvalidTokenError as e: - logger.debug('InvalidTokenError: ' + format(e)) + logger.debug("InvalidTokenError: " + format(e)) return False # else decode the token and get the sub there else: try: - user = decode(token_string, app.config["JWT_PUBLIC_KEY"], algorithms=[app.config["JWT_ALGORITHM"]])['sub'] + user = decode(token_string, app.config["JWT_PUBLIC_KEY"], algorithms=[app.config["JWT_ALGORITHM"]])["sub"] except DecodeError as e: - logger.debug('DecodeError: ' + format(e)) + logger.debug("DecodeError: " + format(e)) return False if user: @@ -220,7 +216,7 @@ def log_request(response): try: if auth_settings.OIDC_ENABLED: token_string = request.headers.get("Authorization").split(" ")[-1] - user = "User: {}, ".format(get_oauth_userinfo(token_string)['email']) + user = "User: {}, ".format(get_oauth_userinfo(token_string)["email"]) else: token = request.headers.get("Authorization").split(" ")[-1] user = "User: {}, ".format(decode_token(token).get("sub")) @@ -229,18 +225,14 @@ def log_request(response): try: url = re.sub(jwt_query_r, "", request.url) - if request.headers.get('content-type') == 'application/json': + if request.headers.get("content-type") == "application/json": logger.info( "{}Method: {}, Status: {}, URL: {}, JSON: {}".format( user, request.method, response.status_code, url, request.json ) ) else: - logger.info( - "{}Method: {}, Status: {}, URL: {}".format( - user, request.method, response.status_code, url - ) - ) + logger.info("{}Method: {}, Status: {}, URL: {}".format(user, request.method, response.status_code, url)) except Exception: pass return response diff --git a/src/cnaas_nms/api/device.py b/src/cnaas_nms/api/device.py index 5e69f9b6..7a3e5ff7 100644 --- a/src/cnaas_nms/api/device.py +++ b/src/cnaas_nms/api/device.py @@ -1132,7 +1132,7 @@ def get(self, hostname): @device_api.expect(stackmembers_model) def put(self, hostname): try: - validated_json_data = StackmembersModel(**request.get_json()).dict() + validated_json_data = StackmembersModel(**request.get_json()).model_dump() data = validated_json_data["stackmembers"] except ValidationError as e: errors = DeviceStackmembersApi.format_errors(e.errors()) @@ -1189,7 +1189,7 @@ def get(self): @device_synchistory_api.expect(device_synchistory_api) def post(self): try: - validated_json_data = NewSyncEventModel(**request.get_json()).dict() + validated_json_data = NewSyncEventModel(**request.get_json()).model_dump() except ValidationError as e: return empty_result("error", parse_pydantic_error(e, NewSyncEventModel, request.get_json())), 400 with sqla_session() as session: diff --git a/src/cnaas_nms/api/generic.py b/src/cnaas_nms/api/generic.py index 084724ec..fb1c03ae 100644 --- a/src/cnaas_nms/api/generic.py +++ b/src/cnaas_nms/api/generic.py @@ -180,7 +180,7 @@ def parse_pydantic_error(e: Exception, schema, data: dict) -> List[str]: ) ) try: - pydantic_descr = get_pydantic_field_descr(schema.schema(), loc) + pydantic_descr = get_pydantic_field_descr(schema.model_json_schema(), loc) if pydantic_descr: pydantic_descr_msg = ", field should be: {}".format(pydantic_descr) else: diff --git a/src/cnaas_nms/api/interface.py b/src/cnaas_nms/api/interface.py index 76f20993..7df4983d 100644 --- a/src/cnaas_nms/api/interface.py +++ b/src/cnaas_nms/api/interface.py @@ -176,14 +176,14 @@ def put(self, hostname): ) ) if "enabled" in if_dict["data"]: - if type(if_dict["data"]["enabled"]) == bool: + if type(if_dict["data"]["enabled"]) is bool: intfdata["enabled"] = if_dict["data"]["enabled"] else: errors.append( "Enabled must be a bool, true or false, got: {}".format(if_dict["data"]["enabled"]) ) if "aggregate_id" in if_dict["data"]: - if type(if_dict["data"]["aggregate_id"]) == int: + if type(if_dict["data"]["aggregate_id"]) is int: intfdata["aggregate_id"] = if_dict["data"]["aggregate_id"] elif if_dict["data"]["aggregate_id"] is None: if "aggregate_id" in intfdata: @@ -193,7 +193,7 @@ def put(self, hostname): "Aggregate ID must be an integer: {}".format(if_dict["data"]["aggregate_id"]) ) if "bpdu_filter" in if_dict["data"]: - if type(if_dict["data"]["bpdu_filter"]) == bool: + if type(if_dict["data"]["bpdu_filter"]) is bool: intfdata["bpdu_filter"] = if_dict["data"]["bpdu_filter"] else: errors.append( @@ -202,7 +202,7 @@ def put(self, hostname): ) ) if "redundant_link" in if_dict["data"]: - if type(if_dict["data"]["redundant_link"]) == bool: + if type(if_dict["data"]["redundant_link"]) is bool: intfdata["redundant_link"] = if_dict["data"]["redundant_link"] else: errors.append( diff --git a/src/cnaas_nms/api/linknet.py b/src/cnaas_nms/api/linknet.py index fbb20a30..95bbe189 100644 --- a/src/cnaas_nms/api/linknet.py +++ b/src/cnaas_nms/api/linknet.py @@ -3,8 +3,7 @@ from flask import request from flask_restx import Namespace, Resource, fields -from pydantic import BaseModel, validator -from pydantic.error_wrappers import ValidationError +from pydantic import BaseModel, FieldValidationInfo, ValidationError, field_validator from cnaas_nms.api.generic import empty_result, parse_pydantic_error, update_sqla_object from cnaas_nms.db.device import Device, DeviceType @@ -51,29 +50,31 @@ class f_linknet(BaseModel): device_a_ip: Optional[str] = None device_b_ip: Optional[str] = None - @validator("device_a_ip", "device_b_ip") - def device_ip_validator(cls, v, values, **kwargs): + @field_validator("device_a_ip", "device_b_ip") + @classmethod + def device_ip_validator(cls, v, info: FieldValidationInfo): if not v: return v - if not values["ipv4_network"]: + if not info.data["ipv4_network"]: raise ValueError("ipv4_network must be set") try: addr = IPv4Address(v) - net = IPv4Network(values["ipv4_network"]) + net = IPv4Network(info.data["ipv4_network"]) except Exception: # noqa: S110 raise ValueError("Invalid device IP or ipv4_network") else: if addr not in net.hosts(): raise ValueError("device IP must be part of ipv4_network") - if "device_a_ip" in values and v == values["device_a_ip"]: + if "device_a_ip" in info.data and v == info.data["device_a_ip"]: raise ValueError("device_a_ip and device_b_ip can not be the same") - if "device_b_ip" in values and v == values["device_b_ip"]: + if "device_b_ip" in info.data and v == info.data["device_b_ip"]: raise ValueError("device_a_ip and device_b_ip can not be the same") return v - @validator("ipv4_network") - def ipv4_network_validator(cls, v, values, **kwargs): + @field_validator("ipv4_network") + @classmethod + def ipv4_network_validator(cls, v): if not v: return v try: @@ -258,7 +259,7 @@ def put(self, linknet_id): if instance: try: validate_data = {**instance.as_dict(), **json_data} - f_linknet(**validate_data).dict() + f_linknet(**validate_data).model_dump() except ValidationError as e: errors += parse_pydantic_error(e, f_linknet, validate_data) if errors: diff --git a/src/cnaas_nms/api/mgmtdomain.py b/src/cnaas_nms/api/mgmtdomain.py index 7fca5fed..8e83fd23 100644 --- a/src/cnaas_nms/api/mgmtdomain.py +++ b/src/cnaas_nms/api/mgmtdomain.py @@ -3,8 +3,7 @@ from flask import request from flask_restx import Namespace, Resource, fields -from pydantic import BaseModel, validator -from pydantic.error_wrappers import ValidationError +from pydantic import BaseModel, ValidationError, field_validator from sqlalchemy.exc import IntegrityError from cnaas_nms.api.generic import build_filter, empty_result, limit_results, parse_pydantic_error, update_sqla_object @@ -42,8 +41,9 @@ class f_mgmtdomain(BaseModel): ipv6_gw: Optional[str] = None description: Optional[str] = None - @validator("ipv4_gw") - def ipv4_gw_valid_address(cls, v, values, **kwargs): + @field_validator("ipv4_gw") + @classmethod + def ipv4_gw_valid_address(cls, v): try: addr = IPv4Interface(v) prefix_len = int(addr.network.prefixlen) @@ -59,9 +59,9 @@ def ipv4_gw_valid_address(cls, v, values, **kwargs): return v - @validator("ipv6_gw") + @field_validator("ipv6_gw") @classmethod - def ipv6_gw_valid_address(cls, v, values, **kwargs): + def ipv6_gw_valid_address(cls, v): try: addr = IPv6Interface(v) prefix_len = int(addr.network.prefixlen) @@ -115,7 +115,7 @@ def put(self, mgmtdomain_id): json_data = request.get_json() errors = [] try: - f_mgmtdomain(**json_data).dict() + f_mgmtdomain(**json_data).model_dump() except ValidationError as e: errors += parse_pydantic_error(e, f_mgmtdomain, json_data) @@ -184,7 +184,7 @@ def post(self): data["device_b"] = device_b try: - data = {**data, **f_mgmtdomain(**json_data).dict()} + data = {**data, **f_mgmtdomain(**json_data).model_dump()} except ValidationError as e: errors += parse_pydantic_error(e, f_mgmtdomain, json_data) diff --git a/src/cnaas_nms/api/models/stackmembers_model.py b/src/cnaas_nms/api/models/stackmembers_model.py index 9f97eee7..37049623 100644 --- a/src/cnaas_nms/api/models/stackmembers_model.py +++ b/src/cnaas_nms/api/models/stackmembers_model.py @@ -1,6 +1,6 @@ from typing import List, Optional -from pydantic import BaseModel, conint, validator +from pydantic import BaseModel, conint, field_validator class StackmemberModel(BaseModel): @@ -8,7 +8,8 @@ class StackmemberModel(BaseModel): hardware_id: str priority: Optional[conint(gt=-1)] = None - @validator("hardware_id") + @field_validator("hardware_id") + @classmethod def validate_non_empty_hardware_id(cls, hardware_id): """Validates that hardware_id is not an empty string""" if not hardware_id: @@ -19,7 +20,8 @@ def validate_non_empty_hardware_id(cls, hardware_id): class StackmembersModel(BaseModel): stackmembers: List[StackmemberModel] - @validator("stackmembers") + @field_validator("stackmembers") + @classmethod def validate_unique_member_no(cls, stackmembers): """Validates that all StackmemberModel in stackmembers have unique member_no compared to each other""" member_no_count = len(set([stackmember.member_no for stackmember in stackmembers])) @@ -27,7 +29,8 @@ def validate_unique_member_no(cls, stackmembers): raise ValueError("member_no must be unique for stackmembers belonging to the same device") return stackmembers - @validator("stackmembers") + @field_validator("stackmembers") + @classmethod def validate_unique_hardware_id(cls, stackmembers): """Validates that all StackmemberModel in stackmembers have unique hardware_id compared to each other""" hardware_id_count = len(set([stackmember.hardware_id for stackmember in stackmembers])) diff --git a/src/cnaas_nms/api/settings.py b/src/cnaas_nms/api/settings.py index 65c97d22..88a5ba00 100644 --- a/src/cnaas_nms/api/settings.py +++ b/src/cnaas_nms/api/settings.py @@ -57,7 +57,7 @@ def get(self): class SettingsModelApi(Resource): def get(self): - response = make_response(settings_root_model.schema_json()) + response = make_response(settings_root_model.model_json_schema()) response.headers["Content-Type"] = "application/json" return response @@ -75,7 +75,7 @@ def post(self): class SettingsServerApI(Resource): @login_required def get(self): - ret_dict = {"api": api_settings.dict()} + ret_dict = {"api": api_settings.model_dump()} response = make_response(json.dumps(ret_dict, default=json_dumper)) response.headers["Content-Type"] = "application/json" return response diff --git a/src/cnaas_nms/api/tests/test_settings.py b/src/cnaas_nms/api/tests/test_settings.py index c89ba613..51b62108 100644 --- a/src/cnaas_nms/api/tests/test_settings.py +++ b/src/cnaas_nms/api/tests/test_settings.py @@ -39,7 +39,7 @@ def test_settings_model(testclient: FlaskClient): result = testclient.get("/api/v1.0/settings/model") assert result.status_code == 200 assert result.content_type == "application/json" - assert "definitions" in result.json + assert "$defs" in result.json def test_settings_server(testclient: FlaskClient): diff --git a/src/cnaas_nms/app_settings.py b/src/cnaas_nms/app_settings.py index 96efe6d2..13992ef9 100644 --- a/src/cnaas_nms/app_settings.py +++ b/src/cnaas_nms/app_settings.py @@ -2,7 +2,8 @@ from typing import Optional import yaml -from pydantic import BaseSettings, PostgresDsn, validator +from pydantic import field_validator +from pydantic_settings import BaseSettings class AppSettings(BaseSettings): @@ -14,7 +15,7 @@ class AppSettings(BaseSettings): CNAAS_DB_PORT: int = 5432 REDIS_HOSTNAME: str = "127.0.0.1" REDIS_PORT: int = 6379 - POSTGRES_DSN: PostgresDsn = ( + POSTGRES_DSN: str = ( f"postgresql://{CNAAS_DB_USERNAME}:{CNAAS_DB_PASSWORD}@{CNAAS_DB_HOSTNAME}:{CNAAS_DB_PORT}/{CNAAS_DB_DATABASE}" ) USERNAME_INIT: str = "admin" @@ -54,7 +55,7 @@ class ApiSettings(BaseSettings): COMMIT_CONFIRMED_WAIT: int = 1 SETTINGS_OVERRIDE: Optional[dict] = None - @validator("MGMTDOMAIN_PRIMARY_IP_VERSION") + @field_validator("MGMTDOMAIN_PRIMARY_IP_VERSION") @classmethod def primary_ip_version_is_valid(cls, version: int) -> int: if version not in (4, 6): @@ -71,6 +72,7 @@ class AuthSettings(BaseSettings): OIDC_ENABLED: bool = False OIDC_CLIENT_SCOPE: str = "openid" AUDIENCE: str = OIDC_CLIENT_ID + VERIFY_AUDIENCE: bool = True def construct_api_settings() -> ApiSettings: @@ -155,6 +157,7 @@ def construct_auth_settings() -> AuthSettings: OIDC_CLIENT_ID=config.get("oidc_client_id", AuthSettings().OIDC_CLIENT_ID), OIDC_CLIENT_SCOPE=config.get("oidc_client_scope", AuthSettings().OIDC_CLIENT_SCOPE), AUDIENCE=config.get("audience", AuthSettings().AUDIENCE), + VERIFY_AUDIENCE=config.get("verify_audience", AuthSettings().VERIFY_AUDIENCE), ) else: return AuthSettings() diff --git a/src/cnaas_nms/db/git.py b/src/cnaas_nms/db/git.py index e2837a2a..5e8fdd0e 100644 --- a/src/cnaas_nms/db/git.py +++ b/src/cnaas_nms/db/git.py @@ -6,7 +6,6 @@ from urllib.parse import urldefrag import yaml -from git.exc import GitCommandError, NoSuchPathError from cnaas_nms.app_settings import app_settings from cnaas_nms.db.device import Device, DeviceType @@ -18,6 +17,7 @@ from cnaas_nms.devicehandler.sync_history import add_sync_event from cnaas_nms.tools.log import get_logger from git import InvalidGitRepositoryError, Repo +from git.exc import GitCommandError, NoSuchPathError logger = get_logger() diff --git a/src/cnaas_nms/db/job.py b/src/cnaas_nms/db/job.py index 3dec79bc..23a9cae0 100644 --- a/src/cnaas_nms/db/job.py +++ b/src/cnaas_nms/db/job.py @@ -76,7 +76,7 @@ def as_dict(self) -> dict: continue elif issubclass(value.__class__, datetime.datetime): value = json_dumper(value) - elif type(col.type) == JSONB and value and type(value) == str: + elif type(col.type) is JSONB and value and type(value) is str: value = json.loads(value) d[col.name] = value return d @@ -106,7 +106,7 @@ def finish_success(self, res: dict, next_job_id: Optional[int]): try: if isinstance(res, NornirJobResult) and isinstance(res.nrresult, AggregatedResult): self.result = {"devices": nr_result_serialize(res.nrresult)} - if res.change_score and type(res.change_score) == int: + if res.change_score and type(res.change_score) is int: self.change_score = res.change_score elif isinstance(res, (StrJobResult, DictJobResult)): self.result = res.result @@ -231,11 +231,11 @@ def get_previous_config( .filter(Job.result["devices"].has_key(hostname)) ) - if job_id and type(job_id) == int: + if job_id and type(job_id) is int: query_part = query_part.filter(Job.id == job_id) - elif previous and type(previous) == int: + elif previous and type(previous) is int: query_part = query_part.order_by(Job.id.desc()).offset(previous) - elif before and type(before) == datetime.datetime: + elif before and type(before) is datetime.datetime: query_part = query_part.filter(Job.finish_time < before).order_by(Job.id.desc()) else: query_part = query_part.order_by(Job.id.desc()) diff --git a/src/cnaas_nms/db/settings.py b/src/cnaas_nms/db/settings.py index 7793556e..aae7c458 100644 --- a/src/cnaas_nms/db/settings.py +++ b/src/cnaas_nms/db/settings.py @@ -5,7 +5,7 @@ import pkg_resources import yaml -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from redis import StrictRedis from redis_lru import RedisLRU @@ -207,7 +207,7 @@ def get_pydantic_field_descr(schema: dict, loc: tuple): ref_to = next_schema["$ref"].split("/")[2] next_schema = schema["definitions"][ref_to]["properties"][loc_part] elif next_schema: - if type(loc_part) == int: + if type(loc_part) is int: next_schema = next_schema["items"] else: next_schema = schema["definitions"][next_schema]["properties"][loc_part] @@ -227,7 +227,7 @@ def check_settings_syntax(settings_dict: dict, settings_metadata_dict: dict) -> """ logger = get_logger() try: - ret_dict = f_root(**settings_dict).dict() + ret_dict = f_root(**settings_dict).model_dump() except ValidationError as validation_error: msg = "" for num, error in enumerate(validation_error.errors()): @@ -247,7 +247,7 @@ def check_settings_syntax(settings_dict: dict, settings_metadata_dict: dict) -> "->".join(str(x) for x in loc), get_pydantic_error_value(settings_dict, loc), origin ) try: - pydantic_descr = get_pydantic_field_descr(f_root.schema(), loc) + pydantic_descr = get_pydantic_field_descr(f_root.model_json_schema(), loc) if pydantic_descr: pydantic_descr_msg = ", field should be: {}".format(pydantic_descr) else: @@ -298,8 +298,8 @@ def get_internal_vlan_range(settings) -> range: if ( "vlan_id_low" in settings["internal_vlans"] and "vlan_id_high" in settings["internal_vlans"] - and type(settings["internal_vlans"]["vlan_id_low"]) == int - and type(settings["internal_vlans"]["vlan_id_high"]) == int + and type(settings["internal_vlans"]["vlan_id_low"]) is int + and type(settings["internal_vlans"]["vlan_id_high"]) is int ): return range(settings["internal_vlans"]["vlan_id_low"], settings["internal_vlans"]["vlan_id_high"] + 1) else: @@ -702,7 +702,7 @@ def get_group_settings(): ) settings["groups"] += default_settings["groups"] check_settings_syntax(settings, settings_origin) - return f_groups(**settings).dict(), settings_origin + return f_groups(**settings).model_dump(), settings_origin @redis_lru_cache diff --git a/src/cnaas_nms/db/settings_fields.py b/src/cnaas_nms/db/settings_fields.py index 5ebcd5d6..0e571106 100644 --- a/src/cnaas_nms/db/settings_fields.py +++ b/src/cnaas_nms/db/settings_fields.py @@ -1,7 +1,8 @@ from ipaddress import AddressValueError, IPv4Interface -from typing import Dict, List, Optional +from typing import Annotated, Dict, List, Optional -from pydantic import BaseModel, Field, conint, validator +from pydantic import BaseModel, Field, FieldValidationInfo, conint, field_validator +from pydantic.functional_validators import AfterValidator # HOSTNAME_REGEX = r'([a-z0-9-]{1,63}\.?)+' IPV4_REGEX = r"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)" @@ -22,22 +23,24 @@ HOST_REGEX = f"^({IPV4_REGEX}|{IPV6_REGEX}|{FQDN_REGEX})$" HOSTNAME_REGEX = r"^([a-zA-Z0-9-]{1,63})(\.[a-z0-9-]{1,63})*$" DOMAIN_NAME_REGEX = r"^([a-zA-Z0-9-]{1,63})(\.[a-z0-9-]{1,63})+$" -host_schema = Field(..., regex=HOST_REGEX, max_length=253, description="Hostname, FQDN or IP address") -hostname_schema = Field(..., regex=HOSTNAME_REGEX, max_length=253, description="Hostname or FQDN") -domain_name_schema = Field(None, regex=DOMAIN_NAME_REGEX, max_length=251, description="DNS domain name") -ipv4_schema = Field(..., regex=f"^{IPV4_REGEX}$", description="IPv4 address") +host_schema = Field(..., pattern=HOST_REGEX, max_length=253, description="Hostname, FQDN or IP address") +hostname_schema = Field(..., pattern=HOSTNAME_REGEX, max_length=253, description="Hostname or FQDN") +domain_name_schema = Field(None, pattern=DOMAIN_NAME_REGEX, max_length=251, description="DNS domain name") +ipv4_schema = Field(..., pattern=f"^{IPV4_REGEX}$", description="IPv4 address") IPV4_IF_REGEX = f"{IPV4_REGEX}" + r"\/[0-9]{1,2}" -ipv4_if_schema = Field(None, regex=f"^{IPV4_IF_REGEX}$", description="IPv4 address in CIDR/prefix notation (0.0.0.0/0)") -ipv6_schema = Field(..., regex=f"^{IPV6_REGEX}$", description="IPv6 address") +ipv4_if_schema = Field( + None, pattern=f"^{IPV4_IF_REGEX}$", description="IPv4 address in CIDR/prefix notation (0.0.0.0/0)" +) +ipv6_schema = Field(..., pattern=f"^{IPV6_REGEX}$", description="IPv6 address") IPV6_IF_REGEX = f"{IPV6_REGEX}" + r"\/[0-9]{1,3}" -ipv6_if_schema = Field(None, regex=f"^{IPV6_IF_REGEX}$", description="IPv6 address in CIDR/prefix notation (::/0)") -ipv4_or_ipv6_if_schema = Field(None, regex=f"({IPV4_IF_REGEX}|{IPV6_IF_REGEX})", description="IPv4 or IPv6 prefix") +ipv6_if_schema = Field(None, pattern=f"^{IPV6_IF_REGEX}$", description="IPv6 address in CIDR/prefix notation (::/0)") +ipv4_or_ipv6_if_schema = Field(None, pattern=f"({IPV4_IF_REGEX}|{IPV6_IF_REGEX})", description="IPv4 or IPv6 prefix") # VLAN name is alphanumeric max 32 chars on Cisco # should not start with number according to some Juniper doc VLAN_NAME_REGEX = r"^[a-zA-Z][a-zA-Z0-9-_]{0,31}$" vlan_name_schema = Field( - None, regex=VLAN_NAME_REGEX, description="Max 32 alphanumeric chars, " + "beginning with a non-numeric character" + None, pattern=VLAN_NAME_REGEX, description="Max 32 alphanumeric chars, " + "beginning with a non-numeric character" ) vlan_id_schema = Field(..., gt=0, lt=4096, description="Numeric 802.1Q VLAN ID, 1-4095") vlan_id_schema_optional = Field(None, gt=0, lt=4096, description="Numeric 802.1Q VLAN ID, 1-4095") @@ -47,24 +50,24 @@ as_num_schema = Field(None, description="BGP Autonomous System number, 1-4294967295 (asdot notation not supported)") as_num_type = conint(strict=True, gt=0, lt=4294967296) IFNAME_REGEX = r"([a-zA-Z0-9\/\.:-])+" -ifname_schema = Field(None, regex=f"^{IFNAME_REGEX}$", description="Interface name") +ifname_schema = Field(None, pattern=f"^{IFNAME_REGEX}$", description="Interface name") IFNAME_RANGE_REGEX = r"([a-zA-Z0-9\/\.:\-\[\]])+" ifname_range_schema = Field( - None, regex=f"^{IFNAME_RANGE_REGEX}$", description="Interface range pattern or interface name" + None, pattern=f"^{IFNAME_RANGE_REGEX}$", description="Interface range pattern or interface name" ) IFCLASS_REGEX = r"(custom|downlink|fabric|port_template_[a-zA-Z0-9_]+)" -ifclass_schema = Field(None, regex=f"^{IFCLASS_REGEX}$", description="Interface class: custom, downlink or uplink") +ifclass_schema = Field(None, pattern=f"^{IFCLASS_REGEX}$", description="Interface class: custom, downlink or uplink") ifdescr_schema = Field(None, max_length=64, description="Interface description, 0-64 characters") tcpudp_port_schema = Field(None, ge=0, lt=65536, description="TCP or UDP port number, 0-65535") ebgp_multihop_schema = Field(None, ge=1, le=255, description="Numeric IP TTL, 1-255") maximum_routes_schema = Field(None, ge=0, le=4294967294, description="Maximum number of routes to receive from peer") -accept_or_reject_schema = Field(..., regex=r"^(accept|reject)$", description="Value has to be 'accept' or 'reject'") +accept_or_reject_schema = Field(..., pattern=r"^(accept|reject)$", description="Value has to be 'accept' or 'reject'") prefix_size_or_range_schema = Field( - None, regex=r"^[0-9]{1,3}([-][0-9]{1,3})?$", description="Prefix size or range 0-128" + None, pattern=r"^[0-9]{1,3}([-][0-9]{1,3})?$", description="Prefix size or range 0-128" ) GROUP_NAME = r"^([a-zA-Z0-9_-]{1,63}\.?)+$" -group_name = Field(..., regex=GROUP_NAME, max_length=253) +group_name = Field(..., pattern=GROUP_NAME, max_length=253) group_priority_schema = Field( 0, ge=0, le=100, description="Group priority 0-100, default 0, higher value means higher priority" ) @@ -83,6 +86,7 @@ def validate_ipv4_if(ipv4if: str): raise ValueError("Invalid IPv4 interface: {}".format(e)) except AssertionError as e: raise ValueError("Invalid IPv4 interface: {}".format(e)) + return ipv4if # Note: If specifying a list of a BaseModel derived class anywhere else except @@ -136,7 +140,7 @@ class f_interface(BaseModel): description: Optional[str] = ifdescr_schema enabled: Optional[bool] = None untagged_vlan: Optional[int] = vlan_id_schema_optional - tagged_vlan_list: Optional[List[int]] = None + tagged_vlan_list: Optional[List[Annotated[int, Field(ge=1, le=4094)]]] = None aggregate_id: Optional[int] = None tags: Optional[List[str]] = None vrf: Optional[str] = vlan_name_schema @@ -149,19 +153,15 @@ class f_interface(BaseModel): acl_ipv6_out: Optional[str] = None cli_append_str: str = "" - @validator("ipv4_address") - def vrf_required_if_ipv4_address_set(cls, v, values, **kwargs): + @field_validator("ipv4_address") + @classmethod + def vrf_required_if_ipv4_address_set(cls, v: str, info: FieldValidationInfo): if v: validate_ipv4_if(v) - if "vrf" not in values or not values["vrf"]: + if "vrf" not in info.data or not info.data["vrf"]: raise ValueError("VRF is required when specifying ipv4_gw") return v - @validator("tagged_vlan_list", each_item=True) - def check_valid_vlan_ids(cls, v): - assert 0 < v < 4096 - return v - class f_vrf(BaseModel): name: str = None @@ -191,8 +191,8 @@ class f_ipv6_static_route(BaseModel): class f_extroute_static_vrf(BaseModel): name: str - ipv4: Optional[List[f_ipv4_static_route]] - ipv6: Optional[List[f_ipv6_static_route]] + ipv4: Optional[List[f_ipv4_static_route]] = None + ipv6: Optional[List[f_ipv6_static_route]] = None class f_extroute_static(BaseModel): @@ -261,24 +261,25 @@ class f_internal_vlans(BaseModel): vlan_id_high: int = vlan_id_schema allocation_order: str = "ascending" - @validator("vlan_id_high") - def vlan_id_high_greater_than_low(cls, v, values, **kwargs): + @field_validator("vlan_id_high") + @classmethod + def vlan_id_high_greater_than_low(cls, v: int, info: FieldValidationInfo): if v: - if values["vlan_id_low"] >= v: + if info.data["vlan_id_low"] >= v: raise ValueError("vlan_id_high must be greater than vlan_id_low") return v class f_vxlan(BaseModel): - description: str = None + description: Optional[str] = None vni: int = vxlan_vni_schema vrf: Optional[str] = vlan_name_schema vlan_id: int = vlan_id_schema vlan_name: str = vlan_name_schema ipv4_gw: Optional[str] = None - ipv4_secondaries: Optional[List[str]] + ipv4_secondaries: Optional[List[Annotated[str, AfterValidator(validate_ipv4_if)]]] = None ipv6_gw: Optional[str] = ipv6_if_schema - dhcp_relays: Optional[List[f_dhcp_relay]] + dhcp_relays: Optional[List[f_dhcp_relay]] = None mtu: Optional[int] = mtu_schema vxlan_host_route: bool = True acl_ipv4_in: Optional[str] = None @@ -290,23 +291,20 @@ class f_vxlan(BaseModel): devices: List[str] = [] tags: List[str] = [] - @validator("ipv4_secondaries", each_item=True) - def ipv4_secondaries_regex(cls, v): - validate_ipv4_if(v) - return v - - @validator("ipv4_gw") - def vrf_required_if_ipv4_gw_set(cls, v, values, **kwargs): + @field_validator("ipv4_gw") + @classmethod + def vrf_required_if_ipv4_gw_set(cls, v: str, info: FieldValidationInfo): if v: validate_ipv4_if(v) - if "vrf" not in values or not values["vrf"]: + if "vrf" not in info.data or not info.data["vrf"]: raise ValueError("VRF is required when specifying ipv4_gw") return v - @validator("ipv6_gw") - def vrf_required_if_ipv6_gw_set(cls, v, values, **kwargs): + @field_validator("ipv6_gw") + @classmethod + def vrf_required_if_ipv6_gw_set(cls, v: str, info: FieldValidationInfo): if v: - if "vrf" not in values or not values["vrf"]: + if "vrf" not in info.data or not info.data["vrf"]: raise ValueError("VRF is required when specifying ipv6_gw") return v @@ -333,7 +331,7 @@ class f_user(BaseModel): class f_prefixset_item(BaseModel): prefix: str = ipv4_or_ipv6_if_schema - masklength_range: Optional[str] = prefix_size_or_range_schema + masklength_range: Optional[Annotated[int, Field(ge=0, le=128)] | Annotated[str, prefix_size_or_range_schema]] = None class f_prefixset(BaseModel): @@ -362,16 +360,16 @@ class f_root(BaseModel): snmp_servers: List[f_snmp_server] = [] dns_servers: List[f_dns_server] = [] flow_collectors: List[f_flow_collector] = [] - dhcp_relays: Optional[List[f_dhcp_relay]] + dhcp_relays: Optional[List[f_dhcp_relay]] = None interfaces: List[f_interface] = [] vrfs: List[f_vrf] = [] vxlans: Dict[str, f_vxlan] = {} underlay: f_underlay = None evpn_peers: List[f_evpn_peer] = [] - extroute_static: Optional[f_extroute_static] - extroute_ospfv3: Optional[f_extroute_ospfv3] - extroute_bgp: Optional[f_extroute_bgp] - internal_vlans: Optional[f_internal_vlans] + extroute_static: Optional[f_extroute_static] = None + extroute_ospfv3: Optional[f_extroute_ospfv3] = None + extroute_bgp: Optional[f_extroute_bgp] = None + internal_vlans: Optional[f_internal_vlans] = None dot1x_fail_vlan: Optional[int] = vlan_id_schema_optional cli_prepend_str: str = "" cli_append_str: str = "" @@ -389,16 +387,17 @@ class f_group_item(BaseModel): regex: str = "" group_priority: int = group_priority_schema - @validator("group_priority") - def reserved_priority(cls, v, values, **kwargs): - if v and v == 1 and values["name"] != "DEFAULT": + @field_validator("group_priority") + @classmethod + def reserved_priority(cls, v: int, info: FieldValidationInfo): + if v and v == 1 and info.data["name"] != "DEFAULT": raise ValueError("group_priority 1 is reserved for built-in group DEFAULT") return v class f_group(BaseModel): - group: Optional[f_group_item] + group: Optional[f_group_item] = None class f_groups(BaseModel): - groups: Optional[List[f_group]] + groups: Optional[List[f_group]] = None diff --git a/src/cnaas_nms/devicehandler/get.py b/src/cnaas_nms/devicehandler/get.py index 12b2ed7f..da4b812a 100644 --- a/src/cnaas_nms/devicehandler/get.py +++ b/src/cnaas_nms/devicehandler/get.py @@ -139,7 +139,7 @@ def get_uplinks( local_ifs = dev.get_neighbor_ifnames(session, neighbor_d, linknets) dl_intf_names = [] - intf: Interface + dl_intf: Interface for dl_intf in dl_intfs: dl_intf_names.append(dl_intf.name) diff --git a/src/cnaas_nms/devicehandler/sync_devices.py b/src/cnaas_nms/devicehandler/sync_devices.py index b0b93cf9..68626bf7 100644 --- a/src/cnaas_nms/devicehandler/sync_devices.py +++ b/src/cnaas_nms/devicehandler/sync_devices.py @@ -66,7 +66,7 @@ def get_evpn_peers(session, settings: dict): def resolve_vlanid(vlan_name: str, vxlans: dict) -> Optional[int]: logger = get_logger() - if type(vlan_name) == int: + if type(vlan_name) is int: return int(vlan_name) if not isinstance(vlan_name, str): return None diff --git a/src/cnaas_nms/scheduler/thread_data.py b/src/cnaas_nms/scheduler/thread_data.py index a4fdda5f..b8dc10c6 100644 --- a/src/cnaas_nms/scheduler/thread_data.py +++ b/src/cnaas_nms/scheduler/thread_data.py @@ -4,5 +4,5 @@ def set_thread_data(job_id): - if job_id and type(job_id) == int: + if job_id and type(job_id) is int: thread_data.job_id = job_id diff --git a/src/cnaas_nms/scheduler/wrapper.py b/src/cnaas_nms/scheduler/wrapper.py index 1e4aedbb..75948b7f 100644 --- a/src/cnaas_nms/scheduler/wrapper.py +++ b/src/cnaas_nms/scheduler/wrapper.py @@ -49,7 +49,7 @@ def job_wrapper(func): """Decorator to save job status in job tracker database.""" def wrapper(job_id: int, scheduled_by: str, kwargs={}): - if not job_id or not type(job_id) == int: + if not job_id or type(job_id) is not int: errmsg = "Missing job_id when starting job for {}".format(func.__name__) logger.error(errmsg) raise ValueError(errmsg) diff --git a/src/cnaas_nms/tools/log.py b/src/cnaas_nms/tools/log.py index 1ffd66b1..8c443f4e 100644 --- a/src/cnaas_nms/tools/log.py +++ b/src/cnaas_nms/tools/log.py @@ -16,7 +16,7 @@ def emit(self, record): def get_logger(): - if hasattr(thread_data, "job_id") and type(thread_data.job_id) == int: + if hasattr(thread_data, "job_id") and type(thread_data.job_id) is int: logger = logging.getLogger("cnaas-nms-{}".format(thread_data.job_id)) if not logger.handlers: formatter = logging.Formatter( diff --git a/src/cnaas_nms/tools/security.py b/src/cnaas_nms/tools/security.py index 44536ec5..8aefbc00 100644 --- a/src/cnaas_nms/tools/security.py +++ b/src/cnaas_nms/tools/security.py @@ -1,3 +1,4 @@ +import json from typing import Any, Mapping import requests @@ -7,7 +8,7 @@ from flask_jwt_extended import jwt_required as jwt_orig from jose import exceptions, jwt from jwt.exceptions import ExpiredSignatureError, InvalidKeyError, InvalidTokenError -import json + from cnaas_nms.app_settings import api_settings, auth_settings from cnaas_nms.tools.log import get_logger @@ -40,7 +41,7 @@ def get_oauth_userinfo(token_string): We get the right info from there and return this to the user. Returns: - resp.json(): Object of the user info + resp.json(): Object of the user info """ # For now unnecersary, useful when we only use one log in method @@ -50,9 +51,9 @@ def get_oauth_userinfo(token_string): try: metadata = requests.get(auth_settings.OIDC_CONF_WELL_KNOWN_URL) metadata.raise_for_status() - except requests.exceptions.HTTPError as e: + except requests.exceptions.HTTPError: raise ConnectionError("Can't reach the OIDC URL") - except requests.exceptions.ConnectionError as e: + except requests.exceptions.ConnectionError: raise ConnectionError("OIDC metadata unavailable") user_info_endpoint = metadata.json()["userinfo_endpoint"] @@ -63,10 +64,11 @@ def get_oauth_userinfo(token_string): resp.raise_for_status() except requests.exceptions.HTTPError as e: body = json.loads(e.response.content) - logger.debug("Request not successful: " + body['error_description']) - raise InvalidTokenError(body['error_description']) + logger.debug("Request not successful: " + body["error_description"]) + raise InvalidTokenError(body["error_description"]) return resp.json() + class MyBearerTokenValidator(BearerTokenValidator): keys: Mapping[str, Any] = {} @@ -77,15 +79,14 @@ def get_keys(self): keys_endpoint = metadata.json()["jwks_uri"] response = requests.get(url=keys_endpoint) self.keys = response.json()["keys"] - except KeyError as e: + except KeyError as e: raise InvalidKeyError(e) except requests.exceptions.HTTPError as e: raise InvalidKeyError(e) - def get_key(self, kid): """Get the key based on the kid""" - key = [k for k in self.keys if k['kid'] == kid] + key = [k for k in self.keys if k["kid"] == kid] if len(key) == 0: logger.debug("Key not found. Get the keys.") self.get_keys() @@ -93,8 +94,8 @@ def get_key(self, kid): logger.error("Keys not downloaded") raise InvalidKeyError() try: - key = [k for k in self.keys if k['kid'] == kid] - except KeyError as e: + key = [k for k in self.keys if k["kid"] == kid] + except KeyError as e: logger.error("Keys in different format?") raise InvalidKeyError(e) if len(key) == 0: @@ -129,12 +130,10 @@ def authenticate_token(self, token_string: str): unverified_header = jwt.get_unverified_header(token_string) except exceptions.JWSError as e: raise InvalidTokenError(e) - except exceptions.JWTError as e: + except exceptions.JWTError: # check if we can still get the user info get_oauth_userinfo(token_string) - token = { - "access_token": token_string - } + token = {"access_token": token_string} return token # get the key @@ -144,11 +143,15 @@ def authenticate_token(self, token_string: str): algorithm = unverified_header.get("alg") try: decoded_token = jwt.decode( - token_string, key, algorithms=algorithm, audience=auth_settings.AUDIENCE + token_string, + key, + algorithms=algorithm, + audience=auth_settings.AUDIENCE, + options={"verify_aud": auth_settings.VERIFY_AUDIENCE}, ) except exceptions.ExpiredSignatureError as e: raise ExpiredSignatureError(e) - except exceptions.JWKError: + except exceptions.JWKError as e: logger.error("Invalid Key") raise InvalidKeyError(e) except exceptions.JWTError as e: @@ -205,4 +208,3 @@ def get_oauth_identity(): else: login_required = jwt_required get_identity = get_jwt_identity - \ No newline at end of file