From 0f2b349d0f355e9efb16c06fdbffc4f21dd50b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Fri, 19 Apr 2024 15:25:31 +0200 Subject: [PATCH 01/16] Remove flask-httpauth --- pyproject.toml | 3 +- requirements/install.txt | 3 -- src/bemserver_api/exceptions.py | 9 ++++ .../extensions/authentication.py | 51 +++++++++++-------- src/bemserver_api/extensions/smorest.py | 4 -- tests/extensions/test_authentication.py | 35 +++++-------- 6 files changed, 53 insertions(+), 52 deletions(-) create mode 100644 src/bemserver_api/exceptions.py diff --git a/pyproject.toml b/pyproject.toml index 3a2937b..c45c07a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "marshmallow-sqlalchemy>=0.29.0,<0.30", "flask_smorest>=0.43.0,<0.44", "apispec>=6.1.0,<7.0", - "flask-httpauth>=4.7.0,<5.0", "bemserver-core>=0.17.1,<0.18", ] @@ -77,7 +76,7 @@ section-order = ["future", "standard-library", "testing", "db", "pallets", "mars [tool.ruff.lint.isort.sections] testing = ["pytest", "pytest_postgresql"] db = ["psycopg", "sqlalchemy", "alembic"] -pallets = ["werkzeug", "flask", "flask_httpauth"] +pallets = ["werkzeug", "flask"] marshmallow = ["marshmallow", "marshmallow_sqlalchemy", "webargs", "apispec", "flask_smorest"] science = ["numpy", "pandas"] core = ["bemserver_core"] diff --git a/requirements/install.txt b/requirements/install.txt index 86dca56..a6cf47c 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -51,10 +51,7 @@ click-repl==0.3.0 flask==3.0.2 # via # bemserver-api (pyproject.toml) - # flask-httpauth # flask-smorest -flask-httpauth==4.8.0 - # via bemserver-api (pyproject.toml) flask-smorest==0.43.0 # via bemserver-api (pyproject.toml) greenlet==3.0.3 diff --git a/src/bemserver_api/exceptions.py b/src/bemserver_api/exceptions.py new file mode 100644 index 0000000..f19e5b1 --- /dev/null +++ b/src/bemserver_api/exceptions.py @@ -0,0 +1,9 @@ +"""Exceptions""" + + +class BEMServerAPIError(Exception): + """Base BEMServer API exception""" + + +class BEMServerAPIAuthenticationError(BEMServerAPIError): + """AuthenticationError error""" diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index bd59eea..e47b198 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -1,10 +1,11 @@ """Authentication""" +import base64 from functools import wraps import sqlalchemy as sqla -from flask_httpauth import HTTPBasicAuth +import flask from flask_smorest import abort @@ -12,11 +13,31 @@ from bemserver_core.model.users import User from bemserver_api.database import db +from bemserver_api.exceptions import BEMServerAPIAuthenticationError -class Auth(HTTPBasicAuth): +class Auth: """Authentication and authorization management""" + @staticmethod + def get_user_http_basic_auth(): + """Check password and return User instance""" + if (auth_header := flask.request.headers.get("Authorization")) is None: + raise (BEMServerAPIAuthenticationError) + try: + _, creds = auth_header.encode("utf-8").split(b" ", maxsplit=1) + enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) + user_email = enc_email.decode() + password = enc_password.decode() + except (ValueError, TypeError) as exc: + raise (BEMServerAPIAuthenticationError) from exc + user = db.session.execute( + sqla.select(User).where(User.email == user_email) + ).scalar() + if user is None or not user.check_password(password): + raise (BEMServerAPIAuthenticationError) + return user + def login_required(self, f=None, **kwargs): """Decorator providing authentication and authorization @@ -28,16 +49,18 @@ def login_required(self, f=None, **kwargs): def decorator(func): @wraps(func) def wrapper(*args, **func_kwargs): - with CurrentUser(self.current_user()): + try: + user = self.get_user_http_basic_auth() + except BEMServerAPIAuthenticationError: + abort(401, "Authentication error") + with CurrentUser(user): try: resp = func(*args, **func_kwargs) except BEMServerAuthorizationError: abort(403, message="Authorization error") return resp - # Wrap this inside HTTPAuth.login_required - # to get authenticated user - return super(Auth, self).login_required(**kwargs)(wrapper) + return wrapper if f: return decorator(f) @@ -45,19 +68,3 @@ def wrapper(*args, **func_kwargs): auth = Auth() - - -@auth.verify_password -def verify_password(username, password): - """Check password and return User instance""" - user = db.session.execute(sqla.select(User).where(User.email == username)).scalar() - if user is not None and user.check_password(password): - return user - return None - - -@auth.error_handler -def auth_error(status): - """Authentication error handler""" - # Call abort to trigger error handler and get consistent JSON output - abort(status, message="Authentication error") diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index d21e563..809039d 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -70,10 +70,6 @@ def _prepare_auth_doc(doc, doc_info, *, app, **kwargs): doc["security"] = [{"BasicAuthentication": []}] return doc - @staticmethod - def current_user(): - return auth.current_user() - @staticmethod def catch_integrity_error(func=None): """Catch DB integrity errors""" diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 2c2700c..0fa7743 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -1,6 +1,6 @@ """Test authentication extension""" -from flask import jsonify +import base64 from bemserver_core.authorization import get_current_user @@ -36,6 +36,19 @@ def no_auth(): resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 + # Broken auth headers + headers = {"Authorization": "Basic Dummy"} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + creds = base64.b64encode(b"Dummy").decode() + headers = {"Authorization": "Basic " + creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + # Inactive user headers = {"Authorization": "Basic " + inactive_user_creds} resp = client.get("/auth_test/auth", headers=headers) @@ -71,23 +84,3 @@ def no_auth(): no_auth_spec = spec["paths"]["/auth_test/no_auth"] assert "401" not in no_auth_spec["get"]["responses"] assert "security" not in no_auth_spec["get"] - - def test_auth_current_user(self, app, users): - active_user_creds = users["Active"]["creds"] - api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] - - blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") - - @blp.route("/user") - @blp.login_required - @blp.response(200) - def user(): - return jsonify(blp.current_user().name) - - api.register_blueprint(blp) - client = app.test_client() - - # Active user - headers = {"Authorization": "Basic " + active_user_creds} - resp = client.get("/auth_test/user", headers=headers) - assert resp.json == "Active" From 91268068903eae522dc4ba53cbe30d68f407cca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Fri, 19 Apr 2024 15:54:16 +0200 Subject: [PATCH 02/16] Add JWT authentication --- pyproject.toml | 1 + requirements/install.txt | 5 + src/bemserver_api/__init__.py | 1 + .../extensions/authentication.py | 81 ++++++++++- src/bemserver_api/settings.py | 6 + tests/common.py | 7 +- tests/conftest.py | 20 ++- tests/extensions/test_authentication.py | 128 +++++++++++++++++- 8 files changed, 230 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c45c07a..67e2d59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "marshmallow-sqlalchemy>=0.29.0,<0.30", "flask_smorest>=0.43.0,<0.44", "apispec>=6.1.0,<7.0", + "joserfc>=0.9.0,<0.10", "bemserver-core>=0.17.1,<0.18", ] diff --git a/requirements/install.txt b/requirements/install.txt index a6cf47c..c54498f 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -31,6 +31,7 @@ certifi==2024.2.2 cffi==1.16.0 # via # argon2-cffi-bindings + # cryptography # oso charset-normalizer==3.3.2 # via requests @@ -48,6 +49,8 @@ click-plugins==1.1.1 # via celery click-repl==0.3.0 # via celery +cryptography==42.0.5 + # via joserfc flask==3.0.2 # via # bemserver-api (pyproject.toml) @@ -62,6 +65,8 @@ itsdangerous==2.1.2 # via flask jinja2==3.1.4 # via flask +joserfc==0.9.0 + # via bemserver-api (pyproject.toml) kombu==5.3.5 # via celery mako==1.3.2 diff --git a/src/bemserver_api/__init__.py b/src/bemserver_api/__init__.py index a3a904b..75ed76b 100644 --- a/src/bemserver_api/__init__.py +++ b/src/bemserver_api/__init__.py @@ -36,6 +36,7 @@ def create_app(): } ) api.init_app(app) + authentication.auth.init_app(app) register_blueprints(api) BEMServerCore() diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index e47b198..f038b28 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -1,6 +1,7 @@ """Authentication""" import base64 +import datetime as dt from functools import wraps import sqlalchemy as sqla @@ -9,6 +10,10 @@ from flask_smorest import abort +from joserfc import jwt +from joserfc.errors import JoseError +from joserfc.jwk import OctKey + from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser from bemserver_core.model.users import User @@ -19,13 +24,64 @@ class Auth: """Authentication and authorization management""" + HEADER = {"alg": "HS256"} + TOKEN_LIFETIME = 900 + + GET_USER_FUNCS = { + "Bearer": "get_user_jwt", + "Basic": "get_user_http_basic_auth", + } + + def __init__(self, app=None): + self.key = None + if app is not None: + self.init_app(app) + + def init_app(self, app): + self.key = OctKey.import_key(app.config["SECRET_KEY"]) + self.get_user_funcs = { + k: getattr(self, v) + for k, v in self.GET_USER_FUNCS.items() + if k in app.config["AUTH_METHODS"] + } + + def encode(self, user): + claims = { + "email": user.email, + "exp": dt.datetime.now(tz=dt.timezone.utc) + + dt.timedelta(seconds=self.TOKEN_LIFETIME), + } + return jwt.encode(self.HEADER, claims, self.key) + + def decode(self, text): + return jwt.decode(text, self.key) + + def validate_token(self, token): + claims_requests = jwt.JWTClaimsRegistry(email={"essential": True}) + claims_requests.validate(token.claims) + + def get_user_jwt(self, creds): + try: + token = self.decode(creds) + except (ValueError, JoseError) as exc: + raise (BEMServerAPIAuthenticationError) from exc + try: + self.validate_token(token) + except JoseError as exc: + raise (BEMServerAPIAuthenticationError) from exc + + user_email = token.claims["email"] + user = db.session.execute( + sqla.select(User).where(User.email == user_email) + ).scalar() + if user is None: + raise (BEMServerAPIAuthenticationError) + return user + @staticmethod - def get_user_http_basic_auth(): + def get_user_http_basic_auth(creds): """Check password and return User instance""" - if (auth_header := flask.request.headers.get("Authorization")) is None: - raise (BEMServerAPIAuthenticationError) try: - _, creds = auth_header.encode("utf-8").split(b" ", maxsplit=1) enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) user_email = enc_email.decode() password = enc_password.decode() @@ -38,10 +94,23 @@ def get_user_http_basic_auth(): raise (BEMServerAPIAuthenticationError) return user + def get_user(self): + if (auth_header := flask.request.headers.get("Authorization")) is None: + raise BEMServerAPIAuthenticationError + try: + scheme, creds = auth_header.split(" ", maxsplit=1) + except ValueError as exc: + raise BEMServerAPIAuthenticationError from exc + try: + func = self.get_user_funcs[scheme] + except KeyError as exc: + raise BEMServerAPIAuthenticationError from exc + return func(creds.encode("utf-8")) + def login_required(self, f=None, **kwargs): """Decorator providing authentication and authorization - Uses HTTPBasicAuth.login_required authenticate user + Uses JWT or HTTPBasicAuth.login_required to authenticate user Sets CurrentUser context variable to authenticated user for the request Catches Authorization error and aborts accordingly """ @@ -50,7 +119,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **func_kwargs): try: - user = self.get_user_http_basic_auth() + user = self.get_user() except BEMServerAPIAuthenticationError: abort(401, "Authentication error") with CurrentUser(user): diff --git a/src/bemserver_api/settings.py b/src/bemserver_api/settings.py index 96a9d3f..1ff5a1d 100644 --- a/src/bemserver_api/settings.py +++ b/src/bemserver_api/settings.py @@ -4,6 +4,12 @@ class Config: """Default configuration""" + # Authentication + SECRET_KEY = "" + AUTH_METHODS = [ + "Bearer", + ] + # API parameters API_TITLE = "BEMServer API" OPENAPI_JSON_PATH = "api-spec.json" diff --git a/tests/common.py b/tests/common.py index 07c83ce..0a664e4 100644 --- a/tests/common.py +++ b/tests/common.py @@ -6,6 +6,11 @@ class TestConfig(Config): TESTING = True + SECRET_KEY = "Test secret" + AUTH_METHODS = [ + "Bearer", + "Basic", + ] AUTH_HEADER = ContextVar("auth_header", default=None) @@ -16,7 +21,7 @@ def __init__(self, creds): self.creds = creds def __enter__(self): - self.token = AUTH_HEADER.set("Basic " + self.creds) + self.token = AUTH_HEADER.set(self.creds) def __exit__(self, *args, **kwargs): AUTH_HEADER.reset(self.token) diff --git a/tests/conftest.py b/tests/conftest.py index 930d464..6d01ee0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,12 +9,15 @@ import flask.testing +from joserfc import jwt + from bemserver_core import common, model, scheduled_tasks from bemserver_core.authorization import OpenBar from bemserver_core.commands import setup_db from bemserver_core.database import db -from bemserver_api import create_app +import bemserver_api +from bemserver_api.extensions.authentication import auth from tests.common import AUTH_HEADER, TestConfig @@ -65,9 +68,10 @@ def open(self, *args, **kwargs): @pytest.fixture(params=(TestConfig,)) -def app(request, bsc_config): - application = create_app() - application.config.from_object(TestConfig) +def app(request, bsc_config, monkeypatch): + with monkeypatch.context() as mp_ctx: + mp_ctx.setattr(bemserver_api.settings, "Config", request.param) + application = bemserver_api.create_app() application.test_client_class = TestClient setup_db() yield application @@ -92,8 +96,12 @@ def users(app, request): ) user.set_password(password) creds = base64.b64encode(f"{email}:{password}".encode()).decode() - invalid_creds = base64.b64encode(f"{email}:bad_pwd".encode()).decode() - ret[name] = {"user": user, "creds": creds, "invalid_creds": invalid_creds} + ret[name] = { + "user": user, + "creds": "Basic " + creds, + "jwt": "Bearer " + + jwt.encode(auth.HEADER, {"email": user.email}, TestConfig.SECRET_KEY), + } db.session.commit() # Set id after commit for user in ret.values(): diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 0fa7743..417ca54 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -1,16 +1,53 @@ """Test authentication extension""" import base64 +import datetime as dt + +import pytest + +from joserfc import jwt +from joserfc.errors import ExpiredTokenError, MissingClaimError from bemserver_core.authorization import get_current_user from bemserver_api import Blueprint +from bemserver_api.extensions.authentication import auth class TestAuthentication: - def test_auth_login_required(self, app, users): + def test_auth_encode_decode(self, app, users): + user_1 = users["Active"]["user"] + + text = auth.encode(user_1) + token = auth.decode(text) + + assert token.header == {"typ": "JWT", "alg": "HS256"} + assert token.claims["email"] == "active@test.com" + assert "exp" in token.claims + auth.validate_token(token) + + def test_auth_decode_error(self, app): + with pytest.raises(ValueError): + auth.decode("dummy") + + def test_auth_validation_error(self, app): + text = jwt.encode(auth.HEADER, {"email": "test@test.com", "exp": 0}, auth.key) + token = auth.decode(text) + with pytest.raises(ExpiredTokenError): + auth.validate_token(token) + + text = jwt.encode( + auth.HEADER, {"exp": dt.datetime.now(tz=dt.timezone.utc)}, auth.key + ) + token = auth.decode(text) + with pytest.raises(MissingClaimError): + auth.validate_token(token) + + def test_auth_login_required_http_basic_auth(self, app, users): active_user_creds = users["Active"]["creds"] - active_user_invalid_creds = users["Active"]["invalid_creds"] + active_user_invalid_creds = base64.b64encode( + f'{users["Active"]["user"].email}:bad_pwd'.encode() + ).decode() inactive_user_creds = users["Inactive"]["creds"] api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") @@ -18,7 +55,7 @@ def test_auth_login_required(self, app, users): @blp.route("/auth") @blp.login_required @blp.response(200) - def auth(): + def auth_func(): return get_current_user().name @blp.route("/no_auth") @@ -50,21 +87,100 @@ def no_auth(): assert resp.status_code == 204 # Inactive user - headers = {"Authorization": "Basic " + inactive_user_creds} + headers = {"Authorization": inactive_user_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 403 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user with invalid creds (bad password) - headers = {"Authorization": "Basic " + active_user_invalid_creds} + headers = {"Authorization": active_user_invalid_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Active user + headers = {"Authorization": active_user_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 200 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Check OpenAPI spec + spec = api.spec.to_dict() + assert spec["components"]["securitySchemes"]["BasicAuthentication"] == { + "type": "http", + "scheme": "basic", + } + auth_spec = spec["paths"]["/auth_test/auth"] + assert auth_spec["get"]["responses"]["401"] == { + "$ref": "#/components/responses/UNAUTHORIZED" + } + assert auth_spec["get"]["security"] == [{"BasicAuthentication": []}] + no_auth_spec = spec["paths"]["/auth_test/no_auth"] + assert "401" not in no_auth_spec["get"]["responses"] + assert "security" not in no_auth_spec["get"] + + def test_auth_login_required_jwt(self, app, users): + active_user_jwt = users["Active"]["jwt"] + active_user_invalid_jwt = jwt.encode( + auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy" + ) + inactive_user_jwt = users["Inactive"]["jwt"] + api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] + blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") + + @blp.route("/auth") + @blp.login_required + @blp.response(200) + def auth_func(): + return get_current_user().name + + @blp.route("/no_auth") + @blp.response(204) + def no_auth(): + return None + + api.register_blueprint(blp) + client = app.test_client() + + # Anonymous user + headers = {} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Broken auth headers + headers = {"Authorization": "Bearer Dummy"} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + creds = jwt.encode(auth.HEADER, {}, app.config["SECRET_KEY"]) + headers = {"Authorization": "Bearer " + creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Inactive user + headers = {"Authorization": inactive_user_jwt} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 403 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Active user with invalid jwt (bad password) + headers = {"Authorization": active_user_invalid_jwt} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user - headers = {"Authorization": "Basic " + active_user_creds} + headers = {"Authorization": active_user_jwt} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 200 resp = client.get("/auth_test/no_auth", headers=headers) From 2f9cfd037b06d1aa0a1b63765f45317d9ccf7d3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Fri, 19 Apr 2024 16:09:25 +0200 Subject: [PATCH 03/16] Use JWT in route tests --- tests/conftest.py | 6 +-- tests/extensions/test_authentication.py | 59 +++++++++++++++++++------ 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6d01ee0..6cb2017 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -95,11 +95,11 @@ def users(app, request): name=name, email=email, is_admin=is_admin, is_active=is_active ) user.set_password(password) - creds = base64.b64encode(f"{email}:{password}".encode()).decode() ret[name] = { "user": user, - "creds": "Basic " + creds, - "jwt": "Bearer " + "hba_creds": "Basic " + + base64.b64encode(f"{email}:{password}".encode()).decode(), + "creds": "Bearer " + jwt.encode(auth.HEADER, {"email": user.email}, TestConfig.SECRET_KEY), } db.session.commit() diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 417ca54..b9589cf 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -7,6 +7,7 @@ from joserfc import jwt from joserfc.errors import ExpiredTokenError, MissingClaimError +from tests.common import TestConfig from bemserver_core.authorization import get_current_user @@ -14,6 +15,18 @@ from bemserver_api.extensions.authentication import auth +class HBATestConfig(TestConfig): + AUTH_METHODS = [ + "Basic", + ] + + +class JWTTestConfig(TestConfig): + AUTH_METHODS = [ + "Bearer", + ] + + class TestAuthentication: def test_auth_encode_decode(self, app, users): user_1 = users["Active"]["user"] @@ -43,12 +56,14 @@ def test_auth_validation_error(self, app): with pytest.raises(MissingClaimError): auth.validate_token(token) + @pytest.mark.parametrize("app", (HBATestConfig,), indirect=True) def test_auth_login_required_http_basic_auth(self, app, users): - active_user_creds = users["Active"]["creds"] - active_user_invalid_creds = base64.b64encode( + active_user_hba_creds = users["Active"]["hba_creds"] + active_user_invalid_hba_creds = base64.b64encode( f'{users["Active"]["user"].email}:bad_pwd'.encode() ).decode() - inactive_user_creds = users["Inactive"]["creds"] + inactive_user_hba_creds = users["Inactive"]["hba_creds"] + active_user_jwt = users["Active"]["creds"] api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") @@ -79,29 +94,36 @@ def no_auth(): assert resp.status_code == 401 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 - creds = base64.b64encode(b"Dummy").decode() - headers = {"Authorization": "Basic " + creds} + hba_creds = base64.b64encode(b"Dummy").decode() + headers = {"Authorization": "Basic " + hba_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Wrong scheme + headers = {"Authorization": active_user_jwt} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Inactive user - headers = {"Authorization": inactive_user_creds} + headers = {"Authorization": inactive_user_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 403 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user with invalid creds (bad password) - headers = {"Authorization": active_user_invalid_creds} + headers = {"Authorization": active_user_invalid_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user - headers = {"Authorization": active_user_creds} + headers = {"Authorization": active_user_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 200 resp = client.get("/auth_test/no_auth", headers=headers) @@ -122,12 +144,14 @@ def no_auth(): assert "401" not in no_auth_spec["get"]["responses"] assert "security" not in no_auth_spec["get"] + @pytest.mark.parametrize("app", (JWTTestConfig,), indirect=True) def test_auth_login_required_jwt(self, app, users): - active_user_jwt = users["Active"]["jwt"] - active_user_invalid_jwt = jwt.encode( + active_user_jwt_creds = users["Active"]["creds"] + active_user_invalid_jwt_creds = jwt.encode( auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy" ) - inactive_user_jwt = users["Inactive"]["jwt"] + inactive_user_jwt_creds = users["Inactive"]["creds"] + active_user_hba_creds = users["Active"]["hba_creds"] api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") @@ -165,22 +189,29 @@ def no_auth(): resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 + # Wrong scheme + headers = {"Authorization": active_user_hba_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + # Inactive user - headers = {"Authorization": inactive_user_jwt} + headers = {"Authorization": inactive_user_jwt_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 403 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user with invalid jwt (bad password) - headers = {"Authorization": active_user_invalid_jwt} + headers = {"Authorization": active_user_invalid_jwt_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user - headers = {"Authorization": active_user_jwt} + headers = {"Authorization": active_user_jwt_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 200 resp = client.get("/auth_test/no_auth", headers=headers) From 25b60e82492fba726de8859c64be7bfebc200059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Fri, 19 Apr 2024 16:21:20 +0200 Subject: [PATCH 04/16] conftest.py: compute credentials at import time --- tests/conftest.py | 50 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6cb2017..8674831 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,29 +78,53 @@ def app(request, bsc_config, monkeypatch): db.session.remove() -USERS = ( - ("Chuck", "N0rris", "chuck@test.com", True, True), - ("Active", "@ctive", "active@test.com", False, True), - ("Inactive", "in@ctive", "inactive@test.com", False, False), -) +USERS = { + "Chuck": { + "email": "chuck@test.com", + "password": "N0rris", + "is_admin": True, + "is_active": True, + }, + "Active": { + "email": "active@test.com", + "password": "@ctive", + "is_admin": False, + "is_active": True, + }, + "Inactive": { + "email": "inactive@test.com", + "password": "in@ctive", + "is_admin": False, + "is_active": False, + }, +} + +for user in USERS.values(): + user["hba_creds"] = ( + "Basic " + + base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode() + ) + user["creds"] = "Bearer " + jwt.encode( + auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY + ) @pytest.fixture(params=(USERS,)) def users(app, request): with OpenBar(): ret = {} - for user in request.param: - name, password, email, is_admin, is_active = user + for name, elems in request.param.items(): user = model.User.new( - name=name, email=email, is_admin=is_admin, is_active=is_active + name=name, + email=elems["email"], + is_admin=elems["is_admin"], + is_active=elems["is_active"], ) - user.set_password(password) + user.set_password(elems["password"]) ret[name] = { "user": user, - "hba_creds": "Basic " - + base64.b64encode(f"{email}:{password}".encode()).decode(), - "creds": "Bearer " - + jwt.encode(auth.HEADER, {"email": user.email}, TestConfig.SECRET_KEY), + "hba_creds": elems["hba_creds"], + "creds": elems["creds"], } db.session.commit() # Set id after commit From d771aa89acdad0b8b7dba2af0805f6419bd944aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Fri, 19 Apr 2024 17:07:44 +0200 Subject: [PATCH 05/16] Add Auth.get_user_by_email --- .../extensions/authentication.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index f038b28..2719c34 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -64,34 +64,34 @@ def get_user_jwt(self, creds): try: token = self.decode(creds) except (ValueError, JoseError) as exc: - raise (BEMServerAPIAuthenticationError) from exc + raise BEMServerAPIAuthenticationError from exc try: self.validate_token(token) except JoseError as exc: - raise (BEMServerAPIAuthenticationError) from exc + raise BEMServerAPIAuthenticationError from exc user_email = token.claims["email"] - user = db.session.execute( - sqla.select(User).where(User.email == user_email) - ).scalar() - if user is None: - raise (BEMServerAPIAuthenticationError) - return user + return self.get_user_by_email(user_email) - @staticmethod - def get_user_http_basic_auth(creds): + def get_user_http_basic_auth(self, creds): """Check password and return User instance""" try: enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) user_email = enc_email.decode() password = enc_password.decode() except (ValueError, TypeError) as exc: - raise (BEMServerAPIAuthenticationError) from exc + raise BEMServerAPIAuthenticationError from exc + user = self.get_user_by_email(user_email) + if not user.check_password(password): + raise BEMServerAPIAuthenticationError + return user + + def get_user_by_email(self, user_email): user = db.session.execute( sqla.select(User).where(User.email == user_email) ).scalar() - if user is None or not user.check_password(password): - raise (BEMServerAPIAuthenticationError) + if user is None: + raise BEMServerAPIAuthenticationError return user def get_user(self): From 7491d5971153366e037448f7d2cb65430fc9acdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 23 Apr 2024 11:37:21 +0200 Subject: [PATCH 06/16] Add get_token route --- src/bemserver_api/extensions/smorest.py | 32 +++++++++++++++++++++++++ tests/extensions/test_smorest.py | 24 +++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 tests/extensions/test_smorest.py diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 809039d..3fe5eea 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -9,6 +9,9 @@ import marshmallow_sqlalchemy as msa from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow.common import resolve_schema_cls +from flask_smorest import abort + +from bemserver_api.exceptions import BEMServerAPIAuthenticationError from . import integrity_error from .authentication import auth @@ -40,6 +43,7 @@ def __init__(self, app=None, *, spec_kwargs=None): def init_app(self, app, *, spec_kwargs=None): super().init_app(app, spec_kwargs=spec_kwargs) self.register_field(Timezone, "string", "iana-tz") + self.register_blueprint(auth_blp) self.spec.components.security_scheme( "BasicAuthentication", {"type": "http", "scheme": "basic"} ) @@ -132,3 +136,31 @@ class SQLCursorPage(flask_smorest.Page): @property def item_count(self): return self.collection.count() + + +auth_blp = Blueprint( + "Auth", __name__, url_prefix="/auth", description="Authentication operations" +) + + +class GetJWTArgsSchema(Schema): + email = ma.fields.Email(required=True) + password = ma.fields.String(validate=ma.validate.Length(1, 80), required=True) + + +class GetJWTRespSchema(Schema): + token = ma.fields.String() + + +@auth_blp.route("/token", methods=["POST"]) +@auth_blp.arguments(GetJWTArgsSchema) +@auth_blp.response(201, GetJWTRespSchema) +def get_token(creds): + """Get an authentication token""" + try: + user = auth.get_user_by_email(creds["email"]) + except BEMServerAPIAuthenticationError: + abort(401, "Authentication error") + if not user.check_password(creds["password"]): + abort(401, "Authentication error") + return {"token": auth.encode(user)} diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py new file mode 100644 index 0000000..84df560 --- /dev/null +++ b/tests/extensions/test_smorest.py @@ -0,0 +1,24 @@ +"""Test smorest extension""" + + +class TestSmorest: + def test_get_token(self, app, users): + user_1 = users["Active"]["user"] + + client = app.test_client() + payload = {"email": user_1.email, "password": "@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 201 + assert "token" in resp.json + + # Wrong password + client = app.test_client() + payload = {"email": user_1.email, "password": "dummy"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 401 + + # Wrong email + client = app.test_client() + payload = {"email": "dummy@dummy.com", "password": "dummy"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 401 From 604f77b2514a477929767a863d752679f23029c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 23 Apr 2024 15:01:49 +0200 Subject: [PATCH 07/16] Document Bearer auth in OpenAPI spec --- src/bemserver_api/extensions/smorest.py | 22 ++++++++++++++++++---- tests/extensions/test_authentication.py | 19 ++++++++++++------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 3fe5eea..b1ba82a 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -30,6 +30,18 @@ def resolver(schema): return name +SECURITY_SCHEMES = { + "Bearer": ( + "BearerAuthentication", + {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + ), + "Basic": ( + "BasicAuthentication", + {"type": "http", "scheme": "basic"}, + ), +} + + class Api(flask_smorest.Api): """Api class""" @@ -44,9 +56,8 @@ def init_app(self, app, *, spec_kwargs=None): super().init_app(app, spec_kwargs=spec_kwargs) self.register_field(Timezone, "string", "iana-tz") self.register_blueprint(auth_blp) - self.spec.components.security_scheme( - "BasicAuthentication", {"type": "http", "scheme": "basic"} - ) + for scheme in app.config["AUTH_METHODS"]: + self.spec.components.security_scheme(*SECURITY_SCHEMES[scheme]) class Blueprint(flask_smorest.Blueprint): @@ -71,7 +82,10 @@ def decorator(function): def _prepare_auth_doc(doc, doc_info, *, app, **kwargs): if doc_info.get("auth", False): doc.setdefault("responses", {})["401"] = http.HTTPStatus(401).name - doc["security"] = [{"BasicAuthentication": []}] + doc["security"] = [ + {SECURITY_SCHEMES[scheme][0]: []} + for scheme in app.config["AUTH_METHODS"] + ] return doc @staticmethod diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index b9589cf..356a5ef 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -131,9 +131,11 @@ def no_auth(): # Check OpenAPI spec spec = api.spec.to_dict() - assert spec["components"]["securitySchemes"]["BasicAuthentication"] == { - "type": "http", - "scheme": "basic", + assert spec["components"]["securitySchemes"] == { + "BasicAuthentication": { + "type": "http", + "scheme": "basic", + } } auth_spec = spec["paths"]["/auth_test/auth"] assert auth_spec["get"]["responses"]["401"] == { @@ -219,15 +221,18 @@ def no_auth(): # Check OpenAPI spec spec = api.spec.to_dict() - assert spec["components"]["securitySchemes"]["BasicAuthentication"] == { - "type": "http", - "scheme": "basic", + assert spec["components"]["securitySchemes"] == { + "BearerAuthentication": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + } } auth_spec = spec["paths"]["/auth_test/auth"] assert auth_spec["get"]["responses"]["401"] == { "$ref": "#/components/responses/UNAUTHORIZED" } - assert auth_spec["get"]["security"] == [{"BasicAuthentication": []}] + assert auth_spec["get"]["security"] == [{"BearerAuthentication": []}] no_auth_spec = spec["paths"]["/auth_test/no_auth"] assert "401" not in no_auth_spec["get"]["responses"] assert "security" not in no_auth_spec["get"] From 5cb6eb700f5d53d60bc8fed4b18bcf2399a42700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 23 Apr 2024 15:31:09 +0200 Subject: [PATCH 08/16] Factorize security in OpenAPI spec --- src/bemserver_api/extensions/smorest.py | 15 ++++++++++----- tests/extensions/test_authentication.py | 10 ++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index b1ba82a..a1a0ad3 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -53,6 +53,10 @@ def __init__(self, app=None, *, spec_kwargs=None): super().__init__(app=app, spec_kwargs=spec_kwargs) def init_app(self, app, *, spec_kwargs=None): + spec_kwargs = spec_kwargs or {} + spec_kwargs["security"] = [ + {SECURITY_SCHEMES[scheme][0]: []} for scheme in app.config["AUTH_METHODS"] + ] super().init_app(app, spec_kwargs=spec_kwargs) self.register_field(Timezone, "string", "iana-tz") self.register_blueprint(auth_blp) @@ -82,10 +86,8 @@ def decorator(function): def _prepare_auth_doc(doc, doc_info, *, app, **kwargs): if doc_info.get("auth", False): doc.setdefault("responses", {})["401"] = http.HTTPStatus(401).name - doc["security"] = [ - {SECURITY_SCHEMES[scheme][0]: []} - for scheme in app.config["AUTH_METHODS"] - ] + else: + doc["security"] = [] return doc @staticmethod @@ -153,7 +155,10 @@ def item_count(self): auth_blp = Blueprint( - "Auth", __name__, url_prefix="/auth", description="Authentication operations" + "Authentication", + __name__, + url_prefix="/auth", + description="Authentication operations", ) diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 356a5ef..f1aafbe 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -131,6 +131,7 @@ def no_auth(): # Check OpenAPI spec spec = api.spec.to_dict() + assert spec["security"] == [{"BasicAuthentication": []}] assert spec["components"]["securitySchemes"] == { "BasicAuthentication": { "type": "http", @@ -141,10 +142,10 @@ def no_auth(): assert auth_spec["get"]["responses"]["401"] == { "$ref": "#/components/responses/UNAUTHORIZED" } - assert auth_spec["get"]["security"] == [{"BasicAuthentication": []}] + assert "security" not in auth_spec["get"] no_auth_spec = spec["paths"]["/auth_test/no_auth"] assert "401" not in no_auth_spec["get"]["responses"] - assert "security" not in no_auth_spec["get"] + assert no_auth_spec["get"]["security"] == [] @pytest.mark.parametrize("app", (JWTTestConfig,), indirect=True) def test_auth_login_required_jwt(self, app, users): @@ -221,6 +222,7 @@ def no_auth(): # Check OpenAPI spec spec = api.spec.to_dict() + assert spec["security"] == [{"BearerAuthentication": []}] assert spec["components"]["securitySchemes"] == { "BearerAuthentication": { "type": "http", @@ -232,7 +234,7 @@ def no_auth(): assert auth_spec["get"]["responses"]["401"] == { "$ref": "#/components/responses/UNAUTHORIZED" } - assert auth_spec["get"]["security"] == [{"BearerAuthentication": []}] + assert "security" not in auth_spec["get"] no_auth_spec = spec["paths"]["/auth_test/no_auth"] assert "401" not in no_auth_spec["get"]["responses"] - assert "security" not in no_auth_spec["get"] + assert no_auth_spec["get"]["security"] == [] From adb5d3964c50a4a778a4ce3b65b3458d640aadd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 23 Apr 2024 22:57:03 +0200 Subject: [PATCH 09/16] Rework auth error handling --- src/bemserver_api/exceptions.py | 3 + .../extensions/authentication.py | 58 ++++++++++++------- src/bemserver_api/extensions/smorest.py | 38 ++++++++---- tests/extensions/test_authentication.py | 55 +++++++++++++++--- tests/extensions/test_smorest.py | 4 +- 5 files changed, 115 insertions(+), 43 deletions(-) diff --git a/src/bemserver_api/exceptions.py b/src/bemserver_api/exceptions.py index f19e5b1..8b4f6eb 100644 --- a/src/bemserver_api/exceptions.py +++ b/src/bemserver_api/exceptions.py @@ -7,3 +7,6 @@ class BEMServerAPIError(Exception): class BEMServerAPIAuthenticationError(BEMServerAPIError): """AuthenticationError error""" + + def __init__(self, code): + self.code = code diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index 2719c34..f070db0 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -11,7 +11,7 @@ from flask_smorest import abort from joserfc import jwt -from joserfc.errors import JoseError +from joserfc.errors import ExpiredTokenError, JoseError from joserfc.jwk import OctKey from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser @@ -34,10 +34,13 @@ class Auth: def __init__(self, app=None): self.key = None + self.app = None + self.get_user_funcs = None if app is not None: self.init_app(app) def init_app(self, app): + self.app = app self.key = OctKey.import_key(app.config["SECRET_KEY"]) self.get_user_funcs = { k: getattr(self, v) @@ -60,18 +63,27 @@ def validate_token(self, token): claims_requests = jwt.JWTClaimsRegistry(email={"essential": True}) claims_requests.validate(token.claims) + @staticmethod + def get_user_by_email(user_email): + return db.session.execute( + sqla.select(User).where(User.email == user_email) + ).scalar() + def get_user_jwt(self, creds): try: token = self.decode(creds) except (ValueError, JoseError) as exc: - raise BEMServerAPIAuthenticationError from exc + raise BEMServerAPIAuthenticationError(code="malformed_token") from exc try: self.validate_token(token) + except ExpiredTokenError as exc: + raise BEMServerAPIAuthenticationError(code="expired_token") from exc except JoseError as exc: - raise BEMServerAPIAuthenticationError from exc - + raise BEMServerAPIAuthenticationError(code="invalid_token") from exc user_email = token.claims["email"] - return self.get_user_by_email(user_email) + if (user := self.get_user_by_email(user_email)) is None: + raise BEMServerAPIAuthenticationError(code="invalid_token") + return user def get_user_http_basic_auth(self, creds): """Check password and return User instance""" @@ -80,31 +92,24 @@ def get_user_http_basic_auth(self, creds): user_email = enc_email.decode() password = enc_password.decode() except (ValueError, TypeError) as exc: - raise BEMServerAPIAuthenticationError from exc - user = self.get_user_by_email(user_email) + raise BEMServerAPIAuthenticationError(code="malformed_credentials") from exc + if (user := self.get_user_by_email(user_email)) is None: + raise BEMServerAPIAuthenticationError(code="invalid_credentials") if not user.check_password(password): - raise BEMServerAPIAuthenticationError - return user - - def get_user_by_email(self, user_email): - user = db.session.execute( - sqla.select(User).where(User.email == user_email) - ).scalar() - if user is None: - raise BEMServerAPIAuthenticationError + raise BEMServerAPIAuthenticationError(code="invalid_credentials") return user def get_user(self): if (auth_header := flask.request.headers.get("Authorization")) is None: - raise BEMServerAPIAuthenticationError + raise BEMServerAPIAuthenticationError(code="missing_authentication") try: scheme, creds = auth_header.split(" ", maxsplit=1) - except ValueError as exc: - raise BEMServerAPIAuthenticationError from exc + except ValueError: + abort(400) try: func = self.get_user_funcs[scheme] except KeyError as exc: - raise BEMServerAPIAuthenticationError from exc + raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc return func(creds.encode("utf-8")) def login_required(self, f=None, **kwargs): @@ -120,8 +125,17 @@ def decorator(func): def wrapper(*args, **func_kwargs): try: user = self.get_user() - except BEMServerAPIAuthenticationError: - abort(401, "Authentication error") + except BEMServerAPIAuthenticationError as exc: + abort( + 401, + "Authentication error", + errors={"authentication": exc.code}, + headers={ + "WWW-Authenticate": ", ".join( + self.app.config["AUTH_METHODS"] + ) + }, + ) with CurrentUser(user): try: resp = func(*args, **func_kwargs) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index a1a0ad3..2b70bee 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -4,14 +4,13 @@ from copy import deepcopy from functools import wraps +import flask + import flask_smorest import marshmallow as ma import marshmallow_sqlalchemy as msa from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow.common import resolve_schema_cls -from flask_smorest import abort - -from bemserver_api.exceptions import BEMServerAPIAuthenticationError from . import integrity_error from .authentication import auth @@ -171,15 +170,34 @@ class GetJWTRespSchema(Schema): token = ma.fields.String() +class GetJWTErrorSchema(Schema): + error = ma.fields.String() + + @auth_blp.route("/token", methods=["POST"]) @auth_blp.arguments(GetJWTArgsSchema) -@auth_blp.response(201, GetJWTRespSchema) +@auth_blp.response( + 201, + GetJWTRespSchema, + example={ + "token": ( + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u" + "JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM" + ) + }, + description="Token created", +) +@auth_blp.alt_response( + # No 401, here. See https://stackoverflow.com/a/67359937 + 200, + schema=GetJWTErrorSchema, + description="Wrong credentials", + example={"error": "Wrong username or password"}, + success=True, +) def get_token(creds): """Get an authentication token""" - try: - user = auth.get_user_by_email(creds["email"]) - except BEMServerAPIAuthenticationError: - abort(401, "Authentication error") - if not user.check_password(creds["password"]): - abort(401, "Authentication error") + user = auth.get_user_by_email(creds["email"]) + if user is None or not user.check_password(creds["password"]): + return flask.jsonify({"error": "Wrong username or password"}) return {"token": auth.encode(user)} diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index f1aafbe..46f6b32 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -35,7 +35,7 @@ def test_auth_encode_decode(self, app, users): token = auth.decode(text) assert token.header == {"typ": "JWT", "alg": "HS256"} - assert token.claims["email"] == "active@test.com" + assert token.claims["email"] == user_1.email assert "exp" in token.claims auth.validate_token(token) @@ -85,6 +85,8 @@ def no_auth(): headers = {} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "missing_authentication" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 @@ -92,12 +94,16 @@ def no_auth(): headers = {"Authorization": "Basic Dummy"} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "malformed_credentials" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 hba_creds = base64.b64encode(b"Dummy").decode() headers = {"Authorization": "Basic " + hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "malformed_credentials" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 @@ -105,6 +111,8 @@ def no_auth(): headers = {"Authorization": active_user_jwt} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "invalid_scheme" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 @@ -115,10 +123,12 @@ def no_auth(): resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 - # Active user with invalid creds (bad password) - headers = {"Authorization": active_user_invalid_hba_creds} + # Active user with invalid creds (wrong password) + headers = {"Authorization": "Basic " + active_user_invalid_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "invalid_credentials" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 @@ -149,6 +159,7 @@ def no_auth(): @pytest.mark.parametrize("app", (JWTTestConfig,), indirect=True) def test_auth_login_required_jwt(self, app, users): + user_1 = users["Active"]["user"] active_user_jwt_creds = users["Active"]["creds"] active_user_invalid_jwt_creds = jwt.encode( auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy" @@ -176,19 +187,27 @@ def no_auth(): headers = {} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "missing_authentication" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 - # Broken auth headers + # Malformed token headers = {"Authorization": "Bearer Dummy"} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "malformed_token" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 + + # Missing claims creds = jwt.encode(auth.HEADER, {}, app.config["SECRET_KEY"]) headers = {"Authorization": "Bearer " + creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_token" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 @@ -196,20 +215,38 @@ def no_auth(): headers = {"Authorization": active_user_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_scheme" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 - # Inactive user - headers = {"Authorization": inactive_user_jwt_creds} + # Bad signature + headers = {"Authorization": "Bearer " + active_user_invalid_jwt_creds} resp = client.get("/auth_test/auth", headers=headers) - assert resp.status_code == 403 + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "malformed_token" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 - # Active user with invalid jwt (bad password) - headers = {"Authorization": active_user_invalid_jwt_creds} + # Expired token + headers = { + "Authorization": "Bearer " + + jwt.encode( + auth.HEADER, {"email": user_1.email, "exp": 0}, app.config["SECRET_KEY"] + ) + } resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "expired_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Inactive user + headers = {"Authorization": inactive_user_jwt_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 403 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py index 84df560..8236988 100644 --- a/tests/extensions/test_smorest.py +++ b/tests/extensions/test_smorest.py @@ -15,10 +15,10 @@ def test_get_token(self, app, users): client = app.test_client() payload = {"email": user_1.email, "password": "dummy"} resp = client.post("/auth/token", json=payload) - assert resp.status_code == 401 + assert resp.status_code == 200 # Wrong email client = app.test_client() payload = {"email": "dummy@dummy.com", "password": "dummy"} resp = client.post("/auth/token", json=payload) - assert resp.status_code == 401 + assert resp.status_code == 200 From 848372a74ccc328e1e75b71e6c11713a60dc44fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 23 Apr 2024 23:27:40 +0200 Subject: [PATCH 10/16] Fix test_blp_integrity_error: use conftest app --- tests/extensions/test_integrity_error.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/extensions/test_integrity_error.py b/tests/extensions/test_integrity_error.py index bc01472..f589dcf 100644 --- a/tests/extensions/test_integrity_error.py +++ b/tests/extensions/test_integrity_error.py @@ -5,9 +5,7 @@ import psycopg.errors as ppe import sqlalchemy as sqla -import flask - -from bemserver_api import Api, Blueprint +from bemserver_api import Blueprint class TestIntegrityError: @@ -28,11 +26,8 @@ class TestIntegrityError: ), ), ) - def test_blp_integrity_error(self, error, message): - app = flask.Flask("Test") - api = Api( - app, spec_kwargs={"title": "Test", "version": "1", "openapi_version": "3"} - ) + def test_blp_integrity_error(self, app, error, message): + api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] blp = Blueprint("Test", __name__, url_prefix="/test") From 270d472998c4449fcd7da5598acba9a2229092a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 25 Apr 2024 17:39:08 +0200 Subject: [PATCH 11/16] get_token: always return 200 status code --- src/bemserver_api/extensions/smorest.py | 40 ++++++++++++------------- tests/extensions/test_smorest.py | 5 +++- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 2b70bee..a106de8 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -167,37 +167,35 @@ class GetJWTArgsSchema(Schema): class GetJWTRespSchema(Schema): + status = ma.fields.String(validate=ma.validate.OneOf(("success", "failure"))) token = ma.fields.String() -class GetJWTErrorSchema(Schema): - error = ma.fields.String() - - @auth_blp.route("/token", methods=["POST"]) @auth_blp.arguments(GetJWTArgsSchema) @auth_blp.response( - 201, + 200, GetJWTRespSchema, - example={ - "token": ( - "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u" - "JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM" - ) + examples={ + "success": { + "value": { + "status": "success", + "token": ( + "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u" + "JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM" + ), + }, + }, + "failure": { + "value": { + "status": "failure", + }, + }, }, - description="Token created", -) -@auth_blp.alt_response( - # No 401, here. See https://stackoverflow.com/a/67359937 - 200, - schema=GetJWTErrorSchema, - description="Wrong credentials", - example={"error": "Wrong username or password"}, - success=True, ) def get_token(creds): """Get an authentication token""" user = auth.get_user_by_email(creds["email"]) if user is None or not user.check_password(creds["password"]): - return flask.jsonify({"error": "Wrong username or password"}) - return {"token": auth.encode(user)} + return flask.jsonify({"status": "failure"}) + return {"status": "success", "token": auth.encode(user)} diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py index 8236988..1878f43 100644 --- a/tests/extensions/test_smorest.py +++ b/tests/extensions/test_smorest.py @@ -8,7 +8,8 @@ def test_get_token(self, app, users): client = app.test_client() payload = {"email": user_1.email, "password": "@ctive"} resp = client.post("/auth/token", json=payload) - assert resp.status_code == 201 + assert resp.status_code == 200 + assert resp.json["status"] == "success" assert "token" in resp.json # Wrong password @@ -16,9 +17,11 @@ def test_get_token(self, app, users): payload = {"email": user_1.email, "password": "dummy"} resp = client.post("/auth/token", json=payload) assert resp.status_code == 200 + assert resp.json == {"status": "failure"} # Wrong email client = app.test_client() payload = {"email": "dummy@dummy.com", "password": "dummy"} resp = client.post("/auth/token", json=payload) assert resp.status_code == 200 + assert resp.json == {"status": "failure"} From bbe735021425e7cac7306bafcac964cafbfe9468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Fri, 17 May 2024 14:28:32 +0200 Subject: [PATCH 12/16] Replace joserfc with authlib.jose --- pyproject.toml | 2 +- requirements/install.txt | 6 +-- .../extensions/authentication.py | 25 +++++------ src/bemserver_api/extensions/smorest.py | 2 +- tests/conftest.py | 11 ++--- tests/extensions/test_authentication.py | 43 +++++++++++-------- 6 files changed, 46 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67e2d59..8a508f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "marshmallow-sqlalchemy>=0.29.0,<0.30", "flask_smorest>=0.43.0,<0.44", "apispec>=6.1.0,<7.0", - "joserfc>=0.9.0,<0.10", + "authlib>=1.3.0,<2.0", "bemserver-core>=0.17.1,<0.18", ] diff --git a/requirements/install.txt b/requirements/install.txt index c54498f..19be03f 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -18,6 +18,8 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi async-timeout==4.0.3 # via redis +authlib==1.3.0 + # via bemserver-api (pyproject.toml) bemserver-core==0.17.1 # via bemserver-api (pyproject.toml) billiard==4.2.0 @@ -50,7 +52,7 @@ click-plugins==1.1.1 click-repl==0.3.0 # via celery cryptography==42.0.5 - # via joserfc + # via authlib flask==3.0.2 # via # bemserver-api (pyproject.toml) @@ -65,8 +67,6 @@ itsdangerous==2.1.2 # via flask jinja2==3.1.4 # via flask -joserfc==0.9.0 - # via bemserver-api (pyproject.toml) kombu==5.3.5 # via celery mako==1.3.2 diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index f070db0..ffd4adb 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -10,9 +10,8 @@ from flask_smorest import abort -from joserfc import jwt -from joserfc.errors import ExpiredTokenError, JoseError -from joserfc.jwk import OctKey +from authlib.jose import JsonWebToken +from authlib.jose.errors import ExpiredTokenError, JoseError from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser from bemserver_core.model.users import User @@ -20,6 +19,9 @@ from bemserver_api.database import db from bemserver_api.exceptions import BEMServerAPIAuthenticationError +# https://docs.authlib.org/en/latest/jose/jwt.html#jwt-with-limited-algorithms +jwt = JsonWebToken(["HS256"]) + class Auth: """Authentication and authorization management""" @@ -41,7 +43,7 @@ def __init__(self, app=None): def init_app(self, app): self.app = app - self.key = OctKey.import_key(app.config["SECRET_KEY"]) + self.key = app.config["SECRET_KEY"] self.get_user_funcs = { k: getattr(self, v) for k, v in self.GET_USER_FUNCS.items() @@ -57,11 +59,7 @@ def encode(self, user): return jwt.encode(self.HEADER, claims, self.key) def decode(self, text): - return jwt.decode(text, self.key) - - def validate_token(self, token): - claims_requests = jwt.JWTClaimsRegistry(email={"essential": True}) - claims_requests.validate(token.claims) + return jwt.decode(text, self.key, claims_options={"email": {"essential": True}}) @staticmethod def get_user_by_email(user_email): @@ -71,16 +69,13 @@ def get_user_by_email(user_email): def get_user_jwt(self, creds): try: - token = self.decode(creds) - except (ValueError, JoseError) as exc: - raise BEMServerAPIAuthenticationError(code="malformed_token") from exc - try: - self.validate_token(token) + claims = self.decode(creds) + claims.validate() except ExpiredTokenError as exc: raise BEMServerAPIAuthenticationError(code="expired_token") from exc except JoseError as exc: raise BEMServerAPIAuthenticationError(code="invalid_token") from exc - user_email = token.claims["email"] + user_email = claims["email"] if (user := self.get_user_by_email(user_email)) is None: raise BEMServerAPIAuthenticationError(code="invalid_token") return user diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index a106de8..5e7bfcc 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -198,4 +198,4 @@ def get_token(creds): user = auth.get_user_by_email(creds["email"]) if user is None or not user.check_password(creds["password"]): return flask.jsonify({"status": "failure"}) - return {"status": "success", "token": auth.encode(user)} + return {"status": "success", "token": auth.encode(user).decode("utf-8")} diff --git a/tests/conftest.py b/tests/conftest.py index 8674831..7c5d140 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,15 +9,13 @@ import flask.testing -from joserfc import jwt - from bemserver_core import common, model, scheduled_tasks from bemserver_core.authorization import OpenBar from bemserver_core.commands import setup_db from bemserver_core.database import db import bemserver_api -from bemserver_api.extensions.authentication import auth +from bemserver_api.extensions.authentication import auth, jwt from tests.common import AUTH_HEADER, TestConfig @@ -104,8 +102,11 @@ def app(request, bsc_config, monkeypatch): "Basic " + base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode() ) - user["creds"] = "Bearer " + jwt.encode( - auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY + user["creds"] = ( + "Bearer " + + jwt.encode( + auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY + ).decode() ) diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 46f6b32..7658a45 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -5,14 +5,18 @@ import pytest -from joserfc import jwt -from joserfc.errors import ExpiredTokenError, MissingClaimError +from authlib.jose.errors import ( + BadSignatureError, + DecodeError, + ExpiredTokenError, + MissingClaimError, +) from tests.common import TestConfig from bemserver_core.authorization import get_current_user from bemserver_api import Blueprint -from bemserver_api.extensions.authentication import auth +from bemserver_api.extensions.authentication import auth, jwt class HBATestConfig(TestConfig): @@ -32,29 +36,32 @@ def test_auth_encode_decode(self, app, users): user_1 = users["Active"]["user"] text = auth.encode(user_1) - token = auth.decode(text) + claims = auth.decode(text) - assert token.header == {"typ": "JWT", "alg": "HS256"} - assert token.claims["email"] == user_1.email - assert "exp" in token.claims - auth.validate_token(token) + assert claims["email"] == user_1.email + assert "exp" in claims + claims.validate() def test_auth_decode_error(self, app): - with pytest.raises(ValueError): + with pytest.raises(DecodeError): auth.decode("dummy") + text = jwt.encode(auth.HEADER, {"email": "test@test.com"}, "Dummy") + with pytest.raises(BadSignatureError): + auth.decode(text) + def test_auth_validation_error(self, app): text = jwt.encode(auth.HEADER, {"email": "test@test.com", "exp": 0}, auth.key) - token = auth.decode(text) + claims = auth.decode(text) with pytest.raises(ExpiredTokenError): - auth.validate_token(token) + claims.validate() text = jwt.encode( auth.HEADER, {"exp": dt.datetime.now(tz=dt.timezone.utc)}, auth.key ) - token = auth.decode(text) + claims = auth.decode(text) with pytest.raises(MissingClaimError): - auth.validate_token(token) + claims.validate() @pytest.mark.parametrize("app", (HBATestConfig,), indirect=True) def test_auth_login_required_http_basic_auth(self, app, users): @@ -163,7 +170,7 @@ def test_auth_login_required_jwt(self, app, users): active_user_jwt_creds = users["Active"]["creds"] active_user_invalid_jwt_creds = jwt.encode( auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy" - ) + ).decode() inactive_user_jwt_creds = users["Inactive"]["creds"] active_user_hba_creds = users["Active"]["hba_creds"] api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] @@ -197,12 +204,12 @@ def no_auth(): resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 assert resp.headers["WWW-Authenticate"] == "Bearer" - assert resp.json["errors"]["authentication"] == "malformed_token" + assert resp.json["errors"]["authentication"] == "invalid_token" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Missing claims - creds = jwt.encode(auth.HEADER, {}, app.config["SECRET_KEY"]) + creds = jwt.encode(auth.HEADER, {}, app.config["SECRET_KEY"]).decode() headers = {"Authorization": "Bearer " + creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 @@ -225,7 +232,7 @@ def no_auth(): resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 assert resp.headers["WWW-Authenticate"] == "Bearer" - assert resp.json["errors"]["authentication"] == "malformed_token" + assert resp.json["errors"]["authentication"] == "invalid_token" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 @@ -234,7 +241,7 @@ def no_auth(): "Authorization": "Bearer " + jwt.encode( auth.HEADER, {"email": user_1.email, "exp": 0}, app.config["SECRET_KEY"] - ) + ).decode() } resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 From 2d8cbc00f585bd69a1bf5e13ed2585ae2a7ce156 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Tue, 21 May 2024 17:11:30 +0200 Subject: [PATCH 13/16] Fix JWT: don't provide token to inactive user --- src/bemserver_api/extensions/smorest.py | 2 +- tests/extensions/test_smorest.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 5e7bfcc..54d78ae 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -196,6 +196,6 @@ class GetJWTRespSchema(Schema): def get_token(creds): """Get an authentication token""" user = auth.get_user_by_email(creds["email"]) - if user is None or not user.check_password(creds["password"]): + if user is None or not user.check_password(creds["password"]) or not user.is_active: return flask.jsonify({"status": "failure"}) return {"status": "success", "token": auth.encode(user).decode("utf-8")} diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py index 1878f43..993a373 100644 --- a/tests/extensions/test_smorest.py +++ b/tests/extensions/test_smorest.py @@ -4,6 +4,7 @@ class TestSmorest: def test_get_token(self, app, users): user_1 = users["Active"]["user"] + user_2 = users["Inactive"]["user"] client = app.test_client() payload = {"email": user_1.email, "password": "@ctive"} @@ -12,6 +13,13 @@ def test_get_token(self, app, users): assert resp.json["status"] == "success" assert "token" in resp.json + # Inactive user + client = app.test_client() + payload = {"email": user_2.email, "password": "in@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} + # Wrong password client = app.test_client() payload = {"email": user_1.email, "password": "dummy"} From 5281c6eeb845014defb9b9d503f2b2e62028a606 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 22 May 2024 11:17:00 +0200 Subject: [PATCH 14/16] Add refresh token feature Also don't provide token routes when JWT auth is not used --- .../extensions/authentication.py | 43 ++++++++--- src/bemserver_api/extensions/smorest.py | 65 ++++++++++++++-- tests/common.py | 10 +++ tests/conftest.py | 10 +-- tests/extensions/test_authentication.py | 69 +++++++++++++++-- tests/extensions/test_smorest.py | 77 ++++++++++++++++++- 6 files changed, 240 insertions(+), 34 deletions(-) diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index ffd4adb..299581c 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -2,6 +2,7 @@ import base64 import datetime as dt +from datetime import datetime from functools import wraps import sqlalchemy as sqla @@ -27,7 +28,8 @@ class Auth: """Authentication and authorization management""" HEADER = {"alg": "HS256"} - TOKEN_LIFETIME = 900 + ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes + REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months GET_USER_FUNCS = { "Bearer": "get_user_jwt", @@ -50,16 +52,31 @@ def init_app(self, app): if k in app.config["AUTH_METHODS"] } - def encode(self, user): + def encode(self, user, token_type="access"): + token_lifetime = ( + self.ACCESS_TOKEN_LIFETIME + if token_type == "access" + else self.REFRESH_TOKEN_LIFETIME + ) claims = { "email": user.email, - "exp": dt.datetime.now(tz=dt.timezone.utc) - + dt.timedelta(seconds=self.TOKEN_LIFETIME), + # datetime is imported in module namespace to allow test mock + # kinda sucks, but oh well... + "exp": datetime.now(tz=dt.timezone.utc) + + dt.timedelta(seconds=token_lifetime), + "type": token_type, } - return jwt.encode(self.HEADER, claims, self.key) + return jwt.encode(self.HEADER.copy(), claims, self.key) def decode(self, text): - return jwt.decode(text, self.key, claims_options={"email": {"essential": True}}) + return jwt.decode( + text, + self.key, + claims_options={ + "email": {"essential": True}, + "type": {"essential": True}, + }, + ) @staticmethod def get_user_by_email(user_email): @@ -67,7 +84,7 @@ def get_user_by_email(user_email): sqla.select(User).where(User.email == user_email) ).scalar() - def get_user_jwt(self, creds): + def get_user_jwt(self, creds, refresh=False): try: claims = self.decode(creds) claims.validate() @@ -75,12 +92,14 @@ def get_user_jwt(self, creds): raise BEMServerAPIAuthenticationError(code="expired_token") from exc except JoseError as exc: raise BEMServerAPIAuthenticationError(code="invalid_token") from exc + if refresh is not (claims["type"] == "refresh"): + raise BEMServerAPIAuthenticationError(code="invalid_token") user_email = claims["email"] if (user := self.get_user_by_email(user_email)) is None: raise BEMServerAPIAuthenticationError(code="invalid_token") return user - def get_user_http_basic_auth(self, creds): + def get_user_http_basic_auth(self, creds, **_kwargs): """Check password and return User instance""" try: enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) @@ -94,7 +113,7 @@ def get_user_http_basic_auth(self, creds): raise BEMServerAPIAuthenticationError(code="invalid_credentials") return user - def get_user(self): + def get_user(self, refresh=False): if (auth_header := flask.request.headers.get("Authorization")) is None: raise BEMServerAPIAuthenticationError(code="missing_authentication") try: @@ -105,9 +124,9 @@ def get_user(self): func = self.get_user_funcs[scheme] except KeyError as exc: raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc - return func(creds.encode("utf-8")) + return func(creds.encode("utf-8"), refresh=refresh) - def login_required(self, f=None, **kwargs): + def login_required(self, f=None, refresh=False): """Decorator providing authentication and authorization Uses JWT or HTTPBasicAuth.login_required to authenticate user @@ -119,7 +138,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **func_kwargs): try: - user = self.get_user() + user = self.get_user(refresh=refresh) except BEMServerAPIAuthenticationError as exc: abort( 401, diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 54d78ae..ee03242 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -12,6 +12,8 @@ from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow.common import resolve_schema_cls +from bemserver_core.authorization import get_current_user + from . import integrity_error from .authentication import auth from .ma_fields import Timezone @@ -58,7 +60,8 @@ def init_app(self, app, *, spec_kwargs=None): ] super().init_app(app, spec_kwargs=spec_kwargs) self.register_field(Timezone, "string", "iana-tz") - self.register_blueprint(auth_blp) + if "Bearer" in app.config["AUTH_METHODS"]: + self.register_blueprint(auth_blp) for scheme in app.config["AUTH_METHODS"]: self.spec.components.security_scheme(*SECURITY_SCHEMES[scheme]) @@ -168,7 +171,8 @@ class GetJWTArgsSchema(Schema): class GetJWTRespSchema(Schema): status = ma.fields.String(validate=ma.validate.OneOf(("success", "failure"))) - token = ma.fields.String() + access_token = ma.fields.String() + refresh_token = ma.fields.String() @auth_blp.route("/token", methods=["POST"]) @@ -180,9 +184,17 @@ class GetJWTRespSchema(Schema): "success": { "value": { "status": "success", - "token": ( - "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u" - "JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM" + "access_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50" + "7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4" + ), + "refresh_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc" + "SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ" ), }, }, @@ -194,8 +206,47 @@ class GetJWTRespSchema(Schema): }, ) def get_token(creds): - """Get an authentication token""" + """Get access and refresh tokens""" user = auth.get_user_by_email(creds["email"]) if user is None or not user.check_password(creds["password"]) or not user.is_active: return flask.jsonify({"status": "failure"}) - return {"status": "success", "token": auth.encode(user).decode("utf-8")} + return { + "status": "success", + "access_token": auth.encode(user).decode("utf-8"), + "refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"), + } + + +@auth_blp.route("/token/refresh", methods=["POST"]) +@auth_blp.login_required(refresh=True) +@auth_blp.response( + 200, + GetJWTRespSchema, + examples={ + "success": { + "value": { + "status": "success", + "access_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50" + "7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4" + ), + "refresh_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc" + "SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ" + ), + }, + }, + }, +) +def refresh_token(): + """Refresh access and refresh tokens""" + user = get_current_user() + return { + "status": "success", + "access_token": auth.encode(user).decode("utf-8"), + "refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"), + } diff --git a/tests/common.py b/tests/common.py index 0a664e4..241dc30 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,6 +1,7 @@ from contextlib import AbstractContextManager from contextvars import ContextVar +from bemserver_api.extensions.authentication import auth, jwt from bemserver_api.settings import Config @@ -25,3 +26,12 @@ def __enter__(self): def __exit__(self, *args, **kwargs): AUTH_HEADER.reset(self.token) + + +def make_token(user_email, token_type): + # Make an access token with no expiration + return jwt.encode( + auth.HEADER.copy(), + {"email": user_email, "type": token_type}, + TestConfig.SECRET_KEY, + ).decode() diff --git a/tests/conftest.py b/tests/conftest.py index 7c5d140..42dd8de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,8 +15,7 @@ from bemserver_core.database import db import bemserver_api -from bemserver_api.extensions.authentication import auth, jwt -from tests.common import AUTH_HEADER, TestConfig +from tests.common import AUTH_HEADER, TestConfig, make_token @pytest.fixture(scope="session", autouse=True) @@ -102,12 +101,7 @@ def app(request, bsc_config, monkeypatch): "Basic " + base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode() ) - user["creds"] = ( - "Bearer " - + jwt.encode( - auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY - ).decode() - ) + user["creds"] = "Bearer " + make_token(user["email"], "access") @pytest.fixture(params=(USERS,)) diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 7658a45..459c559 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -2,6 +2,7 @@ import base64 import datetime as dt +from unittest import mock import pytest @@ -32,32 +33,86 @@ class JWTTestConfig(TestConfig): class TestAuthentication: - def test_auth_encode_decode(self, app, users): + @mock.patch("bemserver_api.extensions.authentication.datetime") + @mock.patch("bemserver_api.extensions.authentication.jwt.encode") + def test_auth_encode(self, mock_encode, mock_dt, app, users): + dt_now = dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc) + mock_dt.now.return_value = dt_now + + user_1 = users["Active"]["user"] + + auth.encode(user_1) + mock_encode.assert_called() + call_1 = mock_encode.call_args[0] + assert call_1[0] == {"alg": "HS256"} + assert call_1[1] == { + "email": "active@test.com", + "exp": dt_now + dt.timedelta(seconds=60 * 15), + "type": "access", + } + assert call_1[2] == "Test secret" + + auth.encode(user_1, token_type="refresh") + mock_encode.assert_called() + call_1 = mock_encode.call_args[0] + assert call_1[0] == {"alg": "HS256"} + assert call_1[1] == { + "email": "active@test.com", + "exp": dt_now + dt.timedelta(seconds=60 * 60 * 24 * 60), + "type": "refresh", + } + assert call_1[2] == "Test secret" + + def test_auth_decode(self, app, users): user_1 = users["Active"]["user"] text = auth.encode(user_1) claims = auth.decode(text) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + text = auth.encode(user_1, token_type="refresh") + claims = auth.decode(text) assert claims["email"] == user_1.email assert "exp" in claims + assert claims["type"] == "refresh" claims.validate() def test_auth_decode_error(self, app): with pytest.raises(DecodeError): auth.decode("dummy") - text = jwt.encode(auth.HEADER, {"email": "test@test.com"}, "Dummy") + text = jwt.encode( + auth.HEADER, {"email": "test@test.com", "type": "access"}, "Dummy" + ) with pytest.raises(BadSignatureError): auth.decode(text) def test_auth_validation_error(self, app): - text = jwt.encode(auth.HEADER, {"email": "test@test.com", "exp": 0}, auth.key) + text = jwt.encode( + auth.HEADER, + {"email": "test@test.com", "type": "access", "exp": 0}, + auth.key, + ) claims = auth.decode(text) with pytest.raises(ExpiredTokenError): claims.validate() text = jwt.encode( - auth.HEADER, {"exp": dt.datetime.now(tz=dt.timezone.utc)}, auth.key + auth.HEADER, + {"exp": dt.datetime.now(tz=dt.timezone.utc), "type": "access"}, + auth.key, + ) + claims = auth.decode(text) + with pytest.raises(MissingClaimError): + claims.validate() + + text = jwt.encode( + auth.HEADER, + {"exp": dt.datetime.now(tz=dt.timezone.utc), "email": "test@test.com"}, + auth.key, ) claims = auth.decode(text) with pytest.raises(MissingClaimError): @@ -169,7 +224,7 @@ def test_auth_login_required_jwt(self, app, users): user_1 = users["Active"]["user"] active_user_jwt_creds = users["Active"]["creds"] active_user_invalid_jwt_creds = jwt.encode( - auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy" + auth.HEADER, {"email": "dummy@dummy.com", "type": "access"}, "Dummy" ).decode() inactive_user_jwt_creds = users["Inactive"]["creds"] active_user_hba_creds = users["Active"]["hba_creds"] @@ -240,7 +295,9 @@ def no_auth(): headers = { "Authorization": "Bearer " + jwt.encode( - auth.HEADER, {"email": user_1.email, "exp": 0}, app.config["SECRET_KEY"] + auth.HEADER, + {"email": user_1.email, "exp": 0, "type": "access"}, + app.config["SECRET_KEY"], ).decode() } resp = client.get("/auth_test/auth", headers=headers) diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py index 993a373..f36552a 100644 --- a/tests/extensions/test_smorest.py +++ b/tests/extensions/test_smorest.py @@ -1,5 +1,17 @@ """Test smorest extension""" +import pytest + +from tests.common import AuthHeader, TestConfig, make_token + +from bemserver_api.extensions.authentication import auth + + +class HBATestConfig(TestConfig): + AUTH_METHODS = [ + "Basic", + ] + class TestSmorest: def test_get_token(self, app, users): @@ -11,7 +23,23 @@ def test_get_token(self, app, users): resp = client.post("/auth/token", json=payload) assert resp.status_code == 200 assert resp.json["status"] == "success" - assert "token" in resp.json + claims = auth.decode(resp.json["access_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + claims = auth.decode(resp.json["refresh_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + # Inactive user + client = app.test_client() + payload = {"email": user_2.email, "password": "in@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} # Inactive user client = app.test_client() @@ -33,3 +61,50 @@ def test_get_token(self, app, users): resp = client.post("/auth/token", json=payload) assert resp.status_code == 200 assert resp.json == {"status": "failure"} + + def test_refresh_token(self, app, users): + user_1 = users["Active"]["user"] + + access_token = "Bearer " + make_token(user_1.email, "access") + refresh_token = "Bearer " + make_token(user_1.email, "refresh") + + client = app.test_client() + + # No token + resp = client.post("/auth/token/refresh") + assert resp.status_code == 401 + + # Acccess token + with AuthHeader(access_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 401 + + # Refresh token + with AuthHeader(refresh_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 200 + assert resp.json["status"] == "success" + claims = auth.decode(resp.json["access_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + claims = auth.decode(resp.json["refresh_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + @pytest.mark.parametrize("app", (HBATestConfig,), indirect=True) + def test_token_routes_jwt_disabled(self, app, users): + user_1 = users["Active"]["user"] + + client = app.test_client() + payload = {"email": user_1.email, "password": "@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 404 + + access_token = "Bearer " + make_token(user_1.email, "access") + with AuthHeader(access_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 404 From 15a85ad2695dad01c0caaea4e480f0201eebd120 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 5 Jun 2024 16:40:33 +0200 Subject: [PATCH 15/16] Improve JWT auth documentation --- src/bemserver_api/extensions/smorest.py | 29 ++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index ee03242..7707a64 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -3,6 +3,7 @@ import http from copy import deepcopy from functools import wraps +from textwrap import dedent import flask @@ -156,11 +157,21 @@ def item_count(self): return self.collection.count() +AUTH_BLP_DESC = dedent("""Authentication operations + +The following resources are used to get and refresh tokens. When authenticating, first +get a couple of access (short-lived) and refresh (long-lived) tokens using login and +password. When or before access token expires, refresh tokens to get a new pair of +tokens with new expiration dates. If refresh token is expired, get a new pair of tokens +using login and password again. +""") + + auth_blp = Blueprint( "Authentication", __name__, url_prefix="/auth", - description="Authentication operations", + description=AUTH_BLP_DESC, ) @@ -206,7 +217,12 @@ class GetJWTRespSchema(Schema): }, ) def get_token(creds): - """Get access and refresh tokens""" + """Get access and refresh tokens + + Use login and password to get a pair of access and refresh tokens. + + No authentication header needed. Credentials must be passed in request payload. + """ user = auth.get_user_by_email(creds["email"]) if user is None or not user.check_password(creds["password"]) or not user.is_active: return flask.jsonify({"status": "failure"}) @@ -243,7 +259,14 @@ def get_token(creds): }, ) def refresh_token(): - """Refresh access and refresh tokens""" + """Refresh access and refresh tokens + + When access token is expired, call this resource using the refresh token to get a + new pair of tokens. + + As opposed to all other resources, this resource must be accessed using the refresh + token, not the access token. + """ user = get_current_user() return { "status": "success", From 9891cf1ee1da0143053c399d280e2aa0776b40db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 5 Jun 2024 17:27:09 +0200 Subject: [PATCH 16/16] Improve JWT test coverage --- tests/extensions/test_authentication.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 459c559..f9d73e8 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -273,6 +273,27 @@ def no_auth(): resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 + # Unknown user + creds = jwt.encode( + auth.HEADER, + {"email": "dummy@dummy.com", "type": "access"}, + app.config["SECRET_KEY"], + ).decode() + headers = {"Authorization": "Bearer " + creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Broken header + headers = {"Authorization": "dummy"} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 400 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + # Wrong scheme headers = {"Authorization": active_user_hba_creds} resp = client.get("/auth_test/auth", headers=headers)