Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Drone Module using psycopg and Pydantic #137

Merged
merged 10 commits into from
Aug 9, 2024
42 changes: 0 additions & 42 deletions src/backend/app/drones/drone_crud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from app.drones import drone_schemas
from app.models.enums import HTTPStatus
from loguru import logger as log
from fastapi import HTTPException
Expand Down Expand Up @@ -88,44 +87,3 @@ async def get_drone(db: Connection, drone_id: int):
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Retrieval failed"
) from e


async def create_drone(db: Connection, drone_info: drone_schemas.DroneIn):
"""
Creates a new drone record in the database.

Args:
db (Database): The database connection object.
drone (drone_schemas.DroneIn): The schema object containing drone details.

Returns:
The ID of the newly created drone record.
"""
try:
insert_query = """
INSERT INTO drones (
model, manufacturer, camera_model, sensor_width, sensor_height,
max_battery_health, focal_length, image_width, image_height,
max_altitude, max_speed, weight, created
) VALUES (
:model, :manufacturer, :camera_model, :sensor_width, :sensor_height,
:max_battery_health, :focal_length, :image_width, :image_height,
:max_altitude, :max_speed, :weight, CURRENT_TIMESTAMP
)
RETURNING id
"""
result = await db.execute(insert_query, drone_info.__dict__)
return result

# except UniqueViolationError as e:
# log.exception("Unique constraint violation: %s", e)
# raise HTTPException(
# status_code=HTTPStatus.CONFLICT,
# detail="A drone with this model already exists",
# )

except Exception as e:
log.exception(e)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Drone creation failed"
) from e
20 changes: 20 additions & 0 deletions src/backend/app/drones/drone_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Annotated
from fastapi import Depends, HTTPException, Path
from psycopg import Connection
from app.db import database
from app.drones.drone_schemas import DbDrone
from app.models.enums import HTTPStatus


async def get_drone_by_id(
drone_id: Annotated[
int,
Path(description="Drone ID."),
],
db: Annotated[Connection, Depends(database.get_db)],
) -> DbDrone:
"""Get a single project by id."""
try:
return await DbDrone.one(db, drone_id)
except KeyError as e:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) from e
75 changes: 25 additions & 50 deletions src/backend/app/drones/drone_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,42 @@
from fastapi import APIRouter, Depends, HTTPException
from app.db import database
from app.config import settings
from app.drones import drone_schemas
from app.drones import drone_schemas, drone_deps
from psycopg import Connection
from app.drones import drone_crud
from typing import List


router = APIRouter(
prefix=f"{settings.API_PREFIX}/drones",
tags=["Drones"],
responses={404: {"description": "Not found"}},
)


@router.get("/", tags=["Drones"], response_model=List[drone_schemas.DroneOut])
@router.get("/", response_model=list[drone_schemas.DroneOut])
async def read_drones(
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Retrieves all drone records from the database.
"""Get all drones."""
try:
return await drone_schemas.DbDrone.all(db)
except KeyError as e:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) from e

Args:
db (Database, optional): The database session object.
user_data (AuthUser, optional): The authenticated user data.

Returns:
List[drone_schemas.DroneOut]: A list of all drone records.
"""
drones = await drone_crud.read_all_drones(db)
return drones
@router.post("/create_drone")
async def create_drone(
drone_info: drone_schemas.DroneIn,
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""Create a new drone in database"""
drone_id = await drone_schemas.DbDrone.create(db, drone_info)
return {"message": "Drone created successfully", "drone_id": drone_id}


@router.delete("/{drone_id}", tags=["Drones"])
@router.delete("/{drone_id}")
async def delete_drone(
drone_id: int,
drone: Annotated[drone_schemas.DbDrone, Depends(drone_deps.get_drone_by_id)],
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
Expand All @@ -53,40 +55,16 @@ async def delete_drone(
Returns:
dict: A success message if the drone was deleted.
"""
success = await drone_crud.delete_drone(db, drone_id)
if not success:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Drone not found")
return {"message": "Drone deleted successfully"}


@router.post("/create_drone", tags=["Drones"])
async def create_drone(
drone_info: drone_schemas.DroneIn,
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Creates a new drone record in the database.

Args:
drone_info (drone_schemas.DroneIn): The schema object containing drone details.
db (Database, optional): The database session object.
user_data (AuthUser, optional): The authenticated user data.

Returns:
dict: A dictionary containing a success message and the ID of the newly created drone.
"""
drone_id = await drone_crud.create_drone(db, drone_info)
if not drone_id:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Drone creation failed"
)
return {"message": "Drone created successfully", "drone_id": drone_id}
# TODO: Check user role, Admin can only do this.
# After user roles introduction
drone_id = await drone_schemas.DbDrone.delete(db, drone.id)
return {"message": f"Drone successfully deleted {drone_id}"}


@router.get("/{drone_id}", tags=["Drones"], response_model=drone_schemas.DroneOut)
@router.get("/{drone_id}", response_model=drone_schemas.DbDrone)
async def read_drone(
drone_id: int,
drone: Annotated[drone_schemas.DbDrone, Depends(drone_deps.get_drone_by_id)],
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
Expand All @@ -101,7 +79,4 @@ async def read_drone(
Returns:
dict: The drone record if found.
"""
drone = await drone_crud.get_drone(db, drone_id)
if not drone:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Drone not found")
return drone
108 changes: 107 additions & 1 deletion src/backend/app/drones/drone_schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from pydantic import BaseModel
from fastapi import HTTPException
from app.models.enums import HTTPStatus
from psycopg import Connection
from psycopg.rows import class_row


class DroneIn(BaseModel):
class BaseDrone(BaseModel):
model: str
manufacturer: str
camera_model: str
Expand All @@ -16,6 +20,108 @@ class DroneIn(BaseModel):
weight: float


class DroneIn(BaseDrone):
"""Model for drone creation"""


class DroneOut(BaseModel):
id: int
model: str


class DbDrone(BaseDrone):
id: int

@staticmethod
async def one(db: Connection, drone_id: int):
"""Get a single drone by it's ID"""
print("drone_id = ", drone_id)
async with db.cursor(row_factory=class_row(DbDrone)) as cur:
await cur.execute(
"""
SELECT * FROM drones
WHERE id = %(drone_id)s;
""",
{"drone_id": drone_id},
)
drone = await cur.fetchone()

if not drone:
raise KeyError(f"Drone {drone_id} not found")

return drone

@staticmethod
async def all(db: Connection):
"""Get all drones"""
async with db.cursor(row_factory=class_row(DbDrone)) as cur:
await cur.execute(
"""
SELECT * FROM drones d
GROUP BY d.id;
"""
)
drones = await cur.fetchall()

if not drones:
raise KeyError("No drones found")
return drones

@staticmethod
async def delete(db: Connection, drone_id: int):
"""Delete a single drone by its ID."""
async with db.cursor() as cur:
await cur.execute(
"""
DELETE FROM drones
WHERE id = %(drone_id)s
RETURNING id;
""",
{"drone_id": drone_id},
)
deleted_drone_id = await cur.fetchone()

if not deleted_drone_id:
raise KeyError(f"Drone {drone_id} not found or could not be deleted")

return deleted_drone_id[0]

@staticmethod
async def create(db: Connection, drone: DroneIn):
"""Create a single drone."""
# NOTE we first check if a drone with this model name exists
async with db.cursor() as cur:
sql = """
SELECT EXISTS (
SELECT 1
FROM drones
WHERE LOWER(model) = %(model_name)s
)
"""
await cur.execute(sql, {"model_name": drone.model.lower()})
project_exists = await cur.fetchone()
if project_exists[0]:
msg = f"Drone ({drone.model}) already exists!"
raise HTTPException(status_code=HTTPStatus.CONFLICT, detail=msg)

# If drone with the same model does not already exists, add a new one.
model_dump = drone.model_dump()
columns = ", ".join(model_dump.keys())
value_placeholders = ", ".join(f"%({key})s" for key in model_dump.keys())

sql = f"""
INSERT INTO drones ({columns}, created)
VALUES ({value_placeholders}, NOW())
RETURNING id;
"""

async with db.cursor() as cur:
await cur.execute(sql, model_dump)
new_drone_id = await cur.fetchone()

if not new_drone_id:
msg = f"Unknown SQL error for data: {model_dump}"
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=msg
)
return new_drone_id[0]