Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bearer (JWT) authentication #233

Merged
merged 16 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"marshmallow-sqlalchemy>=0.29.0,<0.30",
"flask_smorest>=0.43.0,<0.44",
"apispec>=6.1.0,<7.0",
"flask-httpauth>=4.7.0,<5.0",
"authlib>=1.3.0,<2.0",
"bemserver-core>=0.17.1,<0.18",
]

Expand Down Expand Up @@ -77,7 +77,7 @@ section-order = ["future", "standard-library", "testing", "db", "pallets", "mars
[tool.ruff.lint.isort.sections]
testing = ["pytest", "pytest_postgresql"]
db = ["psycopg", "sqlalchemy", "alembic"]
pallets = ["werkzeug", "flask", "flask_httpauth"]
pallets = ["werkzeug", "flask"]
marshmallow = ["marshmallow", "marshmallow_sqlalchemy", "webargs", "apispec", "flask_smorest"]
science = ["numpy", "pandas"]
core = ["bemserver_core"]
Expand Down
8 changes: 5 additions & 3 deletions requirements/install.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ argon2-cffi-bindings==21.2.0
# via argon2-cffi
async-timeout==4.0.3
# via redis
authlib==1.3.0
# via bemserver-api (pyproject.toml)
bemserver-core==0.17.1
# via bemserver-api (pyproject.toml)
billiard==4.2.0
Expand All @@ -31,6 +33,7 @@ certifi==2024.2.2
cffi==1.16.0
# via
# argon2-cffi-bindings
# cryptography
# oso
charset-normalizer==3.3.2
# via requests
Expand All @@ -48,13 +51,12 @@ click-plugins==1.1.1
# via celery
click-repl==0.3.0
# via celery
cryptography==42.0.5
# via authlib
flask==3.0.2
# via
# bemserver-api (pyproject.toml)
# flask-httpauth
# flask-smorest
flask-httpauth==4.8.0
# via bemserver-api (pyproject.toml)
flask-smorest==0.43.0
# via bemserver-api (pyproject.toml)
greenlet==3.0.3
Expand Down
1 change: 1 addition & 0 deletions src/bemserver_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def create_app():
}
)
api.init_app(app)
authentication.auth.init_app(app)
register_blueprints(api)

BEMServerCore()
Expand Down
12 changes: 12 additions & 0 deletions src/bemserver_api/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Exceptions"""


class BEMServerAPIError(Exception):
"""Base BEMServer API exception"""


class BEMServerAPIAuthenticationError(BEMServerAPIError):
"""AuthenticationError error"""

def __init__(self, code):
self.code = code
152 changes: 128 additions & 24 deletions src/bemserver_api/extensions/authentication.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,167 @@
"""Authentication"""

import base64
import datetime as dt
from datetime import datetime
from functools import wraps

import sqlalchemy as sqla

from flask_httpauth import HTTPBasicAuth
import flask

from flask_smorest import abort

from authlib.jose import JsonWebToken
from authlib.jose.errors import ExpiredTokenError, JoseError

from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser
from bemserver_core.model.users import User

from bemserver_api.database import db
from bemserver_api.exceptions import BEMServerAPIAuthenticationError

# https://docs.authlib.org/en/latest/jose/jwt.html#jwt-with-limited-algorithms
jwt = JsonWebToken(["HS256"])


class Auth(HTTPBasicAuth):
class Auth:
"""Authentication and authorization management"""

def login_required(self, f=None, **kwargs):
HEADER = {"alg": "HS256"}
ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes
REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months

GET_USER_FUNCS = {
"Bearer": "get_user_jwt",
"Basic": "get_user_http_basic_auth",
}

def __init__(self, app=None):
self.key = None
self.app = None
self.get_user_funcs = None
if app is not None:
self.init_app(app)

Check warning on line 44 in src/bemserver_api/extensions/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/bemserver_api/extensions/authentication.py#L44

Added line #L44 was not covered by tests

def init_app(self, app):
self.app = app
self.key = app.config["SECRET_KEY"]
self.get_user_funcs = {
k: getattr(self, v)
for k, v in self.GET_USER_FUNCS.items()
if k in app.config["AUTH_METHODS"]
}

def encode(self, user, token_type="access"):
token_lifetime = (
self.ACCESS_TOKEN_LIFETIME
if token_type == "access"
else self.REFRESH_TOKEN_LIFETIME
)
claims = {
"email": user.email,
# datetime is imported in module namespace to allow test mock
# kinda sucks, but oh well...
"exp": datetime.now(tz=dt.timezone.utc)
+ dt.timedelta(seconds=token_lifetime),
"type": token_type,
}
return jwt.encode(self.HEADER.copy(), claims, self.key)

def decode(self, text):
return jwt.decode(
text,
self.key,
claims_options={
"email": {"essential": True},
"type": {"essential": True},
},
)

@staticmethod
def get_user_by_email(user_email):
return db.session.execute(
sqla.select(User).where(User.email == user_email)
).scalar()

def get_user_jwt(self, creds, refresh=False):
try:
claims = self.decode(creds)
claims.validate()
except ExpiredTokenError as exc:
raise BEMServerAPIAuthenticationError(code="expired_token") from exc
except JoseError as exc:
raise BEMServerAPIAuthenticationError(code="invalid_token") from exc
if refresh is not (claims["type"] == "refresh"):
raise BEMServerAPIAuthenticationError(code="invalid_token")
user_email = claims["email"]
if (user := self.get_user_by_email(user_email)) is None:
raise BEMServerAPIAuthenticationError(code="invalid_token")
return user

def get_user_http_basic_auth(self, creds, **_kwargs):
"""Check password and return User instance"""
try:
enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1)
user_email = enc_email.decode()
password = enc_password.decode()
except (ValueError, TypeError) as exc:
raise BEMServerAPIAuthenticationError(code="malformed_credentials") from exc
if (user := self.get_user_by_email(user_email)) is None:
raise BEMServerAPIAuthenticationError(code="invalid_credentials")

Check warning on line 111 in src/bemserver_api/extensions/authentication.py

View check run for this annotation

Codecov / codecov/patch

src/bemserver_api/extensions/authentication.py#L111

Added line #L111 was not covered by tests
if not user.check_password(password):
raise BEMServerAPIAuthenticationError(code="invalid_credentials")
return user

def get_user(self, refresh=False):
if (auth_header := flask.request.headers.get("Authorization")) is None:
raise BEMServerAPIAuthenticationError(code="missing_authentication")
try:
scheme, creds = auth_header.split(" ", maxsplit=1)
except ValueError:
abort(400)
try:
func = self.get_user_funcs[scheme]
except KeyError as exc:
raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc
return func(creds.encode("utf-8"), refresh=refresh)

def login_required(self, f=None, refresh=False):
"""Decorator providing authentication and authorization

Uses HTTPBasicAuth.login_required authenticate user
Uses JWT or HTTPBasicAuth.login_required to authenticate user
Sets CurrentUser context variable to authenticated user for the request
Catches Authorization error and aborts accordingly
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **func_kwargs):
with CurrentUser(self.current_user()):
try:
user = self.get_user(refresh=refresh)
except BEMServerAPIAuthenticationError as exc:
abort(
401,
"Authentication error",
errors={"authentication": exc.code},
headers={
"WWW-Authenticate": ", ".join(
self.app.config["AUTH_METHODS"]
)
},
)
with CurrentUser(user):
try:
resp = func(*args, **func_kwargs)
except BEMServerAuthorizationError:
abort(403, message="Authorization error")
return resp

# Wrap this inside HTTPAuth.login_required
# to get authenticated user
return super(Auth, self).login_required(**kwargs)(wrapper)
return wrapper

if f:
return decorator(f)
return decorator


auth = Auth()


@auth.verify_password
def verify_password(username, password):
"""Check password and return User instance"""
user = db.session.execute(sqla.select(User).where(User.email == username)).scalar()
if user is not None and user.check_password(password):
return user
return None


@auth.error_handler
def auth_error(status):
"""Authentication error handler"""
# Call abort to trigger error handler and get consistent JSON output
abort(status, message="Authentication error")
Loading