Skip to content

Commit

Permalink
Add refresh token feature
Browse files Browse the repository at this point in the history
Also don't provide token routes when JWT auth is not used
  • Loading branch information
lafrech committed Jun 3, 2024
1 parent 8f0072b commit 506f2df
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 34 deletions.
43 changes: 31 additions & 12 deletions src/bemserver_api/extensions/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import datetime as dt
from datetime import datetime
from functools import wraps

import sqlalchemy as sqla
Expand All @@ -27,7 +28,8 @@ class Auth:
"""Authentication and authorization management"""

HEADER = {"alg": "HS256"}
TOKEN_LIFETIME = 900
ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes
REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months

GET_USER_FUNCS = {
"Bearer": "get_user_jwt",
Expand All @@ -50,37 +52,54 @@ def init_app(self, app):
if k in app.config["AUTH_METHODS"]
}

def encode(self, user):
def encode(self, user, token_type="access"):
token_lifetime = (
self.ACCESS_TOKEN_LIFETIME
if token_type == "access"
else self.REFRESH_TOKEN_LIFETIME
)
claims = {
"email": user.email,
"exp": dt.datetime.now(tz=dt.timezone.utc)
+ dt.timedelta(seconds=self.TOKEN_LIFETIME),
# datetime is imported in module namespace to allow test mock
# kinda sucks, but oh well...
"exp": datetime.now(tz=dt.timezone.utc)
+ dt.timedelta(seconds=token_lifetime),
"type": token_type,
}
return jwt.encode(self.HEADER, claims, self.key)
return jwt.encode(self.HEADER.copy(), claims, self.key)

def decode(self, text):
return jwt.decode(text, self.key, claims_options={"email": {"essential": True}})
return jwt.decode(
text,
self.key,
claims_options={
"email": {"essential": True},
"type": {"essential": True},
},
)

@staticmethod
def get_user_by_email(user_email):
return db.session.execute(
sqla.select(User).where(User.email == user_email)
).scalar()

def get_user_jwt(self, creds):
def get_user_jwt(self, creds, refresh=False):
try:
claims = self.decode(creds)
claims.validate()
except ExpiredTokenError as exc:
raise BEMServerAPIAuthenticationError(code="expired_token") from exc
except JoseError as exc:
raise BEMServerAPIAuthenticationError(code="invalid_token") from exc
if refresh is not (claims["type"] == "refresh"):
raise BEMServerAPIAuthenticationError(code="invalid_token")
user_email = claims["email"]
if (user := self.get_user_by_email(user_email)) is None:
raise BEMServerAPIAuthenticationError(code="invalid_token")

Check warning on line 99 in src/bemserver_api/extensions/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/bemserver_api/extensions/authentication.py#L99

Added line #L99 was not covered by tests
return user

def get_user_http_basic_auth(self, creds):
def get_user_http_basic_auth(self, creds, **_kwargs):
"""Check password and return User instance"""
try:
enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1)
Expand All @@ -94,7 +113,7 @@ def get_user_http_basic_auth(self, creds):
raise BEMServerAPIAuthenticationError(code="invalid_credentials")
return user

def get_user(self):
def get_user(self, refresh=False):
if (auth_header := flask.request.headers.get("Authorization")) is None:
raise BEMServerAPIAuthenticationError(code="missing_authentication")
try:
Expand All @@ -105,9 +124,9 @@ def get_user(self):
func = self.get_user_funcs[scheme]
except KeyError as exc:
raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc
return func(creds.encode("utf-8"))
return func(creds.encode("utf-8"), refresh=refresh)

def login_required(self, f=None, **kwargs):
def login_required(self, f=None, refresh=False):
"""Decorator providing authentication and authorization
Uses JWT or HTTPBasicAuth.login_required to authenticate user
Expand All @@ -119,7 +138,7 @@ def decorator(func):
@wraps(func)
def wrapper(*args, **func_kwargs):
try:
user = self.get_user()
user = self.get_user(refresh=refresh)
except BEMServerAPIAuthenticationError as exc:
abort(
401,
Expand Down
65 changes: 58 additions & 7 deletions src/bemserver_api/extensions/smorest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from apispec.ext.marshmallow import MarshmallowPlugin
from apispec.ext.marshmallow.common import resolve_schema_cls

from bemserver_core.authorization import get_current_user

from . import integrity_error
from .authentication import auth
from .ma_fields import Timezone
Expand Down Expand Up @@ -58,7 +60,8 @@ def init_app(self, app, *, spec_kwargs=None):
]
super().init_app(app, spec_kwargs=spec_kwargs)
self.register_field(Timezone, "string", "iana-tz")
self.register_blueprint(auth_blp)
if "Bearer" in app.config["AUTH_METHODS"]:
self.register_blueprint(auth_blp)
for scheme in app.config["AUTH_METHODS"]:
self.spec.components.security_scheme(*SECURITY_SCHEMES[scheme])

Expand Down Expand Up @@ -168,7 +171,8 @@ class GetJWTArgsSchema(Schema):

class GetJWTRespSchema(Schema):
status = ma.fields.String(validate=ma.validate.OneOf(("success", "failure")))
token = ma.fields.String()
access_token = ma.fields.String()
refresh_token = ma.fields.String()


@auth_blp.route("/token", methods=["POST"])
Expand All @@ -180,9 +184,17 @@ class GetJWTRespSchema(Schema):
"success": {
"value": {
"status": "success",
"token": (
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u"
"JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM"
"access_token": (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
"TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50"
"7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4"
),
"refresh_token": (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
"TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc"
"SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ"
),
},
},
Expand All @@ -194,8 +206,47 @@ class GetJWTRespSchema(Schema):
},
)
def get_token(creds):
"""Get an authentication token"""
"""Get access and refresh tokens"""
user = auth.get_user_by_email(creds["email"])
if user is None or not user.check_password(creds["password"]) or not user.is_active:
return flask.jsonify({"status": "failure"})
return {"status": "success", "token": auth.encode(user).decode("utf-8")}
return {
"status": "success",
"access_token": auth.encode(user).decode("utf-8"),
"refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"),
}


@auth_blp.route("/token/refresh", methods=["POST"])
@auth_blp.login_required(refresh=True)
@auth_blp.response(
200,
GetJWTRespSchema,
examples={
"success": {
"value": {
"status": "success",
"access_token": (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
"TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50"
"7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4"
),
"refresh_token": (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
"TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc"
"SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ"
),
},
},
},
)
def refresh_token():
"""Refresh access and refresh tokens"""
user = get_current_user()
return {
"status": "success",
"access_token": auth.encode(user).decode("utf-8"),
"refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"),
}
10 changes: 10 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import AbstractContextManager
from contextvars import ContextVar

from bemserver_api.extensions.authentication import auth, jwt
from bemserver_api.settings import Config


Expand All @@ -25,3 +26,12 @@ def __enter__(self):

def __exit__(self, *args, **kwargs):
AUTH_HEADER.reset(self.token)


def make_token(user_email, token_type):
# Make an access token with no expiration
return jwt.encode(
auth.HEADER.copy(),
{"email": user_email, "type": token_type},
TestConfig.SECRET_KEY,
).decode()
10 changes: 2 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from bemserver_core.database import db

import bemserver_api
from bemserver_api.extensions.authentication import auth, jwt
from tests.common import AUTH_HEADER, TestConfig
from tests.common import AUTH_HEADER, TestConfig, make_token


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -102,12 +101,7 @@ def app(request, bsc_config, monkeypatch):
"Basic "
+ base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode()
)
user["creds"] = (
"Bearer "
+ jwt.encode(
auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY
).decode()
)
user["creds"] = "Bearer " + make_token(user["email"], "access")


@pytest.fixture(params=(USERS,))
Expand Down
69 changes: 63 additions & 6 deletions tests/extensions/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import datetime as dt
from unittest import mock

import pytest

Expand Down Expand Up @@ -32,32 +33,86 @@ class JWTTestConfig(TestConfig):


class TestAuthentication:
def test_auth_encode_decode(self, app, users):
@mock.patch("bemserver_api.extensions.authentication.datetime")
@mock.patch("bemserver_api.extensions.authentication.jwt.encode")
def test_auth_encode(self, mock_encode, mock_dt, app, users):
dt_now = dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc)
mock_dt.now.return_value = dt_now

user_1 = users["Active"]["user"]

auth.encode(user_1)
mock_encode.assert_called()
call_1 = mock_encode.call_args[0]
assert call_1[0] == {"alg": "HS256"}
assert call_1[1] == {
"email": "active@test.com",
"exp": dt_now + dt.timedelta(seconds=60 * 15),
"type": "access",
}
assert call_1[2] == "Test secret"

auth.encode(user_1, token_type="refresh")
mock_encode.assert_called()
call_1 = mock_encode.call_args[0]
assert call_1[0] == {"alg": "HS256"}
assert call_1[1] == {
"email": "active@test.com",
"exp": dt_now + dt.timedelta(seconds=60 * 60 * 24 * 60),
"type": "refresh",
}
assert call_1[2] == "Test secret"

def test_auth_decode(self, app, users):
user_1 = users["Active"]["user"]

text = auth.encode(user_1)
claims = auth.decode(text)
assert claims["email"] == user_1.email
assert "exp" in claims
assert claims["type"] == "access"
claims.validate()

text = auth.encode(user_1, token_type="refresh")
claims = auth.decode(text)
assert claims["email"] == user_1.email
assert "exp" in claims
assert claims["type"] == "refresh"
claims.validate()

def test_auth_decode_error(self, app):
with pytest.raises(DecodeError):
auth.decode("dummy")

text = jwt.encode(auth.HEADER, {"email": "test@test.com"}, "Dummy")
text = jwt.encode(
auth.HEADER, {"email": "test@test.com", "type": "access"}, "Dummy"
)
with pytest.raises(BadSignatureError):
auth.decode(text)

def test_auth_validation_error(self, app):
text = jwt.encode(auth.HEADER, {"email": "test@test.com", "exp": 0}, auth.key)
text = jwt.encode(
auth.HEADER,
{"email": "test@test.com", "type": "access", "exp": 0},
auth.key,
)
claims = auth.decode(text)
with pytest.raises(ExpiredTokenError):
claims.validate()

text = jwt.encode(
auth.HEADER, {"exp": dt.datetime.now(tz=dt.timezone.utc)}, auth.key
auth.HEADER,
{"exp": dt.datetime.now(tz=dt.timezone.utc), "type": "access"},
auth.key,
)
claims = auth.decode(text)
with pytest.raises(MissingClaimError):
claims.validate()

text = jwt.encode(
auth.HEADER,
{"exp": dt.datetime.now(tz=dt.timezone.utc), "email": "test@test.com"},
auth.key,
)
claims = auth.decode(text)
with pytest.raises(MissingClaimError):
Expand Down Expand Up @@ -169,7 +224,7 @@ def test_auth_login_required_jwt(self, app, users):
user_1 = users["Active"]["user"]
active_user_jwt_creds = users["Active"]["creds"]
active_user_invalid_jwt_creds = jwt.encode(
auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy"
auth.HEADER, {"email": "dummy@dummy.com", "type": "access"}, "Dummy"
).decode()
inactive_user_jwt_creds = users["Inactive"]["creds"]
active_user_hba_creds = users["Active"]["hba_creds"]
Expand Down Expand Up @@ -240,7 +295,9 @@ def no_auth():
headers = {
"Authorization": "Bearer "
+ jwt.encode(
auth.HEADER, {"email": user_1.email, "exp": 0}, app.config["SECRET_KEY"]
auth.HEADER,
{"email": user_1.email, "exp": 0, "type": "access"},
app.config["SECRET_KEY"],
).decode()
}
resp = client.get("/auth_test/auth", headers=headers)
Expand Down
Loading

0 comments on commit 506f2df

Please sign in to comment.