Skip to content

Commit

Permalink
#420 Add api for creating a project (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Dec 7, 2024
1 parent 3f0c8ea commit 6b229f2
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 29 deletions.
15 changes: 9 additions & 6 deletions app/api/v1/endpoints/project.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional
from uuid import uuid4

from fastapi import APIRouter
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from app.api.deps import api_key_header
from app.api.v1.schemas.project import (
ExperimentStatus,
ModelSummary,
Expand All @@ -15,21 +17,22 @@
ProjectsResponse,
ProjectSummaryPayload,
)
from app.services.project import project_service
from netspresso.utils.db.session import get_db

router = APIRouter()


@router.post("", response_model=ProjectResponse)
def create_project(
*,
db: Session = Depends(get_db),
api_key: str = Depends(api_key_header),
request_body: ProjectCreate,
) -> ProjectResponse:
project = project_service.create_project(db=db, project_name=request_body.project_name, api_key=api_key)

project = ProjectSummaryPayload(
project_id=str(uuid4()),
project_name=request_body.project_name,
user_id=str(uuid4()),
)
project = ProjectSummaryPayload.model_validate(project)

return ProjectResponse(data=project)

Expand Down
3 changes: 2 additions & 1 deletion app/api/v1/schemas/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from app.api.v1.schemas.base import ResponseItem, ResponsePaginationItems
from netspresso.enums import Status
from netspresso.exceptions.project import ProjectNameTooLongException


class ProjectCreate(BaseModel):
Expand All @@ -12,7 +13,7 @@ class ProjectCreate(BaseModel):
@field_validator("project_name")
def validate_length_of_project_name(cls, project_name: str) -> str:
if len(project_name) > 30:
raise ValueError("The project_name can't exceed 30 characters.")
raise ProjectNameTooLongException(max_length=30, actual_length=len(project_name))
return project_name


Expand Down
13 changes: 12 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from typing import List

from fastapi import FastAPI
from fastapi import FastAPI, Request, status
from fastapi.middleware import Middleware
from fastapi.responses import JSONResponse
from starlette.middleware.cors import CORSMiddleware

from app.api.api import api_router
from app.configs.settings import settings
from netspresso.exceptions.common import PyNPException
from netspresso.exceptions.status import STATUS_MAP


def init_routers(app: FastAPI) -> None:
app.include_router(api_router, prefix=settings.API_PREFIX)


def init_exceptions(app: FastAPI) -> None:
@app.exception_handler(PyNPException)
async def http_exception_handler(request: Request, exc: PyNPException):
status_code = STATUS_MAP.get(exc.detail["error_code"], status.HTTP_500_INTERNAL_SERVER_ERROR)

return JSONResponse(status_code=status_code, content=exc.detail)

def make_middleware() -> List[Middleware]:
origins = ["*"]
middleware = [
Expand All @@ -36,6 +46,7 @@ def create_app():
middleware=make_middleware(),
)
init_routers(app=app)
init_exceptions(app=app)

return app

Expand Down
19 changes: 19 additions & 0 deletions app/services/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from sqlalchemy.orm import Session

from app.services.user import user_service
from netspresso.netspresso import NetsPresso
from netspresso.utils.db.models.project import Project


class ProjectService:
def create_project(self, db: Session, project_name: str, api_key: str) -> Project:
user = user_service.get_user_by_api_key(db=db, api_key=api_key)

netspresso = NetsPresso(email=user.email, password=user.password)

project = netspresso.create_project(project_name=project_name)

return project


project_service = ProjectService()
9 changes: 7 additions & 2 deletions app/services/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class UserService:
def create_user(self, db: Session, email: str, password: str, api_key: str):
def create_user(self, db: Session, email: str, password: str, api_key: str) -> User:
user = User(
email=email,
password=password,
Expand Down Expand Up @@ -40,9 +40,14 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl

return api_key

def get_user_info(self, db: Session, api_key: str) -> UserPayload:
def get_user_by_api_key(self, db: Session, api_key: str) -> User:
user = user_repository.get_by_api_key(db=db, api_key=api_key)

return user

def get_user_info(self, db: Session, api_key: str) -> UserPayload:
user = self.get_user_by_api_key(db=db, api_key=api_key)

netspresso = NetsPresso(email=user.email, password=user.password)

user = UserPayload(
Expand Down
9 changes: 9 additions & 0 deletions netspresso/exceptions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
from typing import List, Optional


class Origin(str, Enum):
ROUTER = "router"
SERVICE = "service"
REPOSITORY = "repository"
CORE = "core"
CLIENT = "client"
LIBRARY = "library"


class LinkType(str, Enum):
DOCS = "docs"
CONTACT = "contact"
Expand Down
34 changes: 34 additions & 0 deletions netspresso/exceptions/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from netspresso.exceptions.common import AdditionalData, Origin, PyNPException


class ProjectNameTooLongException(PyNPException):
def __init__(self, max_length: int, actual_length: int):
message = f"The project_name exceeds maximum length. Max: {max_length}, Actual: {actual_length}"
super().__init__(
data=AdditionalData(origin=Origin.CORE),
error_code="PROJECT40001",
name=self.__class__.__name__,
message=message,
)


class ProjectAlreadyExistsException(PyNPException):
def __init__(self, project_name: str, project_path: str):
message = f"The project_name '{project_name}' already exists at '{project_path}'."
super().__init__(
data=AdditionalData(origin=Origin.CORE),
error_code="PROJECT40901",
name=self.__class__.__name__,
message=message,
)


class ProjectSaveException(PyNPException):
def __init__(self, error: Exception, project_name: str):
message = f"Failed to save project '{project_name}' to the database: {str(error)}"
super().__init__(
data=AdditionalData(origin=Origin.REPOSITORY),
error_code="PROJECT50001",
name=self.__class__.__name__,
message=message,
)
7 changes: 7 additions & 0 deletions netspresso/exceptions/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from fastapi import status

STATUS_MAP = {
"PROJECT40001": status.HTTP_400_BAD_REQUEST,
"PROJECT40901": status.HTTP_409_CONFLICT,
"PROJECT50001": status.HTTP_500_INTERNAL_SERVER_ERROR,
}
80 changes: 66 additions & 14 deletions netspresso/netspresso.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
from netspresso.constant.project import SUB_FOLDERS
from netspresso.converter import ConverterV2
from netspresso.enums import Task
from netspresso.exceptions.project import (
ProjectAlreadyExistsException,
ProjectNameTooLongException,
ProjectSaveException,
)
from netspresso.inferencer.inferencer import CustomInferencer, NPInferencer
from netspresso.quantizer import Quantizer
from netspresso.tao import TAOTrainer
from netspresso.trainer import Trainer
from netspresso.utils.db.models.project import Project
from netspresso.utils.db.repositories.project import project_repository
from netspresso.utils.db.session import get_db
from netspresso.utils.db.session import SessionLocal


class NetsPresso:
Expand Down Expand Up @@ -46,15 +51,44 @@ def get_user(self) -> UserResponse:
return user_info

def create_project(self, project_name: str, project_path: str = "./projects") -> Project:
"""
Create a new project with the specified name and path.
This method creates a project directory structure on the file system
and saves the project information in the database. It also handles
scenarios where the project name is too long or already exists.
Args:
project_name (str): The name of the project to create.
Must not exceed 30 characters.
project_path (str, optional): The base path where the project
will be created. Defaults to "./projects".
Returns:
Project: The created project object containing information
such as project name, user ID, and absolute path.
Raises:
ProjectNameTooLongException: If the `project_name` exceeds the
maximum allowed length of 30 characters.
ProjectAlreadyExistsException: If a project with the same name
already exists at the specified `project_path`.
ProjectSaveException: If an error occurs while saving the project
to the database.
"""
if len(project_name) > 30:
raise ValueError("The project_name can't exceed 30 characters.")
raise ProjectNameTooLongException(max_length=30, actual_length=len(project_name))

# Create the main project folder
project_folder_path = Path(project_path) / project_name

# Check if the project folder already exists
if project_folder_path.exists():
logger.warning(f"Project '{project_name}' already exists at {project_folder_path.resolve()}.")
raise ProjectAlreadyExistsException(
project_name=project_name,
project_path=project_folder_path.resolve().as_posix()
)
else:
project_folder_path.mkdir(parents=True, exist_ok=True)
project_abs_path = project_folder_path.resolve()
Expand All @@ -65,31 +99,49 @@ def create_project(self, project_name: str, project_path: str = "./projects") ->

logger.info(f"Project '{project_name}' created at {project_abs_path}.")

db = None
try:
with get_db() as db:
project = Project(
project_name=project_name,
user_id=self.user_info.user_id,
project_abs_path=project_abs_path.as_posix(),
)
project = project_repository.save(db=db, model=project)
db = SessionLocal()
project = Project(
project_name=project_name,
user_id=self.user_info.user_id,
project_abs_path=project_abs_path.as_posix(),
)
project = project_repository.save(db=db, model=project)

return project
return project

except Exception as e:
logger.error(f"Failed to save project '{project_name}' to the database: {e}")
raise
raise ProjectSaveException(error=e, project_name=project_name)
finally:
db and db.close()

def get_projects(self) -> List[Project]:
"""
Retrieve all projects associated with the current user.
This method fetches project information from the database for
the user identified by `self.user_info.user_id`.
Returns:
List[Project]: A list of projects associated with the current user.
Raises:
Exception: If an error occurs while querying the database.
"""
db = None
try:
with get_db() as db:
projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id)
db = SessionLocal()
projects = project_repository.get_all_by_user_id(db=db, user_id=self.user_info.user_id)

return projects
return projects

except Exception as e:
logger.error(f"Failed to get project list from the database: {e}")
raise
finally:
db and db.close()

def trainer(
self, task: Optional[Union[str, Task]] = None, yaml_path: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion netspresso/utils/db/repositories/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class UserRepository(BaseRepository[User]):
def get_by_email(self, db: Session, email: str) -> Optional[User]:
user = db.query(User).filter(User.email == email).first()
user = db.query(self.model).filter(self.model.email == email).first()

return user

Expand Down
5 changes: 1 addition & 4 deletions netspresso/utils/db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@
Base = declarative_base()


@contextmanager
def get_db() -> Generator:
db = None
try:
db = SessionLocal()
yield db
finally:
if db:
db.close()
db.close()


def check_database(engine):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ignore = [
"C901",
"B008",
"SIM115",
"B904",
]

[tool.ruff.per-file-ignores]
Expand Down

0 comments on commit 6b229f2

Please sign in to comment.