Skip to content

Commit

Permalink
feat: sql pairs endpoint spec
Browse files Browse the repository at this point in the history
  • Loading branch information
paopa committed Mar 5, 2025
1 parent 0ac1078 commit 7deb1c9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 39 deletions.
53 changes: 25 additions & 28 deletions wren-ai-service/src/web/v1/routers/sql_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import asdict
from typing import List, Literal, Optional

from fastapi import APIRouter, BackgroundTasks, Depends
from fastapi import APIRouter, BackgroundTasks, Depends, Response
from pydantic import BaseModel

from src.globals import (
Expand Down Expand Up @@ -85,7 +85,7 @@ class PostRequest(BaseModel):


class PostResponse(BaseModel):
id: str
event_id: str


@router.post("/sql-pairs")
Expand All @@ -95,67 +95,64 @@ async def prepare(
service_container: ServiceContainer = Depends(get_service_container),
service_metadata: ServiceMetadata = Depends(get_service_metadata),
) -> PostResponse:
id = str(uuid.uuid4())
event_id = str(uuid.uuid4())
service = service_container.sql_pairs_service
service[id] = SqlPairsService.Resource(id=id, status="indexing")
service[event_id] = SqlPairsService.Event(id=event_id, status="indexing")

index_request = SqlPairsService.IndexRequest(id=id, **request.model_dump())
index_request = SqlPairsService.IndexRequest(id=event_id, **request.model_dump())

background_tasks.add_task(
service.index,
index_request,
service_metadata=asdict(service_metadata),
)
return PostResponse(id=id)
return PostResponse(event_id=event_id)


class DeleteRequest(BaseModel):
sql_pair_ids: List[str]
project_id: Optional[str] = None


class DeleteResponse(BaseModel):
id: str


@router.delete("/sql-pairs")
async def delete(
request: DeleteRequest,
background_tasks: BackgroundTasks,
response: Response,
service_container: ServiceContainer = Depends(get_service_container),
service_metadata: ServiceMetadata = Depends(get_service_metadata),
) -> DeleteResponse:
id = str(uuid.uuid4())
) -> None | SqlPairsService.Event.Error:
event_id = str(uuid.uuid4())
service = service_container.sql_pairs_service
service[id] = SqlPairsService.Resource(id=id, status="deleting")
service[event_id] = SqlPairsService.Event(id=event_id, status="deleting")

delete_request = SqlPairsService.DeleteRequest(
id=id,
id=event_id,
**request.model_dump(),
)

background_tasks.add_task(
service.delete,
delete_request,
service_metadata=asdict(service_metadata),
)
return DeleteResponse(id=id)
await service.delete(delete_request, service_metadata=asdict(service_metadata))

event: SqlPairsService.Event = service[event_id]

if event.status == "failed":
response.status_code = 500
return event.error


class GetResponse(BaseModel):
id: str
event_id: str
status: Literal["indexing", "deleting", "finished", "failed"]
error: Optional[dict]


@router.get("/sql-pairs/{id}")
@router.get("/sql-pairs/{event_id}")
async def get(
id: str,
event_id: str,
container: ServiceContainer = Depends(get_service_container),
) -> GetResponse:
resource = container.sql_pairs_service[id]
event: SqlPairsService.Event = container.sql_pairs_service[event_id]
return GetResponse(
id=resource.id,
status=resource.status,
error=resource.error and resource.error.model_dump(),
event_id=event.id,
status=event.status,
error=event.error and event.error.model_dump(),
)
22 changes: 11 additions & 11 deletions wren-ai-service/src/web/v1/services/sql_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class SqlPairsService:
class Resource(BaseModel, MetadataTraceable):
class Event(BaseModel, MetadataTraceable):
class Error(BaseModel):
code: Literal["OTHERS"]
message: str
Expand All @@ -30,18 +30,18 @@ def __init__(
ttl: int = 120,
):
self._pipelines = pipelines
self._cache: Dict[str, self.Resource] = TTLCache(maxsize=maxsize, ttl=ttl)
self._cache: Dict[str, self.Event] = TTLCache(maxsize=maxsize, ttl=ttl)

def _handle_exception(
self,
id: str,
error_message: str,
code: str = "OTHERS",
):
self._cache[id] = self.Resource(
self._cache[id] = self.Event(
id=id,
status="failed",
error=self.Resource.Error(code=code, message=error_message),
error=self.Event.Error(code=code, message=error_message),
)
logger.error(error_message)

Expand Down Expand Up @@ -71,7 +71,7 @@ async def index(
}
await self._pipelines["sql_pairs"].run(**input)

self._cache[request.id] = self.Resource(id=request.id, status="finished")
self._cache[request.id] = self.Event(id=request.id, status="finished")

except Exception as e:
self._handle_exception(
Expand Down Expand Up @@ -101,7 +101,7 @@ async def delete(
sql_pairs=sql_pairs, project_id=request.project_id
)

self._cache[request.id] = self.Resource(id=request.id, status="finished")
self._cache[request.id] = self.Event(id=request.id, status="finished")
except Exception as e:
self._handle_exception(
request.id,
Expand All @@ -110,19 +110,19 @@ async def delete(

return self._cache[request.id].with_metadata()

def __getitem__(self, id: str) -> Resource:
def __getitem__(self, id: str) -> Event:
response = self._cache.get(id)

if response is None:
message = f"SQL Pairs Resource with ID '{id}' not found."
message = f"SQL Pairs Event with ID '{id}' not found."
logger.exception(message)
return self.Resource(
return self.Event(
id=id,
status="failed",
error=self.Resource.Error(code="OTHERS", message=message),
error=self.Event.Error(code="OTHERS", message=message),
)

return response

def __setitem__(self, id: str, value: Resource):
def __setitem__(self, id: str, value: Event):
self._cache[id] = value

0 comments on commit 7deb1c9

Please sign in to comment.