Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auth #7

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions macrostrat_db_insertion/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

from sqlalchemy import create_engine, MetaData, Engine
from sqlalchemy.orm import sessionmaker, declarative_base, Session

engine: Engine | None = None
base: declarative_base = None
session: Session | None = None


def get_engine() -> Engine:
return engine


def get_base() -> declarative_base:
return base


def connect_engine(uri: str, schema: str):
global engine
global session
global base

engine = create_engine(uri)
session = session

base = declarative_base()
base.metadata.reflect(get_engine())
base.metadata.reflect(get_engine(), schema=schema, views=True)


def dispose_engine():
global engine
engine.dispose()


def get_session_maker() -> sessionmaker:
return sessionmaker(autocommit=False, autoflush=False, bind=get_engine())


def get_session() -> Session:
with get_session_maker()() as s:
yield s
2 changes: 2 additions & 0 deletions macrostrat_db_insertion/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ dependencies:
- urllib3==2.2.1
- werkzeug==3.0.2
- zipp==3.18.1
- fuzzysearch
- uvicorn
prefix: /conda/envs/db_insert_env
114 changes: 114 additions & 0 deletions macrostrat_db_insertion/security-v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
from datetime import datetime
from typing import Annotated, Optional

import bcrypt
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.security import (
HTTPAuthorizationCredentials,
HTTPBearer,
OAuth2AuthorizationCodeBearer,
)
from fastapi.security.utils import get_authorization_scheme_param
from jose import JWTError, jwt
from pydantic import BaseModel
from sqlalchemy import select, update
from starlette.status import HTTP_401_UNAUTHORIZED

from macrostrat_db_insertion.security.db import get_access_token
from macrostrat_db_insertion.security.model import TokenData

ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours
GROUP_TOKEN_LENGTH = 32
GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent


class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer):
"""Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header"""

async def __call__(self, request: Request) -> Optional[str]:
authorization = request.cookies.get("Authorization") # authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={
"WWW-Authenticate": "Bearer"
},
)
else:
return None # pragma: nocover
return param


oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie(
authorizationUrl='/security/login',
tokenUrl="/security/callback",
auto_error=False
)

http_bearer = HTTPBearer(auto_error=False)


def get_groups_from_header_token(
header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None:
"""Get the groups from the bearer token in the header"""

if header_token is None:
return None

token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT)
token_hash_string = token_hash.decode('utf-8')

token = get_access_token(token=token_hash_string)

if token is None:
return None

return token.group


def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2_scheme)]):
"""Get the current user from the JWT token in the cookies"""

# If there wasn't a token include in the request
if token is None:
return None

try:
payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']])
sub: str = payload.get("sub")
groups = payload.get("groups", [])
token_data = TokenData(sub=sub, groups=groups)
except JWTError as e:
return None

return token_data


def get_groups(
user_token_data: TokenData | None = Depends(get_user_token_from_cookie),
header_token: int | None = Depends(get_groups_from_header_token)
) -> list[int]:
"""Get the groups from both the cookies and header"""

groups = []
if user_token_data is not None:
groups = user_token_data.groups

if header_token is not None:
groups.append(header_token)

return groups


async def has_access(groups: list[int] = Depends(get_groups)) -> bool:
"""Check if the user has access to the group"""

if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development':
return True

return 1 in groups
6 changes: 6 additions & 0 deletions macrostrat_db_insertion/security/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from macrostrat_db_insertion.security.main import (
get_groups_from_header_token,
get_user_token_from_cookie,
get_groups,
has_access
)
30 changes: 30 additions & 0 deletions macrostrat_db_insertion/security/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import datetime

from sqlalchemy import select, update

from macrostrat_db_insertion.database import get_session_maker, get_engine
from macrostrat_db_insertion.security.schema import Token


def get_access_token(token: str):
"""The sole database call """

session_maker = get_session_maker()
with session_maker() as session:

select_stmt = select(Token).where(Token.token == token)

# Check that the token exists
result = (session.scalars(select_stmt)).first()

# Check if it has expired
if result.expires_on < datetime.datetime.now(datetime.timezone.utc):
return None

# Update the used_on column
if result is not None:
stmt = update(Token).where(Token.token == token).values(used_on=datetime.datetime.utcnow())
session.execute(stmt)
session.commit()

return (session.scalars(select_stmt)).first()
114 changes: 114 additions & 0 deletions macrostrat_db_insertion/security/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
from datetime import datetime
from typing import Annotated, Optional

import bcrypt
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.security import (
HTTPAuthorizationCredentials,
HTTPBearer,
OAuth2AuthorizationCodeBearer,
)
from fastapi.security.utils import get_authorization_scheme_param
from jose import JWTError, jwt
from pydantic import BaseModel
from sqlalchemy import select, update
from starlette.status import HTTP_401_UNAUTHORIZED

from macrostrat_db_insertion.security.db import get_access_token
from macrostrat_db_insertion.security.model import TokenData

ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours
GROUP_TOKEN_LENGTH = 32
GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent


class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer):
"""Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header"""

async def __call__(self, request: Request) -> Optional[str]:
authorization = request.cookies.get("Authorization") # authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={
"WWW-Authenticate": "Bearer"
},
)
else:
return None # pragma: nocover
return param


oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie(
authorizationUrl='/security/login',
tokenUrl="/security/callback",
auto_error=False
)

http_bearer = HTTPBearer(auto_error=False)


def get_groups_from_header_token(
header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None:
"""Get the groups from the bearer token in the header"""

if header_token is None:
return None

token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT)
token_hash_string = token_hash.decode('utf-8')

token = get_access_token(token=token_hash_string)

if token is None:
return None

return token.group


def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2_scheme)]):
"""Get the current user from the JWT token in the cookies"""

# If there wasn't a token include in the request
if token is None:
return None

try:
payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']])
sub: str = payload.get("sub")
groups = payload.get("groups", [])
token_data = TokenData(sub=sub, groups=groups)
except JWTError as e:
return None

return token_data


def get_groups(
user_token_data: TokenData | None = Depends(get_user_token_from_cookie),
header_token: int | None = Depends(get_groups_from_header_token)
) -> list[int]:
"""Get the groups from both the cookies and header"""

groups = []
if user_token_data is not None:
groups = user_token_data.groups

if header_token is not None:
groups.append(header_token)

return groups


def has_access(groups: list[int] = Depends(get_groups)) -> bool:
"""Check if the user has access to the group"""

if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development':
return True

return 1 in groups
23 changes: 23 additions & 0 deletions macrostrat_db_insertion/security/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pydantic import BaseModel


class TokenData(BaseModel):
sub: str
groups: list[int] = []


class User(BaseModel):
username: str
email: str | None = None
full_name: str | None = None
disabled: bool | None = None


class AccessToken(BaseModel):
group: int
token: str


class GroupTokenRequest(BaseModel):
expiration: int
group_id: int
Loading
Loading