diff --git a/src/backend/app/users/oauth_routes.py b/src/backend/app/users/oauth_routes.py index 730251d1..8984eea2 100644 --- a/src/backend/app/users/oauth_routes.py +++ b/src/backend/app/users/oauth_routes.py @@ -2,13 +2,13 @@ from loguru import logger as log from fastapi import Depends, Request from fastapi.responses import JSONResponse -from sqlalchemy.orm import Session from app.db import database from app.users.user_routes import router from app.users.user_deps import init_google_auth, login_required from app.users.user_schemas import AuthUser, Token from app.users import user_crud from app.config import settings +from databases import Database if settings.DEBUG: @@ -43,7 +43,7 @@ async def callback(request: Request, google_auth=Depends(init_google_auth)): access_token = google_auth.callback(callback_url).get("access_token") user_data = google_auth.deserialize_access_token(access_token) - access_token, refresh_token = user_crud.create_access_token(user_data) + access_token, refresh_token = await user_crud.create_access_token(user_data) return Token(access_token=access_token, refresh_token=refresh_token) @@ -58,9 +58,9 @@ def update_token(user_data: AuthUser = Depends(login_required)): @router.get("/my-info/") async def my_data( - db: Session = Depends(database.get_db), + db: Database = Depends(database.encode_db), user_data: AuthUser = Depends(login_required), ): """Read access token and get user details from Google""" - return user_data + return await user_crud.get_or_create_user(db, user_data) diff --git a/src/backend/app/users/user_crud.py b/src/backend/app/users/user_crud.py index a4ab7831..f43517cc 100644 --- a/src/backend/app/users/user_crud.py +++ b/src/backend/app/users/user_crud.py @@ -5,15 +5,16 @@ from passlib.context import CryptContext from sqlalchemy.orm import Session from app.db import db_models -from app.users.user_schemas import UserCreate +from app.users.user_schemas import UserCreate, AuthUser from sqlalchemy import text from databases import Database from fastapi import HTTPException +from app.models.enums import UserRole pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") -def create_access_token(subject: str | Any): +async def create_access_token(subject: str | Any): expire = int(time.time()) + settings.ACCESS_TOKEN_EXPIRE_MINUTES refresh_expire = int(time.time()) + settings.REFRESH_TOKEN_EXPIRE_MINUTES @@ -67,6 +68,7 @@ def get_user_by_email(db: Session, email: str): data = result.fetchone() return data + async def get_user_email(db: Database, email: str): query = f"SELECT * FROM users WHERE email_address = '{email}' LIMIT 1;" result = await db.fetch_one(query) @@ -78,20 +80,25 @@ async def get_user_username(db: Database, username: str): result = await db.fetch_one(query=query) return result + def get_user_by_username(db: Session, username: str): query = text(f"SELECT * FROM users WHERE username = '{username}' LIMIT 1;") result = db.execute(query) data = result.fetchone() return data -async def authenticate(db: Database, username: str, password: str) -> db_models.DbUser | None: + +async def authenticate( + db: Database, username: str, password: str +) -> db_models.DbUser | None: db_user = await get_user_username(db, username) if not db_user: return None - if not verify_password(password, db_user['password']): + if not verify_password(password, db_user["password"]): return None return db_user + # def authenticate(db: Session, username: str, password: str) -> db_models.DbUser | None: # db_user = get_user_by_username(db, username) # if not db_user: @@ -101,7 +108,7 @@ async def authenticate(db: Database, username: str, password: str) -> db_models. # return db_user -async def create_user(db: Database, user_create: UserCreate): +async def create_user(db: Database, user_create: UserCreate): query = f""" INSERT INTO users (username, password, is_active, name, email_address, is_superuser) VALUES ('{user_create.username}', '{get_password_hash(user_create.password)}', {True}, '{user_create.name}', '{user_create.email_address}', {False}) @@ -111,8 +118,47 @@ async def create_user(db: Database, user_create: UserCreate): raw_query = f"SELECT * from users WHERE id = {_id} LIMIT 1" db_obj = await db.fetch_one(query=raw_query) if not db_obj: - raise HTTPException( - status_code=500, - detail="User could not be created" - ) + raise HTTPException(status_code=500, detail="User could not be created") return db_obj + + +async def get_or_create_user( + db: Database, + user_data: AuthUser, +): + """Get user from User table if exists, else create.""" + try: + update_sql = """ + INSERT INTO users ( + id, username, email_address, profile_img, role + ) + VALUES ( + :user_id, :username, :email_address, :profile_img, :role + ) + ON CONFLICT (id) + DO UPDATE SET profile_img = :profile_img; + """ + + await db.execute( + update_sql, + { + "user_id": str(user_data.id), + "username": user_data.email, # FIXME: remove this + "email_address": user_data.email, + "profile_img": user_data.img_url, + "role": UserRole.DRONE_PILOT.name, + }, + ) + return user_data + + except Exception as e: + if ( + 'duplicate key value violates unique constraint "users_email_address_key"' + in str(e) + ): + raise HTTPException( + status_code=400, + detail=f"User with this email {user_data.email} already exists.", + ) from e + else: + raise HTTPException(status_code=400, detail=str(e)) from e diff --git a/src/backend/app/users/user_deps.py b/src/backend/app/users/user_deps.py index c56434d1..5f424be9 100644 --- a/src/backend/app/users/user_deps.py +++ b/src/backend/app/users/user_deps.py @@ -1,6 +1,5 @@ import jwt from typing import Annotated -from databases import Database from fastapi import Depends, HTTPException, Request, status, Header from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError