diff --git a/app/exceptions/UserException.py b/app/exceptions/UserException.py index ac630fe..0b31397 100644 --- a/app/exceptions/UserException.py +++ b/app/exceptions/UserException.py @@ -17,7 +17,6 @@ def __init__(self): class InvalidURL(HTTPException): - def __init__(self, id: int): + def __init__(self, detail: str): status_code = status.HTTP_400_BAD_REQUEST - detail = "Invalid URL" super().__init__(status_code=status_code, detail=detail) diff --git a/app/main.py b/app/main.py index 2ecab21..634affa 100644 --- a/app/main.py +++ b/app/main.py @@ -14,7 +14,7 @@ @app.get("/") def root(): - return {"message": "Hello World"} + return {"message": "users service"} @app.get("/users/{user_id}") diff --git a/app/repository/Users.py b/app/repository/Users.py index d913bb0..cca3467 100644 --- a/app/repository/Users.py +++ b/app/repository/Users.py @@ -5,6 +5,7 @@ from models.database import Base from models.users import User from datetime import date +from .sql_exception_handling import withSQLExceptionsHandle class UsersRepository: @@ -25,14 +26,17 @@ def shutdown(self): def rollback(self): self.session.rollback() + @withSQLExceptionsHandle() def add(self, record: Base): self.session.add(record) self.session.commit() + @withSQLExceptionsHandle() def get_user(self, user_id: int): user = self.session.query(User).filter_by(id=user_id).first() return user.__dict__ if user else None + @withSQLExceptionsHandle() def get_user_by_email(self, email: str): user = self.session.query(User).filter_by(email=email).first() return user.__dict__ if user else None @@ -45,6 +49,7 @@ def get_users_by_ids(self, ids: list): users = self.session.query(User).filter(User.id.in_(ids)).all() return self.__parse_result(users) + @withSQLExceptionsHandle() def create_user( self, email: str, @@ -78,6 +83,7 @@ def create_user( self.session.commit() return new_user + @withSQLExceptionsHandle() def edit_user(self, user_id: int, data_to_edit: dict): user = self.session.query(User).filter_by(id=user_id).first() for field, value in data_to_edit.items(): diff --git a/app/repository/sql_exception_handling.py b/app/repository/sql_exception_handling.py new file mode 100644 index 0000000..fa81f24 --- /dev/null +++ b/app/repository/sql_exception_handling.py @@ -0,0 +1,59 @@ +from psycopg2.errors import UniqueViolation +from sqlalchemy.exc import PendingRollbackError, IntegrityError, NoResultFound +from fastapi import status, HTTPException +import logging + +logger = logging.getLogger("app") +logger.setLevel("DEBUG") + + +def handle_common_errors(err): + if isinstance(err, IntegrityError): + if isinstance(err.orig, UniqueViolation): + parsed_error = err.orig.pgerror.split("\n") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": parsed_error[0], + "detail": parsed_error[1] + }) + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=format(err)) + + if isinstance(err, PendingRollbackError): + logger.warning(format(err)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=format(err) + ) + + if isinstance(err, NoResultFound): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=format(err) + ) + + logger.error(format(err)) + raise err + + +def withSQLExceptionsHandle(async_mode: bool = False): + def decorator(func): + async def handleAsyncSQLException(*args, **kwargs): + try: + return await func(*args, **kwargs) + except Exception as err: + return handle_common_errors(err) + + def handleSyncSQLException(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as err: + return handle_common_errors(err) + + return ( + handleAsyncSQLException if async_mode else handleSyncSQLException + ) + + return decorator diff --git a/app/service/Users.py b/app/service/Users.py index 8f71524..24959cb 100644 --- a/app/service/Users.py +++ b/app/service/Users.py @@ -27,7 +27,12 @@ def get_users_by_ids(self, ids: list): def create_user(self, user_data: dict): if not self._validate_location(user_data.get("location")): raise InvalidData() - return self.user_repository.create_user(**user_data) + try: + user = self.user_repository.create_user(**user_data) + return user + except Exception as e: + self.user_repository.rollback() + raise e def update_user(self, user_id: int, update_data: dict): # TODO: aca habria que chequear a partir del token, session o algo que @@ -41,7 +46,11 @@ def update_user(self, user_id: int, update_data: dict): if not re.match(r'^https?://(?:[a-zA-Z0-9-]+\.)+[a-zA-Z]{2,6}' r'(?:/[^/#?]+)+(?:\?.*)?$', photo_url): raise InvalidURL("Invalid photo URL") - self.user_repository.edit_user(user_id, filtered_update_data) + try: + self.user_repository.edit_user(user_id, filtered_update_data) + except Exception as e: + self.user_repository.rollback() + raise e def login(self, auth_code: str): access_token = self._get_access_token(auth_code)