Skip to content

Commit

Permalink
Merge pull request #128 from hotosm/feat/psycopg-pydantic
Browse files Browse the repository at this point in the history
Replace encode/databases with psycopg & pydantic model validation
  • Loading branch information
nrjadkry authored Aug 26, 2024
2 parents 497f715 + 588f1eb commit d89cff1
Show file tree
Hide file tree
Showing 37 changed files with 2,237 additions and 2,347 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ db.sqlite3

# ignore python environments
venv
fmtm-env

# project related
temp_webmaps/local_only
Expand Down
6 changes: 4 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ services:
env_file: .env
environment:
LANG: en-GB.utf8
# POSTGRES_INITDB_ARGS: "--locale-provider=icu --icu-locale=en-GB"
POSTGRES_INITDB_ARGS: "--locale-provider=icu --icu-locale=en-GB"
ports:
- "5467:5432"
networks:
- dtm-network
restart: unless-stopped
Expand Down Expand Up @@ -106,5 +108,5 @@ services:
- .env
networks:
- dtm-network
entrypoint: ["pdm", "run", "alembic", "upgrade", "head"]
entrypoint: ["alembic", "upgrade", "head"]
restart: "no"
81 changes: 61 additions & 20 deletions src/backend/app/db/database.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,73 @@
"""Config for the DTM database connection."""

from databases import Database
from typing import AsyncGenerator
from fastapi import Request
from psycopg import Connection
from psycopg_pool import AsyncConnectionPool
from app.config import settings


class DatabaseConnection:
"""Manages database connection (sqlalchemy & encode databases)"""
async def get_db_connection_pool() -> AsyncConnectionPool:
"""Get the connection pool for psycopg."""
return AsyncConnectionPool(conninfo=settings.DTM_DB_URL.unicode_string())

def __init__(self):
self.database = Database(
settings.DTM_DB_URL.unicode_string(),
min_size=5,
max_size=20,
)

async def connect(self):
"""Connect to the database."""
await self.database.connect()
async def get_db(request: Request) -> AsyncGenerator[Connection, None]:
"""Get a connection from the psycopg pool.
async def disconnect(self):
"""Disconnect from the database."""
await self.database.disconnect()
Info on connections vs cursors:
https://www.psycopg.org/psycopg3/docs/advanced/async.html
Here we are getting a connection from the pool, which will be returned
after the session ends / endpoint finishes processing.
db_connection = DatabaseConnection()
In summary:
- Connection is created on endpoint call.
- Cursors are used to execute commands throughout endpoint.
Note it is possible to create multiple cursors from the connection,
but all will be executed in the same db 'transaction'.
- Connection is closed on endpoint finish.
-----------------------------------
To use the connection in endpoints:
-----------------------------------
async def get_db():
"""Get the encode database connection"""
await db_connection.connect()
yield db_connection.database
@app.get("/something/")
async def do_stuff(db = Depends(get_db)):
async with db.cursor() as cursor:
await cursor.execute("SELECT * FROM items")
result = await cursor.fetchall()
return result
-----------------------------------
Additionally, the connection could be passed through to a function to
utilise the Pydantic model serialisation on the cursor:
-----------------------------------
from psycopg.rows import class_row
async def get_user_by_id(db: Connection, id: int):
async with conn.cursor(row_factory=class_row(User)) as cur:
await cur.execute(
'''
SELECT id, first_name, last_name, dob
FROM (VALUES
(1, 'John', 'Doe', '2000-01-01'::date),
(2, 'Jane', 'White', NULL)
) AS data (id, first_name, last_name, dob)
WHERE id = %(id)s;
''',
{"id": id},
)
obj = await cur.fetchone()
# reveal_type(obj) would return 'Optional[User]' here
if not obj:
raise KeyError(f"user {id} not found")
# reveal_type(obj) would return 'User' here
return obj
"""
async with request.app.state.db_pool.connection() as conn:
yield conn
20 changes: 20 additions & 0 deletions src/backend/app/drones/drone_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Annotated
from fastapi import Depends, HTTPException, Path
from psycopg import Connection
from app.db import database
from app.drones.drone_schemas import DbDrone
from app.models.enums import HTTPStatus


async def get_drone_by_id(
drone_id: Annotated[
int,
Path(description="Drone ID."),
],
db: Annotated[Connection, Depends(database.get_db)],
) -> DbDrone:
"""Get a single project by id."""
try:
return await DbDrone.one(db, drone_id)
except KeyError as e:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) from e
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from app.drones import drone_schemas
from app.models.enums import HTTPStatus
from databases import Database
from loguru import logger as log
from fastapi import HTTPException
from asyncpg import UniqueViolationError
from psycopg import Connection

# from asyncpg import UniqueViolationError
from typing import List
from app.drones.drone_schemas import DroneOut


async def read_all_drones(db: Database) -> List[DroneOut]:
async def read_all_drones(db: Connection) -> List[DroneOut]:
"""
Retrieves all drone records from the database.
Expand All @@ -32,7 +32,7 @@ async def read_all_drones(db: Database) -> List[DroneOut]:
) from e


async def delete_drone(db: Database, drone_id: int) -> bool:
async def delete_drone(db: Connection, drone_id: int) -> bool:
"""
Deletes a drone record from the database, along with associated drone flights.
Expand Down Expand Up @@ -63,7 +63,7 @@ async def delete_drone(db: Database, drone_id: int) -> bool:
) from e


async def get_drone(db: Database, drone_id: int):
async def get_drone(db: Connection, drone_id: int):
"""
Retrieves a drone record from the database.
Expand All @@ -87,44 +87,3 @@ async def get_drone(db: Database, drone_id: int):
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Retrieval failed"
) from e


async def create_drone(db: Database, drone_info: drone_schemas.DroneIn):
"""
Creates a new drone record in the database.
Args:
db (Database): The database connection object.
drone (drone_schemas.DroneIn): The schema object containing drone details.
Returns:
The ID of the newly created drone record.
"""
try:
insert_query = """
INSERT INTO drones (
model, manufacturer, camera_model, sensor_width, sensor_height,
max_battery_health, focal_length, image_width, image_height,
max_altitude, max_speed, weight, created
) VALUES (
:model, :manufacturer, :camera_model, :sensor_width, :sensor_height,
:max_battery_health, :focal_length, :image_width, :image_height,
:max_altitude, :max_speed, :weight, CURRENT_TIMESTAMP
)
RETURNING id
"""
result = await db.execute(insert_query, drone_info.__dict__)
return result

except UniqueViolationError as e:
log.exception("Unique constraint violation: %s", e)
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="A drone with this model already exists",
)

except Exception as e:
log.exception(e)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Drone creation failed"
) from e
90 changes: 33 additions & 57 deletions src/backend/app/drones/drone_routes.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,48 @@
from typing import Annotated
from app.users.user_deps import login_required
from app.users.user_schemas import AuthUser
from app.models.enums import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException
from app.db.database import get_db
from app.db import database
from app.config import settings
from app.drones import drone_schemas
from databases import Database
from app.drones import drone_crud
from typing import List
from app.drones import drone_schemas, drone_deps
from psycopg import Connection


router = APIRouter(
prefix=f"{settings.API_PREFIX}/drones",
tags=["Drones"],
responses={404: {"description": "Not found"}},
)


@router.get("/", tags=["Drones"], response_model=List[drone_schemas.DroneOut])
@router.get("/", response_model=list[drone_schemas.DroneOut])
async def read_drones(
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
db: Annotated[Connection, Depends(database.get_db)],
):
"""
Retrieves all drone records from the database.
"""Get all drones."""
try:
return await drone_schemas.DbDrone.all(db)
except KeyError as e:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) from e

Args:
db (Database, optional): The database session object.
user_data (AuthUser, optional): The authenticated user data.

Returns:
List[drone_schemas.DroneOut]: A list of all drone records.
"""
drones = await drone_crud.read_all_drones(db)
return drones
@router.post("/create_drone")
async def create_drone(
drone_info: drone_schemas.DroneIn,
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""Create a new drone in database"""
drone_id = await drone_schemas.DbDrone.create(db, drone_info)
return {"message": "Drone created successfully", "drone_id": drone_id}


@router.delete("/{drone_id}", tags=["Drones"])
@router.delete("/{drone_id}")
async def delete_drone(
drone_id: int,
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
drone: Annotated[drone_schemas.DbDrone, Depends(drone_deps.get_drone_by_id)],
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Deletes a drone record from the database.
Expand All @@ -52,42 +55,18 @@ async def delete_drone(
Returns:
dict: A success message if the drone was deleted.
"""
success = await drone_crud.delete_drone(db, drone_id)
if not success:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Drone not found")
return {"message": "Drone deleted successfully"}


@router.post("/create_drone", tags=["Drones"])
async def create_drone(
drone_info: drone_schemas.DroneIn,
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
):
"""
Creates a new drone record in the database.
Args:
drone_info (drone_schemas.DroneIn): The schema object containing drone details.
db (Database, optional): The database session object.
user_data (AuthUser, optional): The authenticated user data.
Returns:
dict: A dictionary containing a success message and the ID of the newly created drone.
"""
drone_id = await drone_crud.create_drone(db, drone_info)
if not drone_id:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Drone creation failed"
)
return {"message": "Drone created successfully", "drone_id": drone_id}
# TODO: Check user role, Admin can only do this.
# After user roles introduction
drone_id = await drone_schemas.DbDrone.delete(db, drone.id)
return {"message": f"Drone successfully deleted {drone_id}"}


@router.get("/{drone_id}", tags=["Drones"], response_model=drone_schemas.DroneOut)
@router.get("/{drone_id}", response_model=drone_schemas.DbDrone)
async def read_drone(
drone_id: int,
db: Database = Depends(get_db),
user_data: AuthUser = Depends(login_required),
drone: Annotated[drone_schemas.DbDrone, Depends(drone_deps.get_drone_by_id)],
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Retrieves a drone record from the database.
Expand All @@ -100,7 +79,4 @@ async def read_drone(
Returns:
dict: The drone record if found.
"""
drone = await drone_crud.get_drone(db, drone_id)
if not drone:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Drone not found")
return drone
Loading

0 comments on commit d89cff1

Please sign in to comment.