Skip to content

Commit 22a18c1

Browse files
authored
Merge pull request #233 from BEMServer/jwt
Add Bearer (JWT) authentication
2 parents 5d1a45e + 9891cf1 commit 22a18c1

File tree

12 files changed

+765
-87
lines changed

12 files changed

+765
-87
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"marshmallow-sqlalchemy>=0.29.0,<0.30",
3333
"flask_smorest>=0.43.0,<0.44",
3434
"apispec>=6.1.0,<7.0",
35-
"flask-httpauth>=4.7.0,<5.0",
35+
"authlib>=1.3.0,<2.0",
3636
"bemserver-core>=0.17.1,<0.18",
3737
]
3838

@@ -77,7 +77,7 @@ section-order = ["future", "standard-library", "testing", "db", "pallets", "mars
7777
[tool.ruff.lint.isort.sections]
7878
testing = ["pytest", "pytest_postgresql"]
7979
db = ["psycopg", "sqlalchemy", "alembic"]
80-
pallets = ["werkzeug", "flask", "flask_httpauth"]
80+
pallets = ["werkzeug", "flask"]
8181
marshmallow = ["marshmallow", "marshmallow_sqlalchemy", "webargs", "apispec", "flask_smorest"]
8282
science = ["numpy", "pandas"]
8383
core = ["bemserver_core"]

requirements/install.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ argon2-cffi-bindings==21.2.0
1818
# via argon2-cffi
1919
async-timeout==4.0.3
2020
# via redis
21+
authlib==1.3.0
22+
# via bemserver-api (pyproject.toml)
2123
bemserver-core==0.17.1
2224
# via bemserver-api (pyproject.toml)
2325
billiard==4.2.0
@@ -31,6 +33,7 @@ certifi==2024.2.2
3133
cffi==1.16.0
3234
# via
3335
# argon2-cffi-bindings
36+
# cryptography
3437
# oso
3538
charset-normalizer==3.3.2
3639
# via requests
@@ -48,13 +51,12 @@ click-plugins==1.1.1
4851
# via celery
4952
click-repl==0.3.0
5053
# via celery
54+
cryptography==42.0.5
55+
# via authlib
5156
flask==3.0.2
5257
# via
5358
# bemserver-api (pyproject.toml)
54-
# flask-httpauth
5559
# flask-smorest
56-
flask-httpauth==4.8.0
57-
# via bemserver-api (pyproject.toml)
5860
flask-smorest==0.43.0
5961
# via bemserver-api (pyproject.toml)
6062
greenlet==3.0.3

src/bemserver_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def create_app():
3636
}
3737
)
3838
api.init_app(app)
39+
authentication.auth.init_app(app)
3940
register_blueprints(api)
4041

4142
BEMServerCore()

src/bemserver_api/exceptions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Exceptions"""
2+
3+
4+
class BEMServerAPIError(Exception):
5+
"""Base BEMServer API exception"""
6+
7+
8+
class BEMServerAPIAuthenticationError(BEMServerAPIError):
9+
"""AuthenticationError error"""
10+
11+
def __init__(self, code):
12+
self.code = code

src/bemserver_api/extensions/authentication.py

Lines changed: 128 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,167 @@
11
"""Authentication"""
22

3+
import base64
4+
import datetime as dt
5+
from datetime import datetime
36
from functools import wraps
47

58
import sqlalchemy as sqla
69

7-
from flask_httpauth import HTTPBasicAuth
10+
import flask
811

912
from flask_smorest import abort
1013

14+
from authlib.jose import JsonWebToken
15+
from authlib.jose.errors import ExpiredTokenError, JoseError
16+
1117
from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser
1218
from bemserver_core.model.users import User
1319

1420
from bemserver_api.database import db
21+
from bemserver_api.exceptions import BEMServerAPIAuthenticationError
22+
23+
# https://docs.authlib.org/en/latest/jose/jwt.html#jwt-with-limited-algorithms
24+
jwt = JsonWebToken(["HS256"])
1525

1626

17-
class Auth(HTTPBasicAuth):
27+
class Auth:
1828
"""Authentication and authorization management"""
1929

20-
def login_required(self, f=None, **kwargs):
30+
HEADER = {"alg": "HS256"}
31+
ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes
32+
REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months
33+
34+
GET_USER_FUNCS = {
35+
"Bearer": "get_user_jwt",
36+
"Basic": "get_user_http_basic_auth",
37+
}
38+
39+
def __init__(self, app=None):
40+
self.key = None
41+
self.app = None
42+
self.get_user_funcs = None
43+
if app is not None:
44+
self.init_app(app)
45+
46+
def init_app(self, app):
47+
self.app = app
48+
self.key = app.config["SECRET_KEY"]
49+
self.get_user_funcs = {
50+
k: getattr(self, v)
51+
for k, v in self.GET_USER_FUNCS.items()
52+
if k in app.config["AUTH_METHODS"]
53+
}
54+
55+
def encode(self, user, token_type="access"):
56+
token_lifetime = (
57+
self.ACCESS_TOKEN_LIFETIME
58+
if token_type == "access"
59+
else self.REFRESH_TOKEN_LIFETIME
60+
)
61+
claims = {
62+
"email": user.email,
63+
# datetime is imported in module namespace to allow test mock
64+
# kinda sucks, but oh well...
65+
"exp": datetime.now(tz=dt.timezone.utc)
66+
+ dt.timedelta(seconds=token_lifetime),
67+
"type": token_type,
68+
}
69+
return jwt.encode(self.HEADER.copy(), claims, self.key)
70+
71+
def decode(self, text):
72+
return jwt.decode(
73+
text,
74+
self.key,
75+
claims_options={
76+
"email": {"essential": True},
77+
"type": {"essential": True},
78+
},
79+
)
80+
81+
@staticmethod
82+
def get_user_by_email(user_email):
83+
return db.session.execute(
84+
sqla.select(User).where(User.email == user_email)
85+
).scalar()
86+
87+
def get_user_jwt(self, creds, refresh=False):
88+
try:
89+
claims = self.decode(creds)
90+
claims.validate()
91+
except ExpiredTokenError as exc:
92+
raise BEMServerAPIAuthenticationError(code="expired_token") from exc
93+
except JoseError as exc:
94+
raise BEMServerAPIAuthenticationError(code="invalid_token") from exc
95+
if refresh is not (claims["type"] == "refresh"):
96+
raise BEMServerAPIAuthenticationError(code="invalid_token")
97+
user_email = claims["email"]
98+
if (user := self.get_user_by_email(user_email)) is None:
99+
raise BEMServerAPIAuthenticationError(code="invalid_token")
100+
return user
101+
102+
def get_user_http_basic_auth(self, creds, **_kwargs):
103+
"""Check password and return User instance"""
104+
try:
105+
enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1)
106+
user_email = enc_email.decode()
107+
password = enc_password.decode()
108+
except (ValueError, TypeError) as exc:
109+
raise BEMServerAPIAuthenticationError(code="malformed_credentials") from exc
110+
if (user := self.get_user_by_email(user_email)) is None:
111+
raise BEMServerAPIAuthenticationError(code="invalid_credentials")
112+
if not user.check_password(password):
113+
raise BEMServerAPIAuthenticationError(code="invalid_credentials")
114+
return user
115+
116+
def get_user(self, refresh=False):
117+
if (auth_header := flask.request.headers.get("Authorization")) is None:
118+
raise BEMServerAPIAuthenticationError(code="missing_authentication")
119+
try:
120+
scheme, creds = auth_header.split(" ", maxsplit=1)
121+
except ValueError:
122+
abort(400)
123+
try:
124+
func = self.get_user_funcs[scheme]
125+
except KeyError as exc:
126+
raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc
127+
return func(creds.encode("utf-8"), refresh=refresh)
128+
129+
def login_required(self, f=None, refresh=False):
21130
"""Decorator providing authentication and authorization
22131
23-
Uses HTTPBasicAuth.login_required authenticate user
132+
Uses JWT or HTTPBasicAuth.login_required to authenticate user
24133
Sets CurrentUser context variable to authenticated user for the request
25134
Catches Authorization error and aborts accordingly
26135
"""
27136

28137
def decorator(func):
29138
@wraps(func)
30139
def wrapper(*args, **func_kwargs):
31-
with CurrentUser(self.current_user()):
140+
try:
141+
user = self.get_user(refresh=refresh)
142+
except BEMServerAPIAuthenticationError as exc:
143+
abort(
144+
401,
145+
"Authentication error",
146+
errors={"authentication": exc.code},
147+
headers={
148+
"WWW-Authenticate": ", ".join(
149+
self.app.config["AUTH_METHODS"]
150+
)
151+
},
152+
)
153+
with CurrentUser(user):
32154
try:
33155
resp = func(*args, **func_kwargs)
34156
except BEMServerAuthorizationError:
35157
abort(403, message="Authorization error")
36158
return resp
37159

38-
# Wrap this inside HTTPAuth.login_required
39-
# to get authenticated user
40-
return super(Auth, self).login_required(**kwargs)(wrapper)
160+
return wrapper
41161

42162
if f:
43163
return decorator(f)
44164
return decorator
45165

46166

47167
auth = Auth()
48-
49-
50-
@auth.verify_password
51-
def verify_password(username, password):
52-
"""Check password and return User instance"""
53-
user = db.session.execute(sqla.select(User).where(User.email == username)).scalar()
54-
if user is not None and user.check_password(password):
55-
return user
56-
return None
57-
58-
59-
@auth.error_handler
60-
def auth_error(status):
61-
"""Authentication error handler"""
62-
# Call abort to trigger error handler and get consistent JSON output
63-
abort(status, message="Authentication error")

0 commit comments

Comments
 (0)