diff --git a/src/backend/app/config.py b/src/backend/app/config.py index 19d3964c..977076b4 100644 --- a/src/backend/app/config.py +++ b/src/backend/app/config.py @@ -78,6 +78,7 @@ def assemble_db_connection(cls, v: Optional[str], info: ValidationInfo) -> Any: S3_BUCKET_NAME: str = "dtm-data" S3_DOWNLOAD_ROOT: Optional[str] = None + ALGORITHM: str = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1 # 1 day REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 day diff --git a/src/backend/app/projects/project_routes.py b/src/backend/app/projects/project_routes.py index 39bb6ce5..084ebf8b 100644 --- a/src/backend/app/projects/project_routes.py +++ b/src/backend/app/projects/project_routes.py @@ -22,11 +22,8 @@ ) -@router.delete('/{project_id}', tags=["Projects"]) -def delete_project_by_id( - project_id: int, - db: Session = Depends(database.get_db) -): +@router.delete("/{project_id}", tags=["Projects"]) +def delete_project_by_id(project_id: int, db: Session = Depends(database.get_db)): """ Delete a project by its ID, along with all associated tasks. @@ -41,20 +38,28 @@ def delete_project_by_id( HTTPException: If the project is not found. """ # Query for the project - project = db.query(db_models.DbProject).filter(db_models.DbProject.id == project_id).first() + project = ( + db.query(db_models.DbProject) + .filter(db_models.DbProject.id == project_id) + .first() + ) if not project: raise HTTPException(status_code=404, detail="Project not found.") # Query and delete associated tasks - tasks = db.query(db_models.DbTask).filter(db_models.DbTask.project_id == project_id).all() + tasks = ( + db.query(db_models.DbTask) + .filter(db_models.DbTask.project_id == project_id) + .all() + ) for task in tasks: db.delete(task) - + # Delete the project db.delete(project) db.commit() return {"message": f"Project ID: {project_id} is deleted successfully."} - + @router.post( "/create_project", tags=["Projects"], response_model=project_schemas.ProjectOut diff --git a/src/backend/app/users/oauth_routes.py b/src/backend/app/users/oauth_routes.py index 0952bf0b..733602a3 100644 --- a/src/backend/app/users/oauth_routes.py +++ b/src/backend/app/users/oauth_routes.py @@ -7,8 +7,10 @@ from app.users.user_routes import router from app.users.user_deps import init_google_auth, login_required from app.users.user_schemas import AuthUser +from app.users import user_crud from app.config import settings + if settings.DEBUG: os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" @@ -39,7 +41,11 @@ async def callback(request: Request, google_auth=Depends(init_google_auth)): callback_url = str(request.url) access_token = google_auth.callback(callback_url).get("access_token") - return access_token + + user_data = google_auth.deserialize_access_token(access_token) + access_token, refresh_token = user_crud.create_access_token(user_data) + + return {"access_token": access_token, "refresh_token": refresh_token} @router.get("/my-info/") diff --git a/src/backend/app/users/user_crud.py b/src/backend/app/users/user_crud.py index 0440a0fd..a4ab7831 100644 --- a/src/backend/app/users/user_crud.py +++ b/src/backend/app/users/user_crud.py @@ -1,6 +1,6 @@ +import time import jwt from app.config import settings -from datetime import datetime, timedelta from typing import Any from passlib.context import CryptContext from sqlalchemy.orm import Session @@ -12,28 +12,47 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -ALGORITHM = "HS256" +def create_access_token(subject: str | Any): + expire = int(time.time()) + settings.ACCESS_TOKEN_EXPIRE_MINUTES + refresh_expire = int(time.time()) + settings.REFRESH_TOKEN_EXPIRE_MINUTES -def create_access_token( - subject: str | Any, expires_delta: timedelta, refresh_token_expiry: timedelta -): - expire = datetime.utcnow() + expires_delta - refresh_expire = datetime.utcnow() + refresh_token_expiry - - to_encode_access_token = {"exp": expire, "sub": str(subject)} - to_encode_refresh_token = {"exp": refresh_expire, "sub": str(subject)} - + # access token + subject["exp"] = expire access_token = jwt.encode( - to_encode_access_token, settings.SECRET_KEY, algorithm=ALGORITHM + subject, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) + + # refresh token + subject["exp"] = refresh_expire refresh_token = jwt.encode( - to_encode_refresh_token, settings.SECRET_KEY, algorithm=ALGORITHM + subject, settings.SECRET_KEY, algorithm=settings.ALGORITHM ) return access_token, refresh_token +def verify_token(token: str): + """Verifies the access token and returns the payload if valid. + + Args: + token (str): The access token to be verified. + + Returns: + dict: The payload of the access token if verification is successful. + + Raises: + HTTPException: If the token has expired or credentials could not be validated. + """ + secret_key = settings.SECRET_KEY + try: + return jwt.decode(token, str(secret_key), algorithms=[settings.ALGORITHM]) + except jwt.ExpiredSignatureError as e: + raise HTTPException(status_code=401, detail="Token has expired") from e + except Exception as e: + raise HTTPException(status_code=401, detail="Could not validate token") from e + + def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) diff --git a/src/backend/app/users/user_deps.py b/src/backend/app/users/user_deps.py index eade966b..c56434d1 100644 --- a/src/backend/app/users/user_deps.py +++ b/src/backend/app/users/user_deps.py @@ -83,16 +83,17 @@ async def login_required( ) -> AuthUser: """Dependency to inject into endpoints requiring login.""" - google_auth = await init_google_auth() + if not access_token: + raise HTTPException(status_code=401, detail="No access token provided") if not access_token: raise HTTPException(status_code=401, detail="No access token provided") try: - google_user = google_auth.deserialize_access_token(access_token) - except ValueError as e: + user = user_crud.verify_token(access_token) + except HTTPException as e: log.error(e) - log.error("Failed to deserialise access token") + log.error("Failed to verify access token") raise HTTPException(status_code=401, detail="Access token not valid") from e - return AuthUser(**google_user) + return AuthUser(**user) diff --git a/src/backend/app/users/user_routes.py b/src/backend/app/users/user_routes.py index 8a553818..28abb54c 100644 --- a/src/backend/app/users/user_routes.py +++ b/src/backend/app/users/user_routes.py @@ -13,7 +13,6 @@ from app.users import user_schemas from app.config import settings - router = APIRouter( prefix=f"{settings.API_PREFIX}/users", tags=["users"], @@ -34,14 +33,9 @@ async def login_access_token( raise HTTPException(status_code=400, detail="Incorrect email or password") elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - refresh_token_expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES) - access_token, refresh_token = user_crud.create_access_token( - user.id, - expires_delta=access_token_expires, - refresh_token_expiry=refresh_token_expires, - ) + access_token, refresh_token = user_crud.create_access_token(user.id) + return Token(access_token=access_token, refresh_token=refresh_token) # @router.post("/login/")