Skip to content

Commit

Permalink
Change raw SQL queries to prisma queries
Browse files Browse the repository at this point in the history
  • Loading branch information
rjambrecic committed Jul 3, 2024
1 parent 6eaf673 commit 05243ff
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 22 deletions.
5 changes: 3 additions & 2 deletions captn/captn_agents/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import httpx
import openai
import pandas as pd
import prisma
from autogen.io.websockets import IOWebsockets
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from prometheus_client import Counter
Expand Down Expand Up @@ -240,7 +241,7 @@ async def get_smart_suggestions(
user_id: Annotated[int, Query(description="The user id")],
) -> List[str]:
user_initial_team = await get_initial_team(user_id)
if isinstance(user_initial_team, dict):
return user_initial_team["smart_suggestions"] # type: ignore[no-any-return]
if isinstance(user_initial_team, prisma.models.UserInitialTeam):
return user_initial_team.initial_team.smart_suggestions # type: ignore[no-any-return]

return DEFAULT_SMART_SUGGESTIONS
22 changes: 7 additions & 15 deletions captn/captn_agents/db_queries.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
from typing import Any, Union
from typing import Optional, Union

from prisma.models import UserInitialTeam

from .helpers import get_db_connection


async def get_initial_team(user_id: Union[int, str]) -> Any:
async def get_initial_team(user_id: Union[int, str]) -> Optional[UserInitialTeam]:
async with get_db_connection() as db:
query = f"""SELECT
uit.id AS user_initial_team_id,
uit.user_id,
it.id AS initial_team_id,
it.name AS initial_team_name,
it.smart_suggestions
FROM
"UserInitialTeam" uit
JOIN
"InitialTeam" it ON uit.initial_team_id = it.id
WHERE
uit.user_id = {int(user_id)}""" # nosec: [B608]
user_initial_team = await db.query_first(query)
user_initial_team = await db.userinitialteam.find_first(
where={"user_id": int(user_id)}, include={"initial_team": True}
)

return user_initial_team
5 changes: 3 additions & 2 deletions openai_agent/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import environ
from typing import Dict, List, Optional, Union

import prisma
from fastapi import APIRouter, BackgroundTasks
from openai import AsyncAzureOpenAI
from pydantic import BaseModel
Expand Down Expand Up @@ -273,10 +274,10 @@ async def chat(
request: AzureOpenAIRequest, background_tasks: BackgroundTasks
) -> Dict[str, Union[Optional[str], int, Union[str, Optional[SmartSuggestions]]]]:
user_initial_team = await get_initial_team(request.user_id)
if isinstance(user_initial_team, dict):
if isinstance(user_initial_team, prisma.models.UserInitialTeam):
return {
"team_status": "inprogress",
"team_name": user_initial_team["initial_team_name"],
"team_name": user_initial_team.initial_team.name,
"team_id": request.chat_id,
"customer_brief": "This is my customer brief.",
"conversation_name": "Team of Experts",
Expand Down
12 changes: 11 additions & 1 deletion tests/ci/captn/captn_agents/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import autogen
import pandas as pd
import prisma
import pytest
from autogen.io.websockets import IOWebsockets
from fastapi import HTTPException
Expand Down Expand Up @@ -513,11 +514,20 @@ def test_upload_csv_or_xlsx_file(


class TestGetSmartSuggestions:
initial_team = prisma.models.InitialTeam(
id=1,
name="test_team",
smart_suggestions=["Boost sales", "Increase brand awareness"],
)
user_initial_team = prisma.models.UserInitialTeam(
id=1, user_id=123, initial_team_id=1, initial_team=initial_team
)

@pytest.mark.parametrize(
("return_value", "expected"),
[
(
{"smart_suggestions": ["Boost sales", "Increase brand awareness"]},
user_initial_team,
["Boost sales", "Increase brand awareness"],
),
(None, DEFAULT_SMART_SUGGESTIONS),
Expand Down
11 changes: 9 additions & 2 deletions tests/ci/openai_agent/test_openai_agent_application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import prisma
import pytest
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -30,8 +31,14 @@ def test_chat(self) -> None:

with unittest.mock.patch(
"openai_agent.application.get_initial_team"
) as mock_get_initial_team:
mock_get_initial_team.return_value = {"initial_team_name": "test_team"}
) as mock_get_user_initial_team:
initial_team = prisma.models.InitialTeam(
id=1, name="test_team", smart_suggestions=["test suggestion"]
)
user_initial_team = prisma.models.UserInitialTeam(
id=1, user_id=123, initial_team_id=1, initial_team=initial_team
)
mock_get_user_initial_team.return_value = user_initial_team

response = client.post("/chat", json=request.model_dump())

Expand Down

0 comments on commit 05243ff

Please sign in to comment.