diff --git a/pyproject.toml b/pyproject.toml index 3a2937b..8a508f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "marshmallow-sqlalchemy>=0.29.0,<0.30", "flask_smorest>=0.43.0,<0.44", "apispec>=6.1.0,<7.0", - "flask-httpauth>=4.7.0,<5.0", + "authlib>=1.3.0,<2.0", "bemserver-core>=0.17.1,<0.18", ] @@ -77,7 +77,7 @@ section-order = ["future", "standard-library", "testing", "db", "pallets", "mars [tool.ruff.lint.isort.sections] testing = ["pytest", "pytest_postgresql"] db = ["psycopg", "sqlalchemy", "alembic"] -pallets = ["werkzeug", "flask", "flask_httpauth"] +pallets = ["werkzeug", "flask"] marshmallow = ["marshmallow", "marshmallow_sqlalchemy", "webargs", "apispec", "flask_smorest"] science = ["numpy", "pandas"] core = ["bemserver_core"] diff --git a/requirements/install.txt b/requirements/install.txt index 86dca56..19be03f 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -18,6 +18,8 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi async-timeout==4.0.3 # via redis +authlib==1.3.0 + # via bemserver-api (pyproject.toml) bemserver-core==0.17.1 # via bemserver-api (pyproject.toml) billiard==4.2.0 @@ -31,6 +33,7 @@ certifi==2024.2.2 cffi==1.16.0 # via # argon2-cffi-bindings + # cryptography # oso charset-normalizer==3.3.2 # via requests @@ -48,13 +51,12 @@ click-plugins==1.1.1 # via celery click-repl==0.3.0 # via celery +cryptography==42.0.5 + # via authlib flask==3.0.2 # via # bemserver-api (pyproject.toml) - # flask-httpauth # flask-smorest -flask-httpauth==4.8.0 - # via bemserver-api (pyproject.toml) flask-smorest==0.43.0 # via bemserver-api (pyproject.toml) greenlet==3.0.3 diff --git a/src/bemserver_api/__init__.py b/src/bemserver_api/__init__.py index a3a904b..75ed76b 100644 --- a/src/bemserver_api/__init__.py +++ b/src/bemserver_api/__init__.py @@ -36,6 +36,7 @@ def create_app(): } ) api.init_app(app) + authentication.auth.init_app(app) register_blueprints(api) BEMServerCore() diff --git a/src/bemserver_api/exceptions.py b/src/bemserver_api/exceptions.py new file mode 100644 index 0000000..8b4f6eb --- /dev/null +++ b/src/bemserver_api/exceptions.py @@ -0,0 +1,12 @@ +"""Exceptions""" + + +class BEMServerAPIError(Exception): + """Base BEMServer API exception""" + + +class BEMServerAPIAuthenticationError(BEMServerAPIError): + """AuthenticationError error""" + + def __init__(self, code): + self.code = code diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index bd59eea..299581c 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -1,26 +1,135 @@ """Authentication""" +import base64 +import datetime as dt +from datetime import datetime from functools import wraps import sqlalchemy as sqla -from flask_httpauth import HTTPBasicAuth +import flask from flask_smorest import abort +from authlib.jose import JsonWebToken +from authlib.jose.errors import ExpiredTokenError, JoseError + from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser from bemserver_core.model.users import User from bemserver_api.database import db +from bemserver_api.exceptions import BEMServerAPIAuthenticationError + +# https://docs.authlib.org/en/latest/jose/jwt.html#jwt-with-limited-algorithms +jwt = JsonWebToken(["HS256"]) -class Auth(HTTPBasicAuth): +class Auth: """Authentication and authorization management""" - def login_required(self, f=None, **kwargs): + HEADER = {"alg": "HS256"} + ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes + REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months + + GET_USER_FUNCS = { + "Bearer": "get_user_jwt", + "Basic": "get_user_http_basic_auth", + } + + def __init__(self, app=None): + self.key = None + self.app = None + self.get_user_funcs = None + if app is not None: + self.init_app(app) + + def init_app(self, app): + self.app = app + self.key = app.config["SECRET_KEY"] + self.get_user_funcs = { + k: getattr(self, v) + for k, v in self.GET_USER_FUNCS.items() + if k in app.config["AUTH_METHODS"] + } + + def encode(self, user, token_type="access"): + token_lifetime = ( + self.ACCESS_TOKEN_LIFETIME + if token_type == "access" + else self.REFRESH_TOKEN_LIFETIME + ) + claims = { + "email": user.email, + # datetime is imported in module namespace to allow test mock + # kinda sucks, but oh well... + "exp": datetime.now(tz=dt.timezone.utc) + + dt.timedelta(seconds=token_lifetime), + "type": token_type, + } + return jwt.encode(self.HEADER.copy(), claims, self.key) + + def decode(self, text): + return jwt.decode( + text, + self.key, + claims_options={ + "email": {"essential": True}, + "type": {"essential": True}, + }, + ) + + @staticmethod + def get_user_by_email(user_email): + return db.session.execute( + sqla.select(User).where(User.email == user_email) + ).scalar() + + def get_user_jwt(self, creds, refresh=False): + try: + claims = self.decode(creds) + claims.validate() + except ExpiredTokenError as exc: + raise BEMServerAPIAuthenticationError(code="expired_token") from exc + except JoseError as exc: + raise BEMServerAPIAuthenticationError(code="invalid_token") from exc + if refresh is not (claims["type"] == "refresh"): + raise BEMServerAPIAuthenticationError(code="invalid_token") + user_email = claims["email"] + if (user := self.get_user_by_email(user_email)) is None: + raise BEMServerAPIAuthenticationError(code="invalid_token") + return user + + def get_user_http_basic_auth(self, creds, **_kwargs): + """Check password and return User instance""" + try: + enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) + user_email = enc_email.decode() + password = enc_password.decode() + except (ValueError, TypeError) as exc: + raise BEMServerAPIAuthenticationError(code="malformed_credentials") from exc + if (user := self.get_user_by_email(user_email)) is None: + raise BEMServerAPIAuthenticationError(code="invalid_credentials") + if not user.check_password(password): + raise BEMServerAPIAuthenticationError(code="invalid_credentials") + return user + + def get_user(self, refresh=False): + if (auth_header := flask.request.headers.get("Authorization")) is None: + raise BEMServerAPIAuthenticationError(code="missing_authentication") + try: + scheme, creds = auth_header.split(" ", maxsplit=1) + except ValueError: + abort(400) + try: + func = self.get_user_funcs[scheme] + except KeyError as exc: + raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc + return func(creds.encode("utf-8"), refresh=refresh) + + def login_required(self, f=None, refresh=False): """Decorator providing authentication and authorization - Uses HTTPBasicAuth.login_required authenticate user + Uses JWT or HTTPBasicAuth.login_required to authenticate user Sets CurrentUser context variable to authenticated user for the request Catches Authorization error and aborts accordingly """ @@ -28,16 +137,27 @@ def login_required(self, f=None, **kwargs): def decorator(func): @wraps(func) def wrapper(*args, **func_kwargs): - with CurrentUser(self.current_user()): + try: + user = self.get_user(refresh=refresh) + except BEMServerAPIAuthenticationError as exc: + abort( + 401, + "Authentication error", + errors={"authentication": exc.code}, + headers={ + "WWW-Authenticate": ", ".join( + self.app.config["AUTH_METHODS"] + ) + }, + ) + with CurrentUser(user): try: resp = func(*args, **func_kwargs) except BEMServerAuthorizationError: abort(403, message="Authorization error") return resp - # Wrap this inside HTTPAuth.login_required - # to get authenticated user - return super(Auth, self).login_required(**kwargs)(wrapper) + return wrapper if f: return decorator(f) @@ -45,19 +165,3 @@ def wrapper(*args, **func_kwargs): auth = Auth() - - -@auth.verify_password -def verify_password(username, password): - """Check password and return User instance""" - user = db.session.execute(sqla.select(User).where(User.email == username)).scalar() - if user is not None and user.check_password(password): - return user - return None - - -@auth.error_handler -def auth_error(status): - """Authentication error handler""" - # Call abort to trigger error handler and get consistent JSON output - abort(status, message="Authentication error") diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index d21e563..7707a64 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -3,6 +3,9 @@ import http from copy import deepcopy from functools import wraps +from textwrap import dedent + +import flask import flask_smorest import marshmallow as ma @@ -10,6 +13,8 @@ from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow.common import resolve_schema_cls +from bemserver_core.authorization import get_current_user + from . import integrity_error from .authentication import auth from .ma_fields import Timezone @@ -27,6 +32,18 @@ def resolver(schema): return name +SECURITY_SCHEMES = { + "Bearer": ( + "BearerAuthentication", + {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + ), + "Basic": ( + "BasicAuthentication", + {"type": "http", "scheme": "basic"}, + ), +} + + class Api(flask_smorest.Api): """Api class""" @@ -38,11 +55,16 @@ def __init__(self, app=None, *, spec_kwargs=None): super().__init__(app=app, spec_kwargs=spec_kwargs) def init_app(self, app, *, spec_kwargs=None): + spec_kwargs = spec_kwargs or {} + spec_kwargs["security"] = [ + {SECURITY_SCHEMES[scheme][0]: []} for scheme in app.config["AUTH_METHODS"] + ] super().init_app(app, spec_kwargs=spec_kwargs) self.register_field(Timezone, "string", "iana-tz") - self.spec.components.security_scheme( - "BasicAuthentication", {"type": "http", "scheme": "basic"} - ) + if "Bearer" in app.config["AUTH_METHODS"]: + self.register_blueprint(auth_blp) + for scheme in app.config["AUTH_METHODS"]: + self.spec.components.security_scheme(*SECURITY_SCHEMES[scheme]) class Blueprint(flask_smorest.Blueprint): @@ -67,13 +89,10 @@ def decorator(function): def _prepare_auth_doc(doc, doc_info, *, app, **kwargs): if doc_info.get("auth", False): doc.setdefault("responses", {})["401"] = http.HTTPStatus(401).name - doc["security"] = [{"BasicAuthentication": []}] + else: + doc["security"] = [] return doc - @staticmethod - def current_user(): - return auth.current_user() - @staticmethod def catch_integrity_error(func=None): """Catch DB integrity errors""" @@ -136,3 +155,121 @@ class SQLCursorPage(flask_smorest.Page): @property def item_count(self): return self.collection.count() + + +AUTH_BLP_DESC = dedent("""Authentication operations + +The following resources are used to get and refresh tokens. When authenticating, first +get a couple of access (short-lived) and refresh (long-lived) tokens using login and +password. When or before access token expires, refresh tokens to get a new pair of +tokens with new expiration dates. If refresh token is expired, get a new pair of tokens +using login and password again. +""") + + +auth_blp = Blueprint( + "Authentication", + __name__, + url_prefix="/auth", + description=AUTH_BLP_DESC, +) + + +class GetJWTArgsSchema(Schema): + email = ma.fields.Email(required=True) + password = ma.fields.String(validate=ma.validate.Length(1, 80), required=True) + + +class GetJWTRespSchema(Schema): + status = ma.fields.String(validate=ma.validate.OneOf(("success", "failure"))) + access_token = ma.fields.String() + refresh_token = ma.fields.String() + + +@auth_blp.route("/token", methods=["POST"]) +@auth_blp.arguments(GetJWTArgsSchema) +@auth_blp.response( + 200, + GetJWTRespSchema, + examples={ + "success": { + "value": { + "status": "success", + "access_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50" + "7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4" + ), + "refresh_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc" + "SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ" + ), + }, + }, + "failure": { + "value": { + "status": "failure", + }, + }, + }, +) +def get_token(creds): + """Get access and refresh tokens + + Use login and password to get a pair of access and refresh tokens. + + No authentication header needed. Credentials must be passed in request payload. + """ + user = auth.get_user_by_email(creds["email"]) + if user is None or not user.check_password(creds["password"]) or not user.is_active: + return flask.jsonify({"status": "failure"}) + return { + "status": "success", + "access_token": auth.encode(user).decode("utf-8"), + "refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"), + } + + +@auth_blp.route("/token/refresh", methods=["POST"]) +@auth_blp.login_required(refresh=True) +@auth_blp.response( + 200, + GetJWTRespSchema, + examples={ + "success": { + "value": { + "status": "success", + "access_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50" + "7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4" + ), + "refresh_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc" + "SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ" + ), + }, + }, + }, +) +def refresh_token(): + """Refresh access and refresh tokens + + When access token is expired, call this resource using the refresh token to get a + new pair of tokens. + + As opposed to all other resources, this resource must be accessed using the refresh + token, not the access token. + """ + user = get_current_user() + return { + "status": "success", + "access_token": auth.encode(user).decode("utf-8"), + "refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"), + } diff --git a/src/bemserver_api/settings.py b/src/bemserver_api/settings.py index 96a9d3f..1ff5a1d 100644 --- a/src/bemserver_api/settings.py +++ b/src/bemserver_api/settings.py @@ -4,6 +4,12 @@ class Config: """Default configuration""" + # Authentication + SECRET_KEY = "" + AUTH_METHODS = [ + "Bearer", + ] + # API parameters API_TITLE = "BEMServer API" OPENAPI_JSON_PATH = "api-spec.json" diff --git a/tests/common.py b/tests/common.py index 07c83ce..241dc30 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,11 +1,17 @@ from contextlib import AbstractContextManager from contextvars import ContextVar +from bemserver_api.extensions.authentication import auth, jwt from bemserver_api.settings import Config class TestConfig(Config): TESTING = True + SECRET_KEY = "Test secret" + AUTH_METHODS = [ + "Bearer", + "Basic", + ] AUTH_HEADER = ContextVar("auth_header", default=None) @@ -16,7 +22,16 @@ def __init__(self, creds): self.creds = creds def __enter__(self): - self.token = AUTH_HEADER.set("Basic " + self.creds) + self.token = AUTH_HEADER.set(self.creds) def __exit__(self, *args, **kwargs): AUTH_HEADER.reset(self.token) + + +def make_token(user_email, token_type): + # Make an access token with no expiration + return jwt.encode( + auth.HEADER.copy(), + {"email": user_email, "type": token_type}, + TestConfig.SECRET_KEY, + ).decode() diff --git a/tests/conftest.py b/tests/conftest.py index 930d464..42dd8de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,8 +14,8 @@ from bemserver_core.commands import setup_db from bemserver_core.database import db -from bemserver_api import create_app -from tests.common import AUTH_HEADER, TestConfig +import bemserver_api +from tests.common import AUTH_HEADER, TestConfig, make_token @pytest.fixture(scope="session", autouse=True) @@ -65,35 +65,62 @@ def open(self, *args, **kwargs): @pytest.fixture(params=(TestConfig,)) -def app(request, bsc_config): - application = create_app() - application.config.from_object(TestConfig) +def app(request, bsc_config, monkeypatch): + with monkeypatch.context() as mp_ctx: + mp_ctx.setattr(bemserver_api.settings, "Config", request.param) + application = bemserver_api.create_app() application.test_client_class = TestClient setup_db() yield application db.session.remove() -USERS = ( - ("Chuck", "N0rris", "chuck@test.com", True, True), - ("Active", "@ctive", "active@test.com", False, True), - ("Inactive", "in@ctive", "inactive@test.com", False, False), -) +USERS = { + "Chuck": { + "email": "chuck@test.com", + "password": "N0rris", + "is_admin": True, + "is_active": True, + }, + "Active": { + "email": "active@test.com", + "password": "@ctive", + "is_admin": False, + "is_active": True, + }, + "Inactive": { + "email": "inactive@test.com", + "password": "in@ctive", + "is_admin": False, + "is_active": False, + }, +} + +for user in USERS.values(): + user["hba_creds"] = ( + "Basic " + + base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode() + ) + user["creds"] = "Bearer " + make_token(user["email"], "access") @pytest.fixture(params=(USERS,)) def users(app, request): with OpenBar(): ret = {} - for user in request.param: - name, password, email, is_admin, is_active = user + for name, elems in request.param.items(): user = model.User.new( - name=name, email=email, is_admin=is_admin, is_active=is_active + name=name, + email=elems["email"], + is_admin=elems["is_admin"], + is_active=elems["is_active"], ) - user.set_password(password) - creds = base64.b64encode(f"{email}:{password}".encode()).decode() - invalid_creds = base64.b64encode(f"{email}:bad_pwd".encode()).decode() - ret[name] = {"user": user, "creds": creds, "invalid_creds": invalid_creds} + user.set_password(elems["password"]) + ret[name] = { + "user": user, + "hba_creds": elems["hba_creds"], + "creds": elems["creds"], + } db.session.commit() # Set id after commit for user in ret.values(): diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 2c2700c..f9d73e8 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -1,24 +1,138 @@ """Test authentication extension""" -from flask import jsonify +import base64 +import datetime as dt +from unittest import mock + +import pytest + +from authlib.jose.errors import ( + BadSignatureError, + DecodeError, + ExpiredTokenError, + MissingClaimError, +) +from tests.common import TestConfig from bemserver_core.authorization import get_current_user from bemserver_api import Blueprint +from bemserver_api.extensions.authentication import auth, jwt + + +class HBATestConfig(TestConfig): + AUTH_METHODS = [ + "Basic", + ] + + +class JWTTestConfig(TestConfig): + AUTH_METHODS = [ + "Bearer", + ] class TestAuthentication: - def test_auth_login_required(self, app, users): - active_user_creds = users["Active"]["creds"] - active_user_invalid_creds = users["Active"]["invalid_creds"] - inactive_user_creds = users["Inactive"]["creds"] + @mock.patch("bemserver_api.extensions.authentication.datetime") + @mock.patch("bemserver_api.extensions.authentication.jwt.encode") + def test_auth_encode(self, mock_encode, mock_dt, app, users): + dt_now = dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc) + mock_dt.now.return_value = dt_now + + user_1 = users["Active"]["user"] + + auth.encode(user_1) + mock_encode.assert_called() + call_1 = mock_encode.call_args[0] + assert call_1[0] == {"alg": "HS256"} + assert call_1[1] == { + "email": "active@test.com", + "exp": dt_now + dt.timedelta(seconds=60 * 15), + "type": "access", + } + assert call_1[2] == "Test secret" + + auth.encode(user_1, token_type="refresh") + mock_encode.assert_called() + call_1 = mock_encode.call_args[0] + assert call_1[0] == {"alg": "HS256"} + assert call_1[1] == { + "email": "active@test.com", + "exp": dt_now + dt.timedelta(seconds=60 * 60 * 24 * 60), + "type": "refresh", + } + assert call_1[2] == "Test secret" + + def test_auth_decode(self, app, users): + user_1 = users["Active"]["user"] + + text = auth.encode(user_1) + claims = auth.decode(text) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + + text = auth.encode(user_1, token_type="refresh") + claims = auth.decode(text) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + def test_auth_decode_error(self, app): + with pytest.raises(DecodeError): + auth.decode("dummy") + + text = jwt.encode( + auth.HEADER, {"email": "test@test.com", "type": "access"}, "Dummy" + ) + with pytest.raises(BadSignatureError): + auth.decode(text) + + def test_auth_validation_error(self, app): + text = jwt.encode( + auth.HEADER, + {"email": "test@test.com", "type": "access", "exp": 0}, + auth.key, + ) + claims = auth.decode(text) + with pytest.raises(ExpiredTokenError): + claims.validate() + + text = jwt.encode( + auth.HEADER, + {"exp": dt.datetime.now(tz=dt.timezone.utc), "type": "access"}, + auth.key, + ) + claims = auth.decode(text) + with pytest.raises(MissingClaimError): + claims.validate() + + text = jwt.encode( + auth.HEADER, + {"exp": dt.datetime.now(tz=dt.timezone.utc), "email": "test@test.com"}, + auth.key, + ) + claims = auth.decode(text) + with pytest.raises(MissingClaimError): + claims.validate() + + @pytest.mark.parametrize("app", (HBATestConfig,), indirect=True) + def test_auth_login_required_http_basic_auth(self, app, users): + active_user_hba_creds = users["Active"]["hba_creds"] + active_user_invalid_hba_creds = base64.b64encode( + f'{users["Active"]["user"].email}:bad_pwd'.encode() + ).decode() + inactive_user_hba_creds = users["Inactive"]["hba_creds"] + active_user_jwt = users["Active"]["creds"] api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") @blp.route("/auth") @blp.login_required @blp.response(200) - def auth(): + def auth_func(): return get_current_user().name @blp.route("/no_auth") @@ -33,25 +147,55 @@ def no_auth(): headers = {} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "missing_authentication" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Broken auth headers + headers = {"Authorization": "Basic Dummy"} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "malformed_credentials" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + hba_creds = base64.b64encode(b"Dummy").decode() + headers = {"Authorization": "Basic " + hba_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "malformed_credentials" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Wrong scheme + headers = {"Authorization": active_user_jwt} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "invalid_scheme" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Inactive user - headers = {"Authorization": "Basic " + inactive_user_creds} + headers = {"Authorization": inactive_user_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 403 resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 - # Active user with invalid creds (bad password) - headers = {"Authorization": "Basic " + active_user_invalid_creds} + # Active user with invalid creds (wrong password) + headers = {"Authorization": "Basic " + active_user_invalid_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Basic" + assert resp.json["errors"]["authentication"] == "invalid_credentials" resp = client.get("/auth_test/no_auth", headers=headers) assert resp.status_code == 204 # Active user - headers = {"Authorization": "Basic " + active_user_creds} + headers = {"Authorization": active_user_hba_creds} resp = client.get("/auth_test/auth", headers=headers) assert resp.status_code == 200 resp = client.get("/auth_test/no_auth", headers=headers) @@ -59,35 +203,160 @@ def no_auth(): # Check OpenAPI spec spec = api.spec.to_dict() - assert spec["components"]["securitySchemes"]["BasicAuthentication"] == { - "type": "http", - "scheme": "basic", + assert spec["security"] == [{"BasicAuthentication": []}] + assert spec["components"]["securitySchemes"] == { + "BasicAuthentication": { + "type": "http", + "scheme": "basic", + } } auth_spec = spec["paths"]["/auth_test/auth"] assert auth_spec["get"]["responses"]["401"] == { "$ref": "#/components/responses/UNAUTHORIZED" } - assert auth_spec["get"]["security"] == [{"BasicAuthentication": []}] + assert "security" not in auth_spec["get"] no_auth_spec = spec["paths"]["/auth_test/no_auth"] assert "401" not in no_auth_spec["get"]["responses"] - assert "security" not in no_auth_spec["get"] + assert no_auth_spec["get"]["security"] == [] - def test_auth_current_user(self, app, users): - active_user_creds = users["Active"]["creds"] + @pytest.mark.parametrize("app", (JWTTestConfig,), indirect=True) + def test_auth_login_required_jwt(self, app, users): + user_1 = users["Active"]["user"] + active_user_jwt_creds = users["Active"]["creds"] + active_user_invalid_jwt_creds = jwt.encode( + auth.HEADER, {"email": "dummy@dummy.com", "type": "access"}, "Dummy" + ).decode() + inactive_user_jwt_creds = users["Inactive"]["creds"] + active_user_hba_creds = users["Active"]["hba_creds"] api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] - blp = Blueprint("AuthTest", __name__, url_prefix="/auth_test") - @blp.route("/user") + @blp.route("/auth") @blp.login_required @blp.response(200) - def user(): - return jsonify(blp.current_user().name) + def auth_func(): + return get_current_user().name + + @blp.route("/no_auth") + @blp.response(204) + def no_auth(): + return None api.register_blueprint(blp) client = app.test_client() + # Anonymous user + headers = {} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "missing_authentication" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Malformed token + headers = {"Authorization": "Bearer Dummy"} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Missing claims + creds = jwt.encode(auth.HEADER, {}, app.config["SECRET_KEY"]).decode() + headers = {"Authorization": "Bearer " + creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Unknown user + creds = jwt.encode( + auth.HEADER, + {"email": "dummy@dummy.com", "type": "access"}, + app.config["SECRET_KEY"], + ).decode() + headers = {"Authorization": "Bearer " + creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Broken header + headers = {"Authorization": "dummy"} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 400 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Wrong scheme + headers = {"Authorization": active_user_hba_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_scheme" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Bad signature + headers = {"Authorization": "Bearer " + active_user_invalid_jwt_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "invalid_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Expired token + headers = { + "Authorization": "Bearer " + + jwt.encode( + auth.HEADER, + {"email": user_1.email, "exp": 0, "type": "access"}, + app.config["SECRET_KEY"], + ).decode() + } + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 401 + assert resp.headers["WWW-Authenticate"] == "Bearer" + assert resp.json["errors"]["authentication"] == "expired_token" + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Inactive user + headers = {"Authorization": inactive_user_jwt_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 403 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + # Active user - headers = {"Authorization": "Basic " + active_user_creds} - resp = client.get("/auth_test/user", headers=headers) - assert resp.json == "Active" + headers = {"Authorization": active_user_jwt_creds} + resp = client.get("/auth_test/auth", headers=headers) + assert resp.status_code == 200 + resp = client.get("/auth_test/no_auth", headers=headers) + assert resp.status_code == 204 + + # Check OpenAPI spec + spec = api.spec.to_dict() + assert spec["security"] == [{"BearerAuthentication": []}] + assert spec["components"]["securitySchemes"] == { + "BearerAuthentication": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + } + } + auth_spec = spec["paths"]["/auth_test/auth"] + assert auth_spec["get"]["responses"]["401"] == { + "$ref": "#/components/responses/UNAUTHORIZED" + } + assert "security" not in auth_spec["get"] + no_auth_spec = spec["paths"]["/auth_test/no_auth"] + assert "401" not in no_auth_spec["get"]["responses"] + assert no_auth_spec["get"]["security"] == [] diff --git a/tests/extensions/test_integrity_error.py b/tests/extensions/test_integrity_error.py index bc01472..f589dcf 100644 --- a/tests/extensions/test_integrity_error.py +++ b/tests/extensions/test_integrity_error.py @@ -5,9 +5,7 @@ import psycopg.errors as ppe import sqlalchemy as sqla -import flask - -from bemserver_api import Api, Blueprint +from bemserver_api import Blueprint class TestIntegrityError: @@ -28,11 +26,8 @@ class TestIntegrityError: ), ), ) - def test_blp_integrity_error(self, error, message): - app = flask.Flask("Test") - api = Api( - app, spec_kwargs={"title": "Test", "version": "1", "openapi_version": "3"} - ) + def test_blp_integrity_error(self, app, error, message): + api = app.extensions["flask-smorest"]["apis"][""]["ext_obj"] blp = Blueprint("Test", __name__, url_prefix="/test") diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py new file mode 100644 index 0000000..f36552a --- /dev/null +++ b/tests/extensions/test_smorest.py @@ -0,0 +1,110 @@ +"""Test smorest extension""" + +import pytest + +from tests.common import AuthHeader, TestConfig, make_token + +from bemserver_api.extensions.authentication import auth + + +class HBATestConfig(TestConfig): + AUTH_METHODS = [ + "Basic", + ] + + +class TestSmorest: + def test_get_token(self, app, users): + user_1 = users["Active"]["user"] + user_2 = users["Inactive"]["user"] + + client = app.test_client() + payload = {"email": user_1.email, "password": "@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json["status"] == "success" + claims = auth.decode(resp.json["access_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + claims = auth.decode(resp.json["refresh_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + # Inactive user + client = app.test_client() + payload = {"email": user_2.email, "password": "in@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} + + # Inactive user + client = app.test_client() + payload = {"email": user_2.email, "password": "in@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} + + # Wrong password + client = app.test_client() + payload = {"email": user_1.email, "password": "dummy"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} + + # Wrong email + client = app.test_client() + payload = {"email": "dummy@dummy.com", "password": "dummy"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} + + def test_refresh_token(self, app, users): + user_1 = users["Active"]["user"] + + access_token = "Bearer " + make_token(user_1.email, "access") + refresh_token = "Bearer " + make_token(user_1.email, "refresh") + + client = app.test_client() + + # No token + resp = client.post("/auth/token/refresh") + assert resp.status_code == 401 + + # Acccess token + with AuthHeader(access_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 401 + + # Refresh token + with AuthHeader(refresh_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 200 + assert resp.json["status"] == "success" + claims = auth.decode(resp.json["access_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + claims = auth.decode(resp.json["refresh_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + @pytest.mark.parametrize("app", (HBATestConfig,), indirect=True) + def test_token_routes_jwt_disabled(self, app, users): + user_1 = users["Active"]["user"] + + client = app.test_client() + payload = {"email": user_1.email, "password": "@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 404 + + access_token = "Bearer " + make_token(user_1.email, "access") + with AuthHeader(access_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 404