diff --git a/app/api/v1/endpoints/project.py b/app/api/v1/endpoints/project.py index b033164f..5ccfd2b3 100644 --- a/app/api/v1/endpoints/project.py +++ b/app/api/v1/endpoints/project.py @@ -1,8 +1,10 @@ from typing import Optional from uuid import uuid4 -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from app.api.deps import api_key_header from app.api.v1.schemas.project import ( ExperimentStatus, ModelSummary, @@ -15,6 +17,8 @@ ProjectsResponse, ProjectSummaryPayload, ) +from app.services.project import project_service +from netspresso.utils.db.session import get_db router = APIRouter() @@ -22,14 +26,13 @@ @router.post("", response_model=ProjectResponse) def create_project( *, + db: Session = Depends(get_db), + api_key: str = Depends(api_key_header), request_body: ProjectCreate, ) -> ProjectResponse: + project = project_service.create_project(db=db, project_name=request_body.project_name, api_key=api_key) - project = ProjectSummaryPayload( - project_id=str(uuid4()), - project_name=request_body.project_name, - user_id=str(uuid4()), - ) + project = ProjectSummaryPayload.model_validate(project) return ProjectResponse(data=project) diff --git a/app/api/v1/schemas/project.py b/app/api/v1/schemas/project.py index de722904..66f5bb30 100644 --- a/app/api/v1/schemas/project.py +++ b/app/api/v1/schemas/project.py @@ -4,6 +4,7 @@ from app.api.v1.schemas.base import ResponseItem, ResponsePaginationItems from netspresso.enums import Status +from netspresso.exceptions.project import ProjectNameTooLongException class ProjectCreate(BaseModel): @@ -12,7 +13,7 @@ class ProjectCreate(BaseModel): @field_validator("project_name") def validate_length_of_project_name(cls, project_name: str) -> str: if len(project_name) > 30: - raise ValueError("The project_name can't exceed 30 characters.") + raise ProjectNameTooLongException(max_length=30, actual_length=len(project_name)) return project_name diff --git a/app/main.py b/app/main.py index 98788f13..92011135 100644 --- a/app/main.py +++ b/app/main.py @@ -1,17 +1,27 @@ from typing import List -from fastapi import FastAPI +from fastapi import FastAPI, Request, status from fastapi.middleware import Middleware +from fastapi.responses import JSONResponse from starlette.middleware.cors import CORSMiddleware from app.api.api import api_router from app.configs.settings import settings +from netspresso.exceptions.common import PyNPException +from netspresso.exceptions.status import STATUS_MAP def init_routers(app: FastAPI) -> None: app.include_router(api_router, prefix=settings.API_PREFIX) +def init_exceptions(app: FastAPI) -> None: + @app.exception_handler(PyNPException) + async def http_exception_handler(request: Request, exc: PyNPException): + status_code = STATUS_MAP.get(exc.detail["error_code"], status.HTTP_500_INTERNAL_SERVER_ERROR) + + return JSONResponse(status_code=status_code, content=exc.detail) + def make_middleware() -> List[Middleware]: origins = ["*"] middleware = [ @@ -36,6 +46,7 @@ def create_app(): middleware=make_middleware(), ) init_routers(app=app) + init_exceptions(app=app) return app diff --git a/app/services/project.py b/app/services/project.py new file mode 100644 index 00000000..79f0014c --- /dev/null +++ b/app/services/project.py @@ -0,0 +1,19 @@ +from sqlalchemy.orm import Session + +from app.services.user import user_service +from netspresso.netspresso import NetsPresso +from netspresso.utils.db.models.project import Project + + +class ProjectService: + def create_project(self, db: Session, project_name: str, api_key: str) -> Project: + user = user_service.get_user_by_api_key(db=db, api_key=api_key) + + netspresso = NetsPresso(email=user.email, password=user.password) + + project = netspresso.create_project(project_name=project_name) + + return project + + +project_service = ProjectService() diff --git a/app/services/user.py b/app/services/user.py index a5ade6a7..baca7933 100644 --- a/app/services/user.py +++ b/app/services/user.py @@ -8,7 +8,7 @@ class UserService: - def create_user(self, db: Session, email: str, password: str, api_key: str): + def create_user(self, db: Session, email: str, password: str, api_key: str) -> User: user = User( email=email, password=password, @@ -40,9 +40,14 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl return api_key - def get_user_info(self, db: Session, api_key: str) -> UserPayload: + def get_user_by_api_key(self, db: Session, api_key: str) -> User: user = user_repository.get_by_api_key(db=db, api_key=api_key) + return user + + def get_user_info(self, db: Session, api_key: str) -> UserPayload: + user = self.get_user_by_api_key(db=db, api_key=api_key) + netspresso = NetsPresso(email=user.email, password=user.password) user = UserPayload( diff --git a/netspresso/exceptions/common.py b/netspresso/exceptions/common.py index ead673be..d1ddc482 100644 --- a/netspresso/exceptions/common.py +++ b/netspresso/exceptions/common.py @@ -3,6 +3,15 @@ from typing import List, Optional +class Origin(str, Enum): + ROUTER = "router" + SERVICE = "service" + REPOSITORY = "repository" + CORE = "core" + CLIENT = "client" + LIBRARY = "library" + + class LinkType(str, Enum): DOCS = "docs" CONTACT = "contact" diff --git a/netspresso/exceptions/project.py b/netspresso/exceptions/project.py new file mode 100644 index 00000000..633a7f06 --- /dev/null +++ b/netspresso/exceptions/project.py @@ -0,0 +1,34 @@ +from netspresso.exceptions.common import AdditionalData, Origin, PyNPException + + +class ProjectNameTooLongException(PyNPException): + def __init__(self, max_length: int, actual_length: int): + message = f"The project_name exceeds maximum length. Max: {max_length}, Actual: {actual_length}" + super().__init__( + data=AdditionalData(origin=Origin.CORE), + error_code="PROJECT40001", + name=self.__class__.__name__, + message=message, + ) + + +class ProjectAlreadyExistsException(PyNPException): + def __init__(self, project_name: str, project_path: str): + message = f"The project_name '{project_name}' already exists at '{project_path}'." + super().__init__( + data=AdditionalData(origin=Origin.CORE), + error_code="PROJECT40901", + name=self.__class__.__name__, + message=message, + ) + + +class ProjectSaveException(PyNPException): + def __init__(self, error: Exception, project_name: str): + message = f"Failed to save project '{project_name}' to the database: {str(error)}" + super().__init__( + data=AdditionalData(origin=Origin.REPOSITORY), + error_code="PROJECT50001", + name=self.__class__.__name__, + message=message, + ) diff --git a/netspresso/exceptions/status.py b/netspresso/exceptions/status.py new file mode 100644 index 00000000..31d0cbfe --- /dev/null +++ b/netspresso/exceptions/status.py @@ -0,0 +1,7 @@ +from fastapi import status + +STATUS_MAP = { + "PROJECT40001": status.HTTP_400_BAD_REQUEST, + "PROJECT40901": status.HTTP_409_CONFLICT, + "PROJECT50001": status.HTTP_500_INTERNAL_SERVER_ERROR, +} diff --git a/netspresso/netspresso.py b/netspresso/netspresso.py index 7d962213..8efa8ee4 100644 --- a/netspresso/netspresso.py +++ b/netspresso/netspresso.py @@ -11,13 +11,18 @@ from netspresso.constant.project import SUB_FOLDERS from netspresso.converter import ConverterV2 from netspresso.enums import Task +from netspresso.exceptions.project import ( + ProjectAlreadyExistsException, + ProjectNameTooLongException, + ProjectSaveException, +) from netspresso.inferencer.inferencer import CustomInferencer, NPInferencer from netspresso.quantizer import Quantizer from netspresso.tao import TAOTrainer from netspresso.trainer import Trainer from netspresso.utils.db.models.project import Project from netspresso.utils.db.repositories.project import project_repository -from netspresso.utils.db.session import get_db +from netspresso.utils.db.session import SessionLocal class NetsPresso: @@ -46,8 +51,33 @@ def get_user(self) -> UserResponse: return user_info def create_project(self, project_name: str, project_path: str = "./projects") -> Project: + """ + Create a new project with the specified name and path. + + This method creates a project directory structure on the file system + and saves the project information in the database. It also handles + scenarios where the project name is too long or already exists. + + Args: + project_name (str): The name of the project to create. + Must not exceed 30 characters. + project_path (str, optional): The base path where the project + will be created. Defaults to "./projects". + + Returns: + Project: The created project object containing information + such as project name, user ID, and absolute path. + + Raises: + ProjectNameTooLongException: If the `project_name` exceeds the + maximum allowed length of 30 characters. + ProjectAlreadyExistsException: If a project with the same name + already exists at the specified `project_path`. + ProjectSaveException: If an error occurs while saving the project + to the database. + """ if len(project_name) > 30: - raise ValueError("The project_name can't exceed 30 characters.") + raise ProjectNameTooLongException(max_length=30, actual_length=len(project_name)) # Create the main project folder project_folder_path = Path(project_path) / project_name @@ -55,6 +85,10 @@ def create_project(self, project_name: str, project_path: str = "./projects") -> # Check if the project folder already exists if project_folder_path.exists(): logger.warning(f"Project '{project_name}' already exists at {project_folder_path.resolve()}.") + raise ProjectAlreadyExistsException( + project_name=project_name, + project_path=project_folder_path.resolve().as_posix() + ) else: project_folder_path.mkdir(parents=True, exist_ok=True) project_abs_path = project_folder_path.resolve() @@ -65,31 +99,49 @@ def create_project(self, project_name: str, project_path: str = "./projects") -> logger.info(f"Project '{project_name}' created at {project_abs_path}.") + db = None try: - with get_db() as db: - project = Project( - project_name=project_name, - user_id=self.user_info.user_id, - project_abs_path=project_abs_path.as_posix(), - ) - project = project_repository.save(db=db, model=project) + db = SessionLocal() + project = Project( + project_name=project_name, + user_id=self.user_info.user_id, + project_abs_path=project_abs_path.as_posix(), + ) + project = project_repository.save(db=db, model=project) - return project + return project except Exception as e: logger.error(f"Failed to save project '{project_name}' to the database: {e}") - raise + raise ProjectSaveException(error=e, project_name=project_name) + finally: + db and db.close() def get_projects(self) -> List[Project]: + """ + Retrieve all projects associated with the current user. + + This method fetches project information from the database for + the user identified by `self.user_info.user_id`. + + Returns: + List[Project]: A list of projects associated with the current user. + + Raises: + Exception: If an error occurs while querying the database. + """ + db = None try: - with get_db() as db: - projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id) + db = SessionLocal() + projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id) - return projects + return projects except Exception as e: logger.error(f"Failed to get project list from the database: {e}") raise + finally: + db and db.close() def trainer( self, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None diff --git a/netspresso/utils/db/repositories/user.py b/netspresso/utils/db/repositories/user.py index e6023d6f..ee68c947 100644 --- a/netspresso/utils/db/repositories/user.py +++ b/netspresso/utils/db/repositories/user.py @@ -8,7 +8,7 @@ class UserRepository(BaseRepository[User]): def get_by_email(self, db: Session, email: str) -> Optional[User]: - user = db.query(User).filter(User.email == email).first() + user = db.query(self.model).filter(self.model.email == email).first() return user diff --git a/netspresso/utils/db/session.py b/netspresso/utils/db/session.py index b240a36d..e84d320e 100644 --- a/netspresso/utils/db/session.py +++ b/netspresso/utils/db/session.py @@ -23,15 +23,12 @@ Base = declarative_base() -@contextmanager def get_db() -> Generator: - db = None try: db = SessionLocal() yield db finally: - if db: - db.close() + db.close() def check_database(engine): diff --git a/pyproject.toml b/pyproject.toml index a88b0ed4..0a77204c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ ignore = [ "C901", "B008", "SIM115", + "B904", ] [tool.ruff.per-file-ignores]