Skip to content

Commit

Permalink
feat: implemented refresh token and updated dependency lock file
Browse files Browse the repository at this point in the history
  • Loading branch information
Sujanadh committed Jun 25, 2024
1 parent 50a65ab commit 1d090e5
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 56 deletions.
81 changes: 40 additions & 41 deletions src/backend/app/auth/auth_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@
from sqlalchemy import text
from sqlalchemy.orm import Session

from app.auth.osm import AuthUser, create_access_token, init_osm_auth, login_required
from app.auth.osm import (
AuthUser,
create_tokens,
extract_refresh_token_from_cookie,
init_osm_auth,
login_required,
refresh_access_token,
set_cookies,
verify_token,
)
from app.config import settings
from app.db import database
from app.models.enums import HTTPStatus, UserRole
from app.models.enums import UserRole

router = APIRouter(
prefix="/auth",
Expand Down Expand Up @@ -82,7 +91,6 @@ async def callback(request: Request, osm_auth=Depends(init_osm_auth)):
# Get access token
access_token = osm_auth.callback(callback_url).get("access_token")
log.debug(f"Access token returned of length {len(access_token)}")
response = Response(status_code=HTTPStatus.OK)

osm_user = osm_auth.deserialize_access_token(access_token)
user_data = {
Expand All @@ -96,23 +104,8 @@ async def callback(request: Request, osm_auth=Depends(init_osm_auth)):
"img_url": osm_user.get("img_url"),
"role": UserRole.MAPPER,
}
jwt_token = create_access_token(user_data)

# Set cookie
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
response = Response(status_code=200)
response.set_cookie(
key=cookie_name,
value=jwt_token,
max_age=31536000,
expires=31536000,
path="/",
domain=settings.FMTM_DOMAIN,
secure=False if settings.DEBUG else True,
httponly=True,
samesite="lax",
)
return response
access_token, refresh_token = create_tokens(user_data)
return set_cookies(access_token, refresh_token)


@router.get("/logout/")
Expand Down Expand Up @@ -231,15 +224,35 @@ async def my_data(
return await get_or_create_user(db, user_data)


@router.get("/introspect", response_model=AuthUser)
async def check_login(
user_data: AuthUser = Depends(login_required),
@router.get("/refresh")
async def refresh_token(
request: Request,
):
"""Verifies the validity of login cookies.
Returns True if authenticated, False otherwise.
"""
return user_data
refresh_token = extract_refresh_token_from_cookie(request)
if not refresh_token:
raise HTTPException(status_code=401, detail="No tokens provided")

token_data = verify_token(refresh_token)
access_token = refresh_access_token(token_data)

response = Response(status_code=200)
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
response.set_cookie(
key=cookie_name,
value=access_token,
max_age=86400,
expires=86400,
path="/",
domain=settings.FMTM_DOMAIN,
secure=False if settings.DEBUG else True,
httponly=True,
samesite="lax",
)
return response


@router.get("/temp-login")
Expand All @@ -264,24 +277,10 @@ async def temp_login(
"sub": f"fmtm|{username}",
"aud": settings.FMTM_DOMAIN,
"iat": int(time.time()),
"exp": int(time.time()) + 86400, # expiry set to 1 day
"exp": int(time.time()) + 86400 * 7, # expiry set to 7 days
"username": username,
"img_url": None,
"role": UserRole.MAPPER,
}
jwt_token = create_access_token(user_data)

response = Response(status_code=200)
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
response.set_cookie(
key=cookie_name,
value=jwt_token,
max_age=86400,
expires=86400,
path="/",
domain=settings.FMTM_DOMAIN,
secure=False if settings.DEBUG else True,
httponly=True,
samesite="lax",
)
return response
access_token, refresh_token = create_tokens(user_data)
return set_cookies(access_token, refresh_token)
97 changes: 83 additions & 14 deletions src/backend/app/auth/osm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
"""Auth methods related to OSM OAuth2."""

import os
import time
from typing import Optional

import jwt
from fastapi import Header, HTTPException, Request
from fastapi import Header, HTTPException, Request, Response
from loguru import logger as log
from osm_login_python.core import Auth
from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -71,15 +72,13 @@ async def login_required(

# Attempt extract from cookie if access token not passed
if not access_token:
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
log.debug(f"Extracting token from cookie {cookie_name}")
access_token = request.cookies.get(cookie_name)
access_token = extract_token_from_cookie(request)

if not access_token:
raise HTTPException(status_code=401, detail="No access token provided")

try:
token_data = verify_access_token(access_token)
token_data = verify_token(access_token)
except ValueError as e:
log.error(e)
log.error("Failed to deserialise access token")
Expand All @@ -88,20 +87,54 @@ async def login_required(
return AuthUser(**token_data)


def create_access_token(payload: dict) -> str:
"""Generates an access token for the specified user.
def extract_token_from_cookie(request: Request) -> str:
"""Extract access token from cookies."""
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
log.debug(f"Extracting token from cookie {cookie_name}")
return request.cookies.get(cookie_name)


def extract_refresh_token_from_cookie(request: Request) -> str:
"""Extract refresh token from cookies."""
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
return request.cookies.get(f"{cookie_name}_refresh")


def create_tokens(payload: dict) -> tuple[str, str]:
"""Generates tokens for the specified user.
Args:
payload (dict): user data for which the access token is being generated.
Returns:
str: The generated access token.
Tuple: The generated access tokens.
"""
access_token_payload = payload
access_token_payload["exp"] = (
int(time.time()) + 86400
) # set access token expiry to 1 day
private_key = settings.AUTH_PRIVATE_KEY
access_token = jwt.encode(
access_token_payload, str(private_key), algorithm=settings.ALGORITHM
)
refresh_token = jwt.encode(payload, str(private_key), algorithm=settings.ALGORITHM)
return access_token, refresh_token


def refresh_access_token(payload: dict) -> str:
"""Generate a new access token."""
access_token_payload = payload
access_token_payload["exp"] = (
int(time.time()) + 60
) # Access token valid for 15 minutes

private_key = settings.AUTH_PRIVATE_KEY
return jwt.encode(payload, str(private_key), algorithm=settings.ALGORITHM)
return jwt.encode(
access_token_payload, str(private_key), algorithm=settings.ALGORITHM
)


def verify_access_token(token: str):
def verify_token(token: str):
"""Verifies the access token and returns the payload if valid.
Args:
Expand All @@ -113,18 +146,54 @@ def verify_access_token(token: str):
Raises:
HTTPException: If the token has expired or credentials could not be validated.
"""
public_key = settings.AUTH_PUBLIC_KEY
try:
public_key = settings.AUTH_PUBLIC_KEY
return jwt.decode(
token,
str(public_key),
algorithms=[settings.ALGORITHM],
audience=settings.FMTM_DOMAIN,
)
except jwt.ExpiredSignatureError as e:
raise HTTPException(status_code=401, detail="Token has expired") from e
raise HTTPException(status_code=401, detail="Refresh token has expired") from e
except Exception as e:
print(e)
raise HTTPException(
status_code=401, detail="Could not validate credentials"
status_code=401, detail="Could not validate refresh token"
) from e


def set_cookies(access_token: str, refresh_token: str):
"""Sets cookies for the access token and refresh token.
Args:
access_token (str): The access token to be stored in the cookie.
refresh_token (str): The refresh token to be stored in the cookie.
Returns:
Response: A response object with the cookies set.
"""
response = Response(status_code=200)
cookie_name = settings.FMTM_DOMAIN.replace(".", "_")
response.set_cookie(
key=cookie_name,
value=access_token,
max_age=86400,
expires=86400, # expiry set for 1 day
path="/",
domain=settings.FMTM_DOMAIN,
secure=False if settings.DEBUG else True,
httponly=True,
samesite="lax",
)
response.set_cookie(
key=f"{cookie_name}_refresh",
value=refresh_token,
max_age=86400 * 7,
expires=86400 * 7, # expiry set for 7 days
path="/",
domain=settings.FMTM_DOMAIN,
secure=False if settings.DEBUG else True,
httponly=True,
samesite="lax",
)
return response
2 changes: 1 addition & 1 deletion src/backend/pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1d090e5

Please sign in to comment.