|
1 | 1 | """Authentication"""
|
2 | 2 |
|
| 3 | +import base64 |
| 4 | +import datetime as dt |
| 5 | +from datetime import datetime |
3 | 6 | from functools import wraps
|
4 | 7 |
|
5 | 8 | import sqlalchemy as sqla
|
6 | 9 |
|
7 |
| -from flask_httpauth import HTTPBasicAuth |
| 10 | +import flask |
8 | 11 |
|
9 | 12 | from flask_smorest import abort
|
10 | 13 |
|
| 14 | +from authlib.jose import JsonWebToken |
| 15 | +from authlib.jose.errors import ExpiredTokenError, JoseError |
| 16 | + |
11 | 17 | from bemserver_core.authorization import BEMServerAuthorizationError, CurrentUser
|
12 | 18 | from bemserver_core.model.users import User
|
13 | 19 |
|
14 | 20 | from bemserver_api.database import db
|
| 21 | +from bemserver_api.exceptions import BEMServerAPIAuthenticationError |
| 22 | + |
| 23 | +# https://docs.authlib.org/en/latest/jose/jwt.html#jwt-with-limited-algorithms |
| 24 | +jwt = JsonWebToken(["HS256"]) |
15 | 25 |
|
16 | 26 |
|
17 |
| -class Auth(HTTPBasicAuth): |
| 27 | +class Auth: |
18 | 28 | """Authentication and authorization management"""
|
19 | 29 |
|
20 |
| - def login_required(self, f=None, **kwargs): |
| 30 | + HEADER = {"alg": "HS256"} |
| 31 | + ACCESS_TOKEN_LIFETIME = 60 * 15 # 15 minutes |
| 32 | + REFRESH_TOKEN_LIFETIME = 60 * 60 * 24 * 60 # 2 months |
| 33 | + |
| 34 | + GET_USER_FUNCS = { |
| 35 | + "Bearer": "get_user_jwt", |
| 36 | + "Basic": "get_user_http_basic_auth", |
| 37 | + } |
| 38 | + |
| 39 | + def __init__(self, app=None): |
| 40 | + self.key = None |
| 41 | + self.app = None |
| 42 | + self.get_user_funcs = None |
| 43 | + if app is not None: |
| 44 | + self.init_app(app) |
| 45 | + |
| 46 | + def init_app(self, app): |
| 47 | + self.app = app |
| 48 | + self.key = app.config["SECRET_KEY"] |
| 49 | + self.get_user_funcs = { |
| 50 | + k: getattr(self, v) |
| 51 | + for k, v in self.GET_USER_FUNCS.items() |
| 52 | + if k in app.config["AUTH_METHODS"] |
| 53 | + } |
| 54 | + |
| 55 | + def encode(self, user, token_type="access"): |
| 56 | + token_lifetime = ( |
| 57 | + self.ACCESS_TOKEN_LIFETIME |
| 58 | + if token_type == "access" |
| 59 | + else self.REFRESH_TOKEN_LIFETIME |
| 60 | + ) |
| 61 | + claims = { |
| 62 | + "email": user.email, |
| 63 | + # datetime is imported in module namespace to allow test mock |
| 64 | + # kinda sucks, but oh well... |
| 65 | + "exp": datetime.now(tz=dt.timezone.utc) |
| 66 | + + dt.timedelta(seconds=token_lifetime), |
| 67 | + "type": token_type, |
| 68 | + } |
| 69 | + return jwt.encode(self.HEADER.copy(), claims, self.key) |
| 70 | + |
| 71 | + def decode(self, text): |
| 72 | + return jwt.decode( |
| 73 | + text, |
| 74 | + self.key, |
| 75 | + claims_options={ |
| 76 | + "email": {"essential": True}, |
| 77 | + "type": {"essential": True}, |
| 78 | + }, |
| 79 | + ) |
| 80 | + |
| 81 | + @staticmethod |
| 82 | + def get_user_by_email(user_email): |
| 83 | + return db.session.execute( |
| 84 | + sqla.select(User).where(User.email == user_email) |
| 85 | + ).scalar() |
| 86 | + |
| 87 | + def get_user_jwt(self, creds, refresh=False): |
| 88 | + try: |
| 89 | + claims = self.decode(creds) |
| 90 | + claims.validate() |
| 91 | + except ExpiredTokenError as exc: |
| 92 | + raise BEMServerAPIAuthenticationError(code="expired_token") from exc |
| 93 | + except JoseError as exc: |
| 94 | + raise BEMServerAPIAuthenticationError(code="invalid_token") from exc |
| 95 | + if refresh is not (claims["type"] == "refresh"): |
| 96 | + raise BEMServerAPIAuthenticationError(code="invalid_token") |
| 97 | + user_email = claims["email"] |
| 98 | + if (user := self.get_user_by_email(user_email)) is None: |
| 99 | + raise BEMServerAPIAuthenticationError(code="invalid_token") |
| 100 | + return user |
| 101 | + |
| 102 | + def get_user_http_basic_auth(self, creds, **_kwargs): |
| 103 | + """Check password and return User instance""" |
| 104 | + try: |
| 105 | + enc_email, enc_password = base64.b64decode(creds).split(b":", maxsplit=1) |
| 106 | + user_email = enc_email.decode() |
| 107 | + password = enc_password.decode() |
| 108 | + except (ValueError, TypeError) as exc: |
| 109 | + raise BEMServerAPIAuthenticationError(code="malformed_credentials") from exc |
| 110 | + if (user := self.get_user_by_email(user_email)) is None: |
| 111 | + raise BEMServerAPIAuthenticationError(code="invalid_credentials") |
| 112 | + if not user.check_password(password): |
| 113 | + raise BEMServerAPIAuthenticationError(code="invalid_credentials") |
| 114 | + return user |
| 115 | + |
| 116 | + def get_user(self, refresh=False): |
| 117 | + if (auth_header := flask.request.headers.get("Authorization")) is None: |
| 118 | + raise BEMServerAPIAuthenticationError(code="missing_authentication") |
| 119 | + try: |
| 120 | + scheme, creds = auth_header.split(" ", maxsplit=1) |
| 121 | + except ValueError: |
| 122 | + abort(400) |
| 123 | + try: |
| 124 | + func = self.get_user_funcs[scheme] |
| 125 | + except KeyError as exc: |
| 126 | + raise BEMServerAPIAuthenticationError(code="invalid_scheme") from exc |
| 127 | + return func(creds.encode("utf-8"), refresh=refresh) |
| 128 | + |
| 129 | + def login_required(self, f=None, refresh=False): |
21 | 130 | """Decorator providing authentication and authorization
|
22 | 131 |
|
23 |
| - Uses HTTPBasicAuth.login_required authenticate user |
| 132 | + Uses JWT or HTTPBasicAuth.login_required to authenticate user |
24 | 133 | Sets CurrentUser context variable to authenticated user for the request
|
25 | 134 | Catches Authorization error and aborts accordingly
|
26 | 135 | """
|
27 | 136 |
|
28 | 137 | def decorator(func):
|
29 | 138 | @wraps(func)
|
30 | 139 | def wrapper(*args, **func_kwargs):
|
31 |
| - with CurrentUser(self.current_user()): |
| 140 | + try: |
| 141 | + user = self.get_user(refresh=refresh) |
| 142 | + except BEMServerAPIAuthenticationError as exc: |
| 143 | + abort( |
| 144 | + 401, |
| 145 | + "Authentication error", |
| 146 | + errors={"authentication": exc.code}, |
| 147 | + headers={ |
| 148 | + "WWW-Authenticate": ", ".join( |
| 149 | + self.app.config["AUTH_METHODS"] |
| 150 | + ) |
| 151 | + }, |
| 152 | + ) |
| 153 | + with CurrentUser(user): |
32 | 154 | try:
|
33 | 155 | resp = func(*args, **func_kwargs)
|
34 | 156 | except BEMServerAuthorizationError:
|
35 | 157 | abort(403, message="Authorization error")
|
36 | 158 | return resp
|
37 | 159 |
|
38 |
| - # Wrap this inside HTTPAuth.login_required |
39 |
| - # to get authenticated user |
40 |
| - return super(Auth, self).login_required(**kwargs)(wrapper) |
| 160 | + return wrapper |
41 | 161 |
|
42 | 162 | if f:
|
43 | 163 | return decorator(f)
|
44 | 164 | return decorator
|
45 | 165 |
|
46 | 166 |
|
47 | 167 | auth = Auth()
|
48 |
| - |
49 |
| - |
50 |
| -@auth.verify_password |
51 |
| -def verify_password(username, password): |
52 |
| - """Check password and return User instance""" |
53 |
| - user = db.session.execute(sqla.select(User).where(User.email == username)).scalar() |
54 |
| - if user is not None and user.check_password(password): |
55 |
| - return user |
56 |
| - return None |
57 |
| - |
58 |
| - |
59 |
| -@auth.error_handler |
60 |
| -def auth_error(status): |
61 |
| - """Authentication error handler""" |
62 |
| - # Call abort to trigger error handler and get consistent JSON output |
63 |
| - abort(status, message="Authentication error") |
|
0 commit comments