Skip to content

Commit

Permalink
refactor: replace all refs to Database --> Connection (psycopg)
Browse files Browse the repository at this point in the history
  • Loading branch information
spwoodcock committed Aug 7, 2024
1 parent 8ebec65 commit 5aa3787
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 55 deletions.
25 changes: 13 additions & 12 deletions src/backend/app/drones/drone_crud.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from app.drones import drone_schemas
from app.models.enums import HTTPStatus
from databases import Database
from loguru import logger as log
from fastapi import HTTPException
from asyncpg import UniqueViolationError
from psycopg import Connection

# from asyncpg import UniqueViolationError
from typing import List
from app.drones.drone_schemas import DroneOut


async def read_all_drones(db: Database) -> List[DroneOut]:
async def read_all_drones(db: Connection) -> List[DroneOut]:
"""
Retrieves all drone records from the database.
Expand All @@ -32,7 +33,7 @@ async def read_all_drones(db: Database) -> List[DroneOut]:
) from e


async def delete_drone(db: Database, drone_id: int) -> bool:
async def delete_drone(db: Connection, drone_id: int) -> bool:
"""
Deletes a drone record from the database, along with associated drone flights.
Expand Down Expand Up @@ -63,7 +64,7 @@ async def delete_drone(db: Database, drone_id: int) -> bool:
) from e


async def get_drone(db: Database, drone_id: int):
async def get_drone(db: Connection, drone_id: int):
"""
Retrieves a drone record from the database.
Expand All @@ -89,7 +90,7 @@ async def get_drone(db: Database, drone_id: int):
) from e


async def create_drone(db: Database, drone_info: drone_schemas.DroneIn):
async def create_drone(db: Connection, drone_info: drone_schemas.DroneIn):
"""
Creates a new drone record in the database.
Expand All @@ -116,12 +117,12 @@ async def create_drone(db: Database, drone_info: drone_schemas.DroneIn):
result = await db.execute(insert_query, drone_info.__dict__)
return result

except UniqueViolationError as e:
log.exception("Unique constraint violation: %s", e)
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="A drone with this model already exists",
)
# except UniqueViolationError as e:
# log.exception("Unique constraint violation: %s", e)
# raise HTTPException(
# status_code=HTTPStatus.CONFLICT,
# detail="A drone with this model already exists",
# )

except Exception as e:
log.exception(e)
Expand Down
21 changes: 11 additions & 10 deletions src/backend/app/drones/drone_routes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Annotated
from app.users.user_deps import login_required
from app.users.user_schemas import AuthUser
from app.models.enums import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException
from app.db.database import get_db
from app.db import database
from app.config import settings
from app.drones import drone_schemas
from databases import Database
from psycopg import Connection
from app.drones import drone_crud
from typing import List

Expand All @@ -18,8 +19,8 @@

@router.get("/", tags=["Drones"], response_model=List[drone_schemas.DroneOut])
async def read_drones(
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Retrieves all drone records from the database.
Expand All @@ -38,8 +39,8 @@ async def read_drones(
@router.delete("/{drone_id}", tags=["Drones"])
async def delete_drone(
drone_id: int,
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Deletes a drone record from the database.
Expand All @@ -61,8 +62,8 @@ async def delete_drone(
@router.post("/create_drone", tags=["Drones"])
async def create_drone(
drone_info: drone_schemas.DroneIn,
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Creates a new drone record in the database.
Expand All @@ -86,8 +87,8 @@ async def create_drone(
@router.get("/{drone_id}", tags=["Drones"], response_model=drone_schemas.DroneOut)
async def read_drone(
drone_id: int,
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Retrieves a drone record from the database.
Expand Down
2 changes: 0 additions & 2 deletions src/backend/app/projects/project_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class ProjectIn(BaseModel):
dem_url: Optional[str] = None
gsd_cm_px: float = None
is_terrain_follow: bool = False
# TODO change all references outline_geojson --> outline
# TODO also no_fly_zones
outline: Annotated[
FeatureCollection | Feature | Polygon, AfterValidator(validate_geojson)
]
Expand Down
22 changes: 13 additions & 9 deletions src/backend/app/tasks/task_crud.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid
from databases import Database
from app.models.enums import HTTPStatus, State
from fastapi import HTTPException
from loguru import logger as log
from psycopg import Connection


async def get_tasks_by_user(user_id: str, db: Database):
async def get_tasks_by_user(user_id: str, db: Connection):
try:
query = """WITH task_details AS (
SELECT
Expand Down Expand Up @@ -42,7 +42,7 @@ async def get_tasks_by_user(user_id: str, db: Database):
) from e


async def get_all_tasks(db: Database, project_id: uuid.UUID):
async def get_all_tasks(db: Connection, project_id: uuid.UUID):
query = """
SELECT id FROM tasks WHERE project_id = :project_id
"""
Expand All @@ -56,7 +56,7 @@ async def get_all_tasks(db: Database, project_id: uuid.UUID):
return task_ids


async def all_tasks_states(db: Database, project_id: uuid.UUID):
async def all_tasks_states(db: Connection, project_id: uuid.UUID):
query = """
SELECT DISTINCT ON (task_id) project_id, task_id, state
FROM task_events
Expand Down Expand Up @@ -95,7 +95,11 @@ async def all_tasks_states(db: Database, project_id: uuid.UUID):


async def request_mapping(
db: Database, project_id: uuid.UUID, task_id: uuid.UUID, user_id: str, comment: str
db: Connection,
project_id: uuid.UUID,
task_id: uuid.UUID,
user_id: str,
comment: str,
):
query = """
WITH last AS (
Expand Down Expand Up @@ -139,8 +143,8 @@ async def request_mapping(
return {"project_id": project_id, "task_id": task_id, "comment": comment}


async def update_or_create_task_state(
db: Database,
async def update_task_state(
db: Connection,
project_id: uuid.UUID,
task_id: uuid.UUID,
user_id: str,
Expand Down Expand Up @@ -195,7 +199,7 @@ async def update_or_create_task_state(


async def get_requested_user_id(
db: Database, project_id: uuid.UUID, task_id: uuid.UUID
db: Connection, project_id: uuid.UUID, task_id: uuid.UUID
):
query = """
SELECT user_id
Expand All @@ -216,7 +220,7 @@ async def get_requested_user_id(
return result["user_id"]


async def get_project_task_by_id(db: Database, user_id: str):
async def get_project_task_by_id(db: Connection, user_id: str):
"""Get a list of pending tasks created by a specific user (project creator)."""
_sql = """
SELECT id FROM projects WHERE author_id = :user_id
Expand Down
19 changes: 11 additions & 8 deletions src/backend/app/tasks/task_routes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid
from typing import Annotated
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from app.config import settings
from app.models.enums import EventType, State, UserRole
from app.tasks import task_schemas, task_crud
from app.users.user_deps import login_required
from app.users.user_schemas import AuthUser
from app.users.user_crud import get_user_by_id
from databases import Database
from psycopg import Connection
from app.db import database
from app.utils import send_notification_email, render_email_template
from app.projects.project_crud import get_project_by_id
Expand All @@ -21,8 +22,8 @@

@router.get("/", response_model=list[task_schemas.UserTasksStatsOut])
async def list_tasks(
db: Database = Depends(database.get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""Get all tasks for a drone user."""

Expand All @@ -31,20 +32,22 @@ async def list_tasks(


@router.get("/states/{project_id}")
async def task_states(project_id: uuid.UUID, db: Database = Depends(database.get_db)):
async def task_states(
db: Annotated[Connection, Depends(database.get_db)], project_id: uuid.UUID
):
"""Get all tasks states for a project."""

return await task_crud.all_tasks_states(db, project_id)


@router.post("/event/{project_id}/{task_id}")
async def new_event(
db: Annotated[Connection, Depends(database.get_db)],
background_tasks: BackgroundTasks,
project_id: uuid.UUID,
task_id: uuid.UUID,
detail: task_schemas.NewEvent,
user_data: AuthUser = Depends(login_required),
db: Database = Depends(database.get_db),
user_data: Annotated[AuthUser, Depends(login_required)],
):
user_id = user_data.id

Expand Down Expand Up @@ -223,8 +226,8 @@ async def new_event(

@router.get("/requested_tasks/pending")
async def get_pending_tasks(
user_data: AuthUser = Depends(login_required),
db: Database = Depends(database.get_db),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""Get a list of pending tasks for a project creator."""
user_id = user_data.id
Expand Down
14 changes: 7 additions & 7 deletions src/backend/app/users/user_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from passlib.context import CryptContext
from app.db import db_models
from app.users.user_schemas import AuthUser, ProfileUpdate
from databases import Database
from fastapi import HTTPException
from pydantic import EmailStr
from psycopg import Connection


pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Expand Down Expand Up @@ -61,26 +61,26 @@ def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


async def get_user_by_id(db: Database, id: str):
async def get_user_by_id(db: Connection, id: str):
query = "SELECT * FROM users WHERE id = :id LIMIT 1;"
result = await db.fetch_one(query, {"id": id})
return result


async def get_userprofile_by_userid(db: Database, user_id: str):
async def get_userprofile_by_userid(db: Connection, user_id: str):
query = "SELECT * FROM user_profile WHERE user_id = :user_id LIMIT 1;"
result = await db.fetch_one(query, {"user_id": user_id})
return result


async def get_user_by_email(db: Database, email: str):
async def get_user_by_email(db: Connection, email: str):
query = "SELECT * FROM users WHERE email_address = :email LIMIT 1;"
result = await db.fetch_one(query, {"email": email})
return result


async def authenticate(
db: Database, email: EmailStr, password: str
db: Connection, email: EmailStr, password: str
) -> db_models.DbUser | None:
db_user = await get_user_by_email(db, email)
if not db_user:
Expand All @@ -91,7 +91,7 @@ async def authenticate(


async def get_or_create_user(
db: Database,
db: Connection,
user_data: AuthUser,
):
"""Get user from User table if exists, else create."""
Expand Down Expand Up @@ -132,7 +132,7 @@ async def get_or_create_user(


async def update_user_profile(
db: Database, user_id: int, profile_update: ProfileUpdate
db: Connection, user_id: int, profile_update: ProfileUpdate
):
"""
Update user profile in the database.
Expand Down
14 changes: 7 additions & 7 deletions src/backend/app/users/user_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from app.users import user_crud
from app.db import database
from app.models.enums import HTTPStatus
from databases import Database
from psycopg import Connection
from fastapi.responses import JSONResponse
from loguru import logger as log

Expand All @@ -31,7 +31,7 @@
@router.post("/login/")
async def login_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Database = Depends(database.get_db),
db: Annotated[Connection, Depends(database.get_db)],
) -> Token:
"""
OAuth2 compatible token login, get an access token for future requests
Expand Down Expand Up @@ -60,8 +60,8 @@ async def login_access_token(
async def update_user_profile(
user_id: str,
profile_update: ProfileUpdate,
db: Database = Depends(database.get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Update user profile based on provided user_id and profile_update data.
Expand Down Expand Up @@ -124,7 +124,7 @@ async def callback(request: Request, google_auth=Depends(init_google_auth)):


@router.get("/refresh-token", response_model=Token)
async def update_token(user_data: AuthUser = Depends(login_required)):
async def update_token(user_data: Annotated[AuthUser, Depends(login_required)]):
"""Refresh access token"""

access_token, refresh_token = await user_crud.create_access_token(
Expand All @@ -135,8 +135,8 @@ async def update_token(user_data: AuthUser = Depends(login_required)):

@router.get("/my-info/")
async def my_data(
db: Database = Depends(database.get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""Read access token and get user details from Google"""

Expand Down

0 comments on commit 5aa3787

Please sign in to comment.