From 506f2df41fa6035dfaabd44b9dbb3637762d814a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 22 May 2024 11:17:00 +0200 Subject: [PATCH] Add refresh token feature Also don't provide token routes when JWT auth is not used --- .../extensions/authentication.py | 43 ++++++++--- src/bemserver_api/extensions/smorest.py | 65 ++++++++++++++-- tests/common.py | 10 +++ tests/conftest.py | 10 +-- tests/extensions/test_authentication.py | 69 +++++++++++++++-- tests/extensions/test_smorest.py | 77 ++++++++++++++++++- 6 files changed, 240 insertions(+), 34 deletions(-) diff --git a/src/bemserver_api/extensions/authentication.py b/src/bemserver_api/extensions/authentication.py index ffd4adb..299581c 100644 --- a/src/bemserver_api/extensions/authentication.py +++ b/src/bemserver_api/extensions/authentication.py @@ -2,6 +2,7 @@ import base64 import datetime as dt +from datetime import datetime from functools import wraps import sqlalchemy as sqla @@ -27,7 +28,8 @@ class Auth: """Authentication and authorization management""" HEADER = {"alg": "HS256"} - TOKEN_LIFETIME = 900 + ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes + REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months GET_USER_FUNCS = { "Bearer": "get_user_jwt", @@ -50,16 +52,31 @@ def init_app(self, app): if k in app.config["AUTH_METHODS"] } - def encode(self, user): + def encode(self, user, token_type="access"): + token_lifetime = ( + self.ACCESS_TOKEN_LIFETIME + if token_type == "access" + else self.REFRESH_TOKEN_LIFETIME + ) claims = { "email": user.email, - "exp": dt.datetime.now(tz=dt.timezone.utc) - + dt.timedelta(seconds=self.TOKEN_LIFETIME), + # datetime is imported in module namespace to allow test mock + # kinda sucks, but oh well... + "exp": datetime.now(tz=dt.timezone.utc) + + dt.timedelta(seconds=token_lifetime), + "type": token_type, } - return jwt.encode(self.HEADER, claims, self.key) + return jwt.encode(self.HEADER.copy(), claims, self.key) def decode(self, text): - return jwt.decode(text, self.key, claims_options={"email": {"essential": True}}) + return jwt.decode( + text, + self.key, + claims_options={ + "email": {"essential": True}, + "type": {"essential": True}, + }, + ) @staticmethod def get_user_by_email(user_email): @@ -67,7 +84,7 @@ def get_user_by_email(user_email): sqla.select(User).where(User.email == user_email) ).scalar() - def get_user_jwt(self, creds): + def get_user_jwt(self, creds, refresh=False): try: claims = self.decode(creds) claims.validate() @@ -75,12 +92,14 @@ def get_user_jwt(self, creds): raise BEMServerAPIAuthenticationError(code="expired_token") from exc except JoseError as exc: raise BEMServerAPIAuthenticationError(code="invalid_token") from exc + if refresh is not (claims["type"] == "refresh"): + raise BEMServerAPIAuthenticationError(code="invalid_token") user_email = claims["email"] if (user := self.get_user_by_email(user_email)) is None: raise BEMServerAPIAuthenticationError(code="invalid_token") return user - def get_user_http_basic_auth(self, creds): + def get_user_http_basic_auth(self, creds, **_kwargs): """Check password and return User instance""" try: enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) @@ -94,7 +113,7 @@ def get_user_http_basic_auth(self, creds): raise BEMServerAPIAuthenticationError(code="invalid_credentials") return user - def get_user(self): + def get_user(self, refresh=False): if (auth_header := flask.request.headers.get("Authorization")) is None: raise BEMServerAPIAuthenticationError(code="missing_authentication") try: @@ -105,9 +124,9 @@ def get_user(self): func = self.get_user_funcs[scheme] except KeyError as exc: raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc - return func(creds.encode("utf-8")) + return func(creds.encode("utf-8"), refresh=refresh) - def login_required(self, f=None, **kwargs): + def login_required(self, f=None, refresh=False): """Decorator providing authentication and authorization Uses JWT or HTTPBasicAuth.login_required to authenticate user @@ -119,7 +138,7 @@ def decorator(func): @wraps(func) def wrapper(*args, **func_kwargs): try: - user = self.get_user() + user = self.get_user(refresh=refresh) except BEMServerAPIAuthenticationError as exc: abort( 401, diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 54d78ae..ee03242 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -12,6 +12,8 @@ from apispec.ext.marshmallow import MarshmallowPlugin from apispec.ext.marshmallow.common import resolve_schema_cls +from bemserver_core.authorization import get_current_user + from . import integrity_error from .authentication import auth from .ma_fields import Timezone @@ -58,7 +60,8 @@ def init_app(self, app, *, spec_kwargs=None): ] super().init_app(app, spec_kwargs=spec_kwargs) self.register_field(Timezone, "string", "iana-tz") - self.register_blueprint(auth_blp) + if "Bearer" in app.config["AUTH_METHODS"]: + self.register_blueprint(auth_blp) for scheme in app.config["AUTH_METHODS"]: self.spec.components.security_scheme(*SECURITY_SCHEMES[scheme]) @@ -168,7 +171,8 @@ class GetJWTArgsSchema(Schema): class GetJWTRespSchema(Schema): status = ma.fields.String(validate=ma.validate.OneOf(("success", "failure"))) - token = ma.fields.String() + access_token = ma.fields.String() + refresh_token = ma.fields.String() @auth_blp.route("/token", methods=["POST"]) @@ -180,9 +184,17 @@ class GetJWTRespSchema(Schema): "success": { "value": { "status": "success", - "token": ( - "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u" - "JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM" + "access_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50" + "7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4" + ), + "refresh_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc" + "SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ" ), }, }, @@ -194,8 +206,47 @@ class GetJWTRespSchema(Schema): }, ) def get_token(creds): - """Get an authentication token""" + """Get access and refresh tokens""" user = auth.get_user_by_email(creds["email"]) if user is None or not user.check_password(creds["password"]) or not user.is_active: return flask.jsonify({"status": "failure"}) - return {"status": "success", "token": auth.encode(user).decode("utf-8")} + return { + "status": "success", + "access_token": auth.encode(user).decode("utf-8"), + "refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"), + } + + +@auth_blp.route("/token/refresh", methods=["POST"]) +@auth_blp.login_required(refresh=True) +@auth_blp.response( + 200, + GetJWTRespSchema, + examples={ + "success": { + "value": { + "status": "success", + "access_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50" + "7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4" + ), + "refresh_token": ( + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl" + "bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M" + "TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc" + "SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ" + ), + }, + }, + }, +) +def refresh_token(): + """Refresh access and refresh tokens""" + user = get_current_user() + return { + "status": "success", + "access_token": auth.encode(user).decode("utf-8"), + "refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"), + } diff --git a/tests/common.py b/tests/common.py index 0a664e4..241dc30 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,6 +1,7 @@ from contextlib import AbstractContextManager from contextvars import ContextVar +from bemserver_api.extensions.authentication import auth, jwt from bemserver_api.settings import Config @@ -25,3 +26,12 @@ def __enter__(self): def __exit__(self, *args, **kwargs): AUTH_HEADER.reset(self.token) + + +def make_token(user_email, token_type): + # Make an access token with no expiration + return jwt.encode( + auth.HEADER.copy(), + {"email": user_email, "type": token_type}, + TestConfig.SECRET_KEY, + ).decode() diff --git a/tests/conftest.py b/tests/conftest.py index 7c5d140..42dd8de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,8 +15,7 @@ from bemserver_core.database import db import bemserver_api -from bemserver_api.extensions.authentication import auth, jwt -from tests.common import AUTH_HEADER, TestConfig +from tests.common import AUTH_HEADER, TestConfig, make_token @pytest.fixture(scope="session", autouse=True) @@ -102,12 +101,7 @@ def app(request, bsc_config, monkeypatch): "Basic " + base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode() ) - user["creds"] = ( - "Bearer " - + jwt.encode( - auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY - ).decode() - ) + user["creds"] = "Bearer " + make_token(user["email"], "access") @pytest.fixture(params=(USERS,)) diff --git a/tests/extensions/test_authentication.py b/tests/extensions/test_authentication.py index 7658a45..459c559 100644 --- a/tests/extensions/test_authentication.py +++ b/tests/extensions/test_authentication.py @@ -2,6 +2,7 @@ import base64 import datetime as dt +from unittest import mock import pytest @@ -32,32 +33,86 @@ class JWTTestConfig(TestConfig): class TestAuthentication: - def test_auth_encode_decode(self, app, users): + @mock.patch("bemserver_api.extensions.authentication.datetime") + @mock.patch("bemserver_api.extensions.authentication.jwt.encode") + def test_auth_encode(self, mock_encode, mock_dt, app, users): + dt_now = dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc) + mock_dt.now.return_value = dt_now + + user_1 = users["Active"]["user"] + + auth.encode(user_1) + mock_encode.assert_called() + call_1 = mock_encode.call_args[0] + assert call_1[0] == {"alg": "HS256"} + assert call_1[1] == { + "email": "active@test.com", + "exp": dt_now + dt.timedelta(seconds=60 * 15), + "type": "access", + } + assert call_1[2] == "Test secret" + + auth.encode(user_1, token_type="refresh") + mock_encode.assert_called() + call_1 = mock_encode.call_args[0] + assert call_1[0] == {"alg": "HS256"} + assert call_1[1] == { + "email": "active@test.com", + "exp": dt_now + dt.timedelta(seconds=60 * 60 * 24 * 60), + "type": "refresh", + } + assert call_1[2] == "Test secret" + + def test_auth_decode(self, app, users): user_1 = users["Active"]["user"] text = auth.encode(user_1) claims = auth.decode(text) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + text = auth.encode(user_1, token_type="refresh") + claims = auth.decode(text) assert claims["email"] == user_1.email assert "exp" in claims + assert claims["type"] == "refresh" claims.validate() def test_auth_decode_error(self, app): with pytest.raises(DecodeError): auth.decode("dummy") - text = jwt.encode(auth.HEADER, {"email": "test@test.com"}, "Dummy") + text = jwt.encode( + auth.HEADER, {"email": "test@test.com", "type": "access"}, "Dummy" + ) with pytest.raises(BadSignatureError): auth.decode(text) def test_auth_validation_error(self, app): - text = jwt.encode(auth.HEADER, {"email": "test@test.com", "exp": 0}, auth.key) + text = jwt.encode( + auth.HEADER, + {"email": "test@test.com", "type": "access", "exp": 0}, + auth.key, + ) claims = auth.decode(text) with pytest.raises(ExpiredTokenError): claims.validate() text = jwt.encode( - auth.HEADER, {"exp": dt.datetime.now(tz=dt.timezone.utc)}, auth.key + auth.HEADER, + {"exp": dt.datetime.now(tz=dt.timezone.utc), "type": "access"}, + auth.key, + ) + claims = auth.decode(text) + with pytest.raises(MissingClaimError): + claims.validate() + + text = jwt.encode( + auth.HEADER, + {"exp": dt.datetime.now(tz=dt.timezone.utc), "email": "test@test.com"}, + auth.key, ) claims = auth.decode(text) with pytest.raises(MissingClaimError): @@ -169,7 +224,7 @@ def test_auth_login_required_jwt(self, app, users): user_1 = users["Active"]["user"] active_user_jwt_creds = users["Active"]["creds"] active_user_invalid_jwt_creds = jwt.encode( - auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy" + auth.HEADER, {"email": "dummy@dummy.com", "type": "access"}, "Dummy" ).decode() inactive_user_jwt_creds = users["Inactive"]["creds"] active_user_hba_creds = users["Active"]["hba_creds"] @@ -240,7 +295,9 @@ def no_auth(): headers = { "Authorization": "Bearer " + jwt.encode( - auth.HEADER, {"email": user_1.email, "exp": 0}, app.config["SECRET_KEY"] + auth.HEADER, + {"email": user_1.email, "exp": 0, "type": "access"}, + app.config["SECRET_KEY"], ).decode() } resp = client.get("/auth_test/auth", headers=headers) diff --git a/tests/extensions/test_smorest.py b/tests/extensions/test_smorest.py index 993a373..f36552a 100644 --- a/tests/extensions/test_smorest.py +++ b/tests/extensions/test_smorest.py @@ -1,5 +1,17 @@ """Test smorest extension""" +import pytest + +from tests.common import AuthHeader, TestConfig, make_token + +from bemserver_api.extensions.authentication import auth + + +class HBATestConfig(TestConfig): + AUTH_METHODS = [ + "Basic", + ] + class TestSmorest: def test_get_token(self, app, users): @@ -11,7 +23,23 @@ def test_get_token(self, app, users): resp = client.post("/auth/token", json=payload) assert resp.status_code == 200 assert resp.json["status"] == "success" - assert "token" in resp.json + claims = auth.decode(resp.json["access_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + claims = auth.decode(resp.json["refresh_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + # Inactive user + client = app.test_client() + payload = {"email": user_2.email, "password": "in@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 200 + assert resp.json == {"status": "failure"} # Inactive user client = app.test_client() @@ -33,3 +61,50 @@ def test_get_token(self, app, users): resp = client.post("/auth/token", json=payload) assert resp.status_code == 200 assert resp.json == {"status": "failure"} + + def test_refresh_token(self, app, users): + user_1 = users["Active"]["user"] + + access_token = "Bearer " + make_token(user_1.email, "access") + refresh_token = "Bearer " + make_token(user_1.email, "refresh") + + client = app.test_client() + + # No token + resp = client.post("/auth/token/refresh") + assert resp.status_code == 401 + + # Acccess token + with AuthHeader(access_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 401 + + # Refresh token + with AuthHeader(refresh_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 200 + assert resp.json["status"] == "success" + claims = auth.decode(resp.json["access_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "access" + claims.validate() + claims = auth.decode(resp.json["refresh_token"]) + assert claims["email"] == user_1.email + assert "exp" in claims + assert claims["type"] == "refresh" + claims.validate() + + @pytest.mark.parametrize("app", (HBATestConfig,), indirect=True) + def test_token_routes_jwt_disabled(self, app, users): + user_1 = users["Active"]["user"] + + client = app.test_client() + payload = {"email": user_1.email, "password": "@ctive"} + resp = client.post("/auth/token", json=payload) + assert resp.status_code == 404 + + access_token = "Bearer " + make_token(user_1.email, "access") + with AuthHeader(access_token): + resp = client.post("/auth/token/refresh") + assert resp.status_code == 404