Skip to content

Commit 5281c6e

Browse files
committed
Add refresh token feature
Also don't provide token routes when JWT auth is not used
1 parent 2d8cbc0 commit 5281c6e

File tree

6 files changed

+240
-34
lines changed

6 files changed

+240
-34
lines changed

src/bemserver_api/extensions/authentication.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import datetime as dt
5+
from datetime import datetime
56
from functools import wraps
67

78
import sqlalchemy as sqla
@@ -27,7 +28,8 @@ class Auth:
2728
"""Authentication and authorization management"""
2829

2930
HEADER = {"alg": "HS256"}
30-
TOKEN_LIFETIME = 900
31+
ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes
32+
REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months
3133

3234
GET_USER_FUNCS = {
3335
"Bearer": "get_user_jwt",
@@ -50,37 +52,54 @@ def init_app(self, app):
5052
if k in app.config["AUTH_METHODS"]
5153
}
5254

53-
def encode(self, user):
55+
def encode(self, user, token_type="access"):
56+
token_lifetime = (
57+
self.ACCESS_TOKEN_LIFETIME
58+
if token_type == "access"
59+
else self.REFRESH_TOKEN_LIFETIME
60+
)
5461
claims = {
5562
"email": user.email,
56-
"exp": dt.datetime.now(tz=dt.timezone.utc)
57-
+ dt.timedelta(seconds=self.TOKEN_LIFETIME),
63+
# datetime is imported in module namespace to allow test mock
64+
# kinda sucks, but oh well...
65+
"exp": datetime.now(tz=dt.timezone.utc)
66+
+ dt.timedelta(seconds=token_lifetime),
67+
"type": token_type,
5868
}
59-
return jwt.encode(self.HEADER, claims, self.key)
69+
return jwt.encode(self.HEADER.copy(), claims, self.key)
6070

6171
def decode(self, text):
62-
return jwt.decode(text, self.key, claims_options={"email": {"essential": True}})
72+
return jwt.decode(
73+
text,
74+
self.key,
75+
claims_options={
76+
"email": {"essential": True},
77+
"type": {"essential": True},
78+
},
79+
)
6380

6481
@staticmethod
6582
def get_user_by_email(user_email):
6683
return db.session.execute(
6784
sqla.select(User).where(User.email == user_email)
6885
).scalar()
6986

70-
def get_user_jwt(self, creds):
87+
def get_user_jwt(self, creds, refresh=False):
7188
try:
7289
claims = self.decode(creds)
7390
claims.validate()
7491
except ExpiredTokenError as exc:
7592
raise BEMServerAPIAuthenticationError(code="expired_token") from exc
7693
except JoseError as exc:
7794
raise BEMServerAPIAuthenticationError(code="invalid_token") from exc
95+
if refresh is not (claims["type"] == "refresh"):
96+
raise BEMServerAPIAuthenticationError(code="invalid_token")
7897
user_email = claims["email"]
7998
if (user := self.get_user_by_email(user_email)) is None:
8099
raise BEMServerAPIAuthenticationError(code="invalid_token")
81100
return user
82101

83-
def get_user_http_basic_auth(self, creds):
102+
def get_user_http_basic_auth(self, creds, **_kwargs):
84103
"""Check password and return User instance"""
85104
try:
86105
enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1)
@@ -94,7 +113,7 @@ def get_user_http_basic_auth(self, creds):
94113
raise BEMServerAPIAuthenticationError(code="invalid_credentials")
95114
return user
96115

97-
def get_user(self):
116+
def get_user(self, refresh=False):
98117
if (auth_header := flask.request.headers.get("Authorization")) is None:
99118
raise BEMServerAPIAuthenticationError(code="missing_authentication")
100119
try:
@@ -105,9 +124,9 @@ def get_user(self):
105124
func = self.get_user_funcs[scheme]
106125
except KeyError as exc:
107126
raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc
108-
return func(creds.encode("utf-8"))
127+
return func(creds.encode("utf-8"), refresh=refresh)
109128

110-
def login_required(self, f=None, **kwargs):
129+
def login_required(self, f=None, refresh=False):
111130
"""Decorator providing authentication and authorization
112131
113132
Uses JWT or HTTPBasicAuth.login_required to authenticate user
@@ -119,7 +138,7 @@ def decorator(func):
119138
@wraps(func)
120139
def wrapper(*args, **func_kwargs):
121140
try:
122-
user = self.get_user()
141+
user = self.get_user(refresh=refresh)
123142
except BEMServerAPIAuthenticationError as exc:
124143
abort(
125144
401,

src/bemserver_api/extensions/smorest.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from apispec.ext.marshmallow import MarshmallowPlugin
1313
from apispec.ext.marshmallow.common import resolve_schema_cls
1414

15+
from bemserver_core.authorization import get_current_user
16+
1517
from . import integrity_error
1618
from .authentication import auth
1719
from .ma_fields import Timezone
@@ -58,7 +60,8 @@ def init_app(self, app, *, spec_kwargs=None):
5860
]
5961
super().init_app(app, spec_kwargs=spec_kwargs)
6062
self.register_field(Timezone, "string", "iana-tz")
61-
self.register_blueprint(auth_blp)
63+
if "Bearer" in app.config["AUTH_METHODS"]:
64+
self.register_blueprint(auth_blp)
6265
for scheme in app.config["AUTH_METHODS"]:
6366
self.spec.components.security_scheme(*SECURITY_SCHEMES[scheme])
6467

@@ -168,7 +171,8 @@ class GetJWTArgsSchema(Schema):
168171

169172
class GetJWTRespSchema(Schema):
170173
status = ma.fields.String(validate=ma.validate.OneOf(("success", "failure")))
171-
token = ma.fields.String()
174+
access_token = ma.fields.String()
175+
refresh_token = ma.fields.String()
172176

173177

174178
@auth_blp.route("/token", methods=["POST"])
@@ -180,9 +184,17 @@ class GetJWTRespSchema(Schema):
180184
"success": {
181185
"value": {
182186
"status": "success",
183-
"token": (
184-
"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.u"
185-
"JKHM4XyWv1bC_-rpkjK19GUy0Fgrkm_pGHi8XghjWM"
187+
"access_token": (
188+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
189+
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
190+
"TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50"
191+
"7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4"
192+
),
193+
"refresh_token": (
194+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
195+
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
196+
"TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc"
197+
"SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ"
186198
),
187199
},
188200
},
@@ -194,8 +206,47 @@ class GetJWTRespSchema(Schema):
194206
},
195207
)
196208
def get_token(creds):
197-
"""Get an authentication token"""
209+
"""Get access and refresh tokens"""
198210
user = auth.get_user_by_email(creds["email"])
199211
if user is None or not user.check_password(creds["password"]) or not user.is_active:
200212
return flask.jsonify({"status": "failure"})
201-
return {"status": "success", "token": auth.encode(user).decode("utf-8")}
213+
return {
214+
"status": "success",
215+
"access_token": auth.encode(user).decode("utf-8"),
216+
"refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"),
217+
}
218+
219+
220+
@auth_blp.route("/token/refresh", methods=["POST"])
221+
@auth_blp.login_required(refresh=True)
222+
@auth_blp.response(
223+
200,
224+
GetJWTRespSchema,
225+
examples={
226+
"success": {
227+
"value": {
228+
"status": "success",
229+
"access_token": (
230+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
231+
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
232+
"TcxNjM2OTg4OCwidHlwZSI6ImFjY2VzcyJ9.YT-50"
233+
"7Qo9oncWKKRJhRXBbpLrOCYoJOMxbk1IaAQef4"
234+
),
235+
"refresh_token": (
236+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJl"
237+
"bWFpbCI6ImFjdGl2ZUB0ZXN0LmNvbSIsImV4cCI6M"
238+
"TcyMTU1MzEzNSwidHlwZSI6InJlZnJlc2gifQ._kc"
239+
"SHTzcngWIt-LRX6yBx8ftpekT_Dqo8qbPyfgFjSQ"
240+
),
241+
},
242+
},
243+
},
244+
)
245+
def refresh_token():
246+
"""Refresh access and refresh tokens"""
247+
user = get_current_user()
248+
return {
249+
"status": "success",
250+
"access_token": auth.encode(user).decode("utf-8"),
251+
"refresh_token": auth.encode(user, token_type="refresh").decode("utf-8"),
252+
}

tests/common.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import AbstractContextManager
22
from contextvars import ContextVar
33

4+
from bemserver_api.extensions.authentication import auth, jwt
45
from bemserver_api.settings import Config
56

67

@@ -25,3 +26,12 @@ def __enter__(self):
2526

2627
def __exit__(self, *args, **kwargs):
2728
AUTH_HEADER.reset(self.token)
29+
30+
31+
def make_token(user_email, token_type):
32+
# Make an access token with no expiration
33+
return jwt.encode(
34+
auth.HEADER.copy(),
35+
{"email": user_email, "type": token_type},
36+
TestConfig.SECRET_KEY,
37+
).decode()

tests/conftest.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from bemserver_core.database import db
1616

1717
import bemserver_api
18-
from bemserver_api.extensions.authentication import auth, jwt
19-
from tests.common import AUTH_HEADER, TestConfig
18+
from tests.common import AUTH_HEADER, TestConfig, make_token
2019

2120

2221
@pytest.fixture(scope="session", autouse=True)
@@ -102,12 +101,7 @@ def app(request, bsc_config, monkeypatch):
102101
"Basic "
103102
+ base64.b64encode(f'{user["email"]}:{user["password"]}'.encode()).decode()
104103
)
105-
user["creds"] = (
106-
"Bearer "
107-
+ jwt.encode(
108-
auth.HEADER, {"email": user["email"]}, TestConfig.SECRET_KEY
109-
).decode()
110-
)
104+
user["creds"] = "Bearer " + make_token(user["email"], "access")
111105

112106

113107
@pytest.fixture(params=(USERS,))

tests/extensions/test_authentication.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import datetime as dt
5+
from unittest import mock
56

67
import pytest
78

@@ -32,32 +33,86 @@ class JWTTestConfig(TestConfig):
3233

3334

3435
class TestAuthentication:
35-
def test_auth_encode_decode(self, app, users):
36+
@mock.patch("bemserver_api.extensions.authentication.datetime")
37+
@mock.patch("bemserver_api.extensions.authentication.jwt.encode")
38+
def test_auth_encode(self, mock_encode, mock_dt, app, users):
39+
dt_now = dt.datetime(2020, 1, 1, tzinfo=dt.timezone.utc)
40+
mock_dt.now.return_value = dt_now
41+
42+
user_1 = users["Active"]["user"]
43+
44+
auth.encode(user_1)
45+
mock_encode.assert_called()
46+
call_1 = mock_encode.call_args[0]
47+
assert call_1[0] == {"alg": "HS256"}
48+
assert call_1[1] == {
49+
"email": "active@test.com",
50+
"exp": dt_now + dt.timedelta(seconds=60 * 15),
51+
"type": "access",
52+
}
53+
assert call_1[2] == "Test secret"
54+
55+
auth.encode(user_1, token_type="refresh")
56+
mock_encode.assert_called()
57+
call_1 = mock_encode.call_args[0]
58+
assert call_1[0] == {"alg": "HS256"}
59+
assert call_1[1] == {
60+
"email": "active@test.com",
61+
"exp": dt_now + dt.timedelta(seconds=60 * 60 * 24 * 60),
62+
"type": "refresh",
63+
}
64+
assert call_1[2] == "Test secret"
65+
66+
def test_auth_decode(self, app, users):
3667
user_1 = users["Active"]["user"]
3768

3869
text = auth.encode(user_1)
3970
claims = auth.decode(text)
71+
assert claims["email"] == user_1.email
72+
assert "exp" in claims
73+
assert claims["type"] == "access"
74+
claims.validate()
4075

76+
text = auth.encode(user_1, token_type="refresh")
77+
claims = auth.decode(text)
4178
assert claims["email"] == user_1.email
4279
assert "exp" in claims
80+
assert claims["type"] == "refresh"
4381
claims.validate()
4482

4583
def test_auth_decode_error(self, app):
4684
with pytest.raises(DecodeError):
4785
auth.decode("dummy")
4886

49-
text = jwt.encode(auth.HEADER, {"email": "test@test.com"}, "Dummy")
87+
text = jwt.encode(
88+
auth.HEADER, {"email": "test@test.com", "type": "access"}, "Dummy"
89+
)
5090
with pytest.raises(BadSignatureError):
5191
auth.decode(text)
5292

5393
def test_auth_validation_error(self, app):
54-
text = jwt.encode(auth.HEADER, {"email": "test@test.com", "exp": 0}, auth.key)
94+
text = jwt.encode(
95+
auth.HEADER,
96+
{"email": "test@test.com", "type": "access", "exp": 0},
97+
auth.key,
98+
)
5599
claims = auth.decode(text)
56100
with pytest.raises(ExpiredTokenError):
57101
claims.validate()
58102

59103
text = jwt.encode(
60-
auth.HEADER, {"exp": dt.datetime.now(tz=dt.timezone.utc)}, auth.key
104+
auth.HEADER,
105+
{"exp": dt.datetime.now(tz=dt.timezone.utc), "type": "access"},
106+
auth.key,
107+
)
108+
claims = auth.decode(text)
109+
with pytest.raises(MissingClaimError):
110+
claims.validate()
111+
112+
text = jwt.encode(
113+
auth.HEADER,
114+
{"exp": dt.datetime.now(tz=dt.timezone.utc), "email": "test@test.com"},
115+
auth.key,
61116
)
62117
claims = auth.decode(text)
63118
with pytest.raises(MissingClaimError):
@@ -169,7 +224,7 @@ def test_auth_login_required_jwt(self, app, users):
169224
user_1 = users["Active"]["user"]
170225
active_user_jwt_creds = users["Active"]["creds"]
171226
active_user_invalid_jwt_creds = jwt.encode(
172-
auth.HEADER, {"email": "dummy@dummy.com"}, "Dummy"
227+
auth.HEADER, {"email": "dummy@dummy.com", "type": "access"}, "Dummy"
173228
).decode()
174229
inactive_user_jwt_creds = users["Inactive"]["creds"]
175230
active_user_hba_creds = users["Active"]["hba_creds"]
@@ -240,7 +295,9 @@ def no_auth():
240295
headers = {
241296
"Authorization": "Bearer "
242297
+ jwt.encode(
243-
auth.HEADER, {"email": user_1.email, "exp": 0}, app.config["SECRET_KEY"]
298+
auth.HEADER,
299+
{"email": user_1.email, "exp": 0, "type": "access"},
300+
app.config["SECRET_KEY"],
244301
).decode()
245302
}
246303
resp = client.get("/auth_test/auth", headers=headers)

0 commit comments

Comments
 (0)