Skip to content

Commit

Permalink
Merge pull request #5 from hotosm/feat-authentication
Browse files Browse the repository at this point in the history
Feat authentication
  • Loading branch information
nrjadkry authored Jun 13, 2024
2 parents ac17311 + ae74551 commit 0a16cf4
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def assemble_db_connection(cls, v: Optional[str], info: ValidationInfo) -> Any:
S3_BUCKET_NAME: str = "dtm-data"
S3_DOWNLOAD_ROOT: Optional[str] = None

ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1 # 1 day
REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 day


@lru_cache
Expand Down
21 changes: 16 additions & 5 deletions src/backend/app/users/user_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,26 @@

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


ALGORITHM = "HS256"


def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
def create_access_token(
subject: str | Any, expires_delta: timedelta, refresh_token_expiry: timedelta
):
expire = datetime.utcnow() + expires_delta
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
refresh_expire = datetime.utcnow() + refresh_token_expiry

to_encode_access_token = {"exp": expire, "sub": str(subject)}
to_encode_refresh_token = {"exp": refresh_expire, "sub": str(subject)}

access_token = jwt.encode(
to_encode_access_token, settings.SECRET_KEY, algorithm=ALGORITHM
)
refresh_token = jwt.encode(
to_encode_refresh_token, settings.SECRET_KEY, algorithm=ALGORITHM
)

return access_token, refresh_token


def verify_password(plain_password: str, hashed_password: str) -> bool:
Expand Down
55 changes: 55 additions & 0 deletions src/backend/app/users/user_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import jwt
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from pydantic import ValidationError
from sqlalchemy.orm import Session
from app.config import settings
from app.db import database
from app.users import user_crud, user_schemas
from app.db.db_models import DbUser


reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login")


SessionDep = Annotated[
Session,
Depends(database.get_db),
]
TokenDep = Annotated[str, Depends(reusable_oauth2)]


def get_current_user(session: SessionDep, token: TokenDep):
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[user_crud.ALGORITHM]
)
token_data = user_schemas.TokenPayload(**payload)

except (InvalidTokenError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)

user = session.get(DbUser, token_data.sub)

if not user:
raise HTTPException(status_code=404, detail="User not found")
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return user


CurrentUser = Annotated[DbUser, Depends(get_current_user)]


def get_current_active_superuser(current_user: CurrentUser):
if not current_user.is_superuser:
raise HTTPException(
status_code=403, detail="The user doesn't have enough privileges"
)
return current_user
34 changes: 30 additions & 4 deletions src/backend/app/users/user_routes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any
from datetime import timedelta
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.orm import Session
from typing import Annotated
from fastapi.security import OAuth2PasswordRequestForm
from app.users.user_schemas import Token, UserPublic, UserRegister
from app.users.user_deps import CurrentUser
from app.config import settings
from app.users import user_crud
from app.db import database
Expand Down Expand Up @@ -31,11 +33,14 @@ def login_access_token(
elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return Token(
access_token=user_crud.create_access_token(
user.id, expires_delta=access_token_expires
)
refresh_token_expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)

access_token, refresh_token = user_crud.create_access_token(
user.id,
expires_delta=access_token_expires,
refresh_token_expiry=refresh_token_expires,
)
return Token(access_token=access_token, refresh_token=refresh_token)


@router.post("/signup", response_model=UserPublic)
Expand All @@ -61,3 +66,24 @@ def register_user(

user = user_crud.create_user(db, user_in)
return user


@router.get("/refresh_token")
def update_token(current_user: CurrentUser):
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)

access_token, refresh_token = user_crud.create_access_token(
current_user.id,
expires_delta=access_token_expires,
refresh_token_expiry=refresh_token_expires,
)
return Token(access_token=access_token, refresh_token=refresh_token)


@router.get("/me", response_model=UserPublic)
def read_user_me(current_user: CurrentUser) -> Any:
"""
Get current user.
"""
return current_user
6 changes: 6 additions & 0 deletions src/backend/app/users/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ class User(BaseModel):
name: str


# Contents of JWT token
class TokenPayload(BaseModel):
sub: int | None = None


class Token(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"


Expand Down

0 comments on commit 0a16cf4

Please sign in to comment.