Skip to content

Commit

Permalink
map sql exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
violetaperezandrade committed Apr 19, 2024
1 parent 32d4b00 commit b234521
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
9 changes: 8 additions & 1 deletion app/repository/Users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from models.database import Base
from models.users import User
from datetime import date
from .sql_exception_handling import withSQLExceptionsHandle


class UsersRepository:
Expand All @@ -27,22 +28,27 @@ def shutdown(self):
def rollback(self):
self.session.rollback()

@withSQLExceptionsHandle()
def add(self, record: Base):
self.session.add(record)
self.session.commit()

@withSQLExceptionsHandle()
def get_user(self, user_id: int):
user = self.session.query(User).filter_by(id=user_id).first()
return user.__dict__ if user else None

@withSQLExceptionsHandle()
def get_user_by_email(self, email: str):
user = self.session.query(User).filter_by(email=email).first()
return user.__dict__ if user else None

@withSQLExceptionsHandle()
def get_all_users(self):
users = self.session.query(User).all()
return self.__parse_result(users)

@withSQLExceptionsHandle()
def create_user(
self,
email: str,
Expand Down Expand Up @@ -76,6 +82,7 @@ def create_user(
self.session.commit()
return new_user

@withSQLExceptionsHandle()
def edit_user(self, user_id: int, data_to_edit: dict):
user = self.session.query(User).filter_by(id=user_id).first()
for field, value in data_to_edit.items():
Expand All @@ -87,4 +94,4 @@ def edit_user(self, user_id: int, data_to_edit: dict):
def __parse_result(self, result):
if not result:
return []
return [r.__dict__ for r in result]
return [r.__dict__ for r in result]
59 changes: 59 additions & 0 deletions app/repository/sql_exception_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from psycopg2.errors import UniqueViolation
from sqlalchemy.exc import PendingRollbackError, IntegrityError, NoResultFound
from fastapi import status, HTTPException
import logging

logger = logging.getLogger("app")
logger.setLevel("DEBUG")


def handle_common_errors(err):
if isinstance(err, IntegrityError):
if isinstance(err.orig, UniqueViolation):
parsed_error = err.orig.pgerror.split("\n")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": parsed_error[0],
"detail": parsed_error[1]
})

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=format(err))

if isinstance(err, PendingRollbackError):
logger.warning(format(err))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=format(err)
)

if isinstance(err, NoResultFound):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=format(err)
)

logger.error(format(err))
raise err


def withSQLExceptionsHandle(async_mode: bool = False):
def decorator(func):
async def handleAsyncSQLException(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as err:
return handle_common_errors(err)

def handleSyncSQLException(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
return handle_common_errors(err)

return (
handleAsyncSQLException if async_mode else handleSyncSQLException
)

return decorator

0 comments on commit b234521

Please sign in to comment.