diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 580fc72f..d7dfb154 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -35,12 +35,6 @@ search: - [hash_auth_token](api/fastagency/auth_token/auth/hash_auth_token.md) - [parse_expiry](api/fastagency/auth_token/auth/parse_expiry.md) - [verify_auth_token](api/fastagency/auth_token/auth/verify_auth_token.md) - - db - - helpers - - [find_model_using_raw](api/fastagency/db/helpers/find_model_using_raw.md) - - [get_db_connection](api/fastagency/db/helpers/get_db_connection.md) - - [get_user](api/fastagency/db/helpers/get_user.md) - - [get_wasp_db_url](api/fastagency/db/helpers/get_wasp_db_url.md) - faststream_app - [ping_handler](api/fastagency/faststream_app/ping_handler.md) - helpers @@ -135,6 +129,11 @@ search: - fastapi_code_generator_helpers - [ArgumentWithDescription](api/fastagency/openapi/fastapi_code_generator_helpers/ArgumentWithDescription.md) - [patch_get_parameter_type](api/fastagency/openapi/fastapi_code_generator_helpers/patch_get_parameter_type.md) + - protocols + - base + - [BaseProtocol](api/fastagency/protocols/base/BaseProtocol.md) + - prisma + - [PrismaProtocol](api/fastagency/protocols/prisma/PrismaProtocol.md) - saas_app_generator - [InvalidFlyTokenError](api/fastagency/saas_app_generator/InvalidFlyTokenError.md) - [InvalidGHTokenError](api/fastagency/saas_app_generator/InvalidGHTokenError.md) diff --git a/docs/docs/en/api/fastagency/db/helpers/get_db_connection.md b/docs/docs/en/api/fastagency/db/helpers/get_db_connection.md deleted file mode 100644 index e36504b2..00000000 --- a/docs/docs/en/api/fastagency/db/helpers/get_db_connection.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -# 0.5 - API -# 2 - Release -# 3 - Contributing -# 5 - Template Page -# 10 - Default -search: - boost: 0.5 ---- - -::: fastagency.db.helpers.get_db_connection diff --git a/docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md b/docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md deleted file mode 100644 index 27fdb6f7..00000000 --- a/docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -# 0.5 - API -# 2 - Release -# 3 - Contributing -# 5 - Template Page -# 10 - Default -search: - boost: 0.5 ---- - -::: fastagency.db.helpers.get_wasp_db_url diff --git a/docs/docs/en/api/fastagency/db/helpers/get_user.md b/docs/docs/en/api/fastagency/protocols/base/BaseProtocol.md similarity index 71% rename from docs/docs/en/api/fastagency/db/helpers/get_user.md rename to docs/docs/en/api/fastagency/protocols/base/BaseProtocol.md index 6f21bd72..1f52cc89 100644 --- a/docs/docs/en/api/fastagency/db/helpers/get_user.md +++ b/docs/docs/en/api/fastagency/protocols/base/BaseProtocol.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: fastagency.db.helpers.get_user +::: fastagency.protocols.base.BaseProtocol diff --git a/docs/docs/en/api/fastagency/db/helpers/find_model_using_raw.md b/docs/docs/en/api/fastagency/protocols/prisma/PrismaProtocol.md similarity index 70% rename from docs/docs/en/api/fastagency/db/helpers/find_model_using_raw.md rename to docs/docs/en/api/fastagency/protocols/prisma/PrismaProtocol.md index 85afa70d..dddc7f2c 100644 --- a/docs/docs/en/api/fastagency/db/helpers/find_model_using_raw.md +++ b/docs/docs/en/api/fastagency/protocols/prisma/PrismaProtocol.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: fastagency.db.helpers.find_model_using_raw +::: fastagency.protocols.prisma.PrismaProtocol diff --git a/fastagency/app.py b/fastagency/app.py index 03155ff1..6c71a211 100644 --- a/fastagency/app.py +++ b/fastagency/app.py @@ -12,7 +12,6 @@ from pydantic import BaseModel, TypeAdapter, ValidationError from .auth_token.auth import DeploymentAuthToken, create_deployment_auth_token -from .db.helpers import find_model_using_raw, get_db_connection, get_user from .helpers import ( add_model_to_user, create_model, @@ -21,6 +20,9 @@ from .models.registry import Registry, Schemas from .models.toolboxes.toolbox import Toolbox +# from .db.helpers import find_model_using_raw, get_db_connection, get_user +from .protocols.prisma import PrismaProtocol + logging.basicConfig(level=logging.INFO) app = FastAPI() @@ -78,7 +80,7 @@ async def validate_secret_model( ) -> Dict[str, Any]: type: str = "secret" - found_model = await find_model_using_raw(model_uuid=model_uuid) + found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid) if "api_key" in found_model["json_str"]: model["api_key"] = found_model["json_str"]["api_key"] try: @@ -136,7 +138,7 @@ async def add_model( async def create_toolbox_for_new_user(user_uuid: Union[str, UUID]) -> Dict[str, Any]: - await get_user(user_uuid=user_uuid) # type: ignore[arg-type] + await PrismaProtocol().get_user(user_uuid=user_uuid) # type: ignore[arg-type] domain = environ.get("DOMAIN", "localhost") toolbox_openapi_url = ( @@ -181,8 +183,8 @@ async def update_model( registry = Registry.get_default() validated_model = registry.validate(type_name, model_name, model) - async with get_db_connection() as db: - found_model = await find_model_using_raw(model_uuid=model_uuid) + async with PrismaProtocol().get_db_connection() as db: + found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid) await db.model.update( where={"uuid": found_model["uuid"]}, # type: ignore[arg-type] @@ -201,8 +203,8 @@ async def update_model( async def models_delete( user_uuid: str, type_name: str, model_uuid: str ) -> Dict[str, Any]: - async with get_db_connection() as db: - found_model = await find_model_using_raw(model_uuid=model_uuid) + async with PrismaProtocol().get_db_connection() as db: + found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid) model = await db.model.delete( where={"uuid": found_model["uuid"]} # type: ignore[arg-type] ) @@ -340,7 +342,9 @@ async def chat(request: ChatRequest) -> Dict[str, Any]: @app.post("/deployment/{deployment_uuid}/chat") async def deployment_chat(deployment_uuid: str) -> Dict[str, Any]: - found_model = await find_model_using_raw(model_uuid=deployment_uuid) + found_model = await PrismaProtocol().find_model_using_raw( + model_uuid=deployment_uuid + ) team_name = found_model["json_str"]["name"] team_uuid = found_model["json_str"]["team"]["uuid"] @@ -381,15 +385,15 @@ class DeploymentAuthTokenInfo(BaseModel): async def get_all_deployment_auth_tokens( user_uuid: str, deployment_uuid: str ) -> List[DeploymentAuthTokenInfo]: - user = await get_user(user_uuid=user_uuid) - deployment = await find_model_using_raw(model_uuid=deployment_uuid) + user = await PrismaProtocol().get_user(user_uuid=user_uuid) + deployment = await PrismaProtocol().find_model_using_raw(model_uuid=deployment_uuid) if user["uuid"] != deployment["user_uuid"]: raise HTTPException( # pragma: no cover status_code=403, detail="User does not have access to this deployment" ) - async with get_db_connection() as db: + async with PrismaProtocol().get_db_connection() as db: auth_tokens = await db.authtoken.find_many( where={"deployment_uuid": deployment_uuid, "user_uuid": user_uuid}, ) @@ -407,15 +411,15 @@ async def delete_deployment_auth_token( deployment_uuid: str, auth_token_uuid: str, ) -> DeploymentAuthTokenInfo: - user = await get_user(user_uuid=user_uuid) - deployment = await find_model_using_raw(model_uuid=deployment_uuid) + user = await PrismaProtocol().get_user(user_uuid=user_uuid) + deployment = await PrismaProtocol().find_model_using_raw(model_uuid=deployment_uuid) if user["uuid"] != deployment["user_uuid"]: raise HTTPException( # pragma: no cover status_code=403, detail="User does not have access to this deployment" ) - async with get_db_connection() as db: + async with PrismaProtocol().get_db_connection() as db: auth_token = await db.authtoken.delete( where={ # type: ignore[typeddict-unknown-key] "uuid": auth_token_uuid, diff --git a/fastagency/auth_token/auth.py b/fastagency/auth_token/auth.py index 15ea21aa..3dc48252 100644 --- a/fastagency/auth_token/auth.py +++ b/fastagency/auth_token/auth.py @@ -8,7 +8,8 @@ from fastapi import HTTPException from pydantic import BaseModel -from fastagency.db.helpers import find_model_using_raw, get_db_connection, get_user +# from fastagency.db.helpers import find_model_using_raw, get_db_connection, get_user +from ..protocols.prisma import PrismaProtocol def generate_auth_token(length: int = 32) -> str: @@ -80,8 +81,8 @@ async def create_deployment_auth_token( name: str = "Default deployment token", expiry: str = "99999d", ) -> DeploymentAuthToken: - user = await get_user(user_uuid=user_uuid) - deployment = await find_model_using_raw(model_uuid=deployment_uuid) + user = await PrismaProtocol().get_user(user_uuid=user_uuid) + deployment = await PrismaProtocol().find_model_using_raw(model_uuid=deployment_uuid) if user["uuid"] != deployment["user_uuid"]: raise HTTPException( @@ -92,7 +93,7 @@ async def create_deployment_auth_token( auth_token = generate_auth_token() hashed_token = hash_auth_token(auth_token) - async with get_db_connection() as db: + async with PrismaProtocol().get_db_connection() as db: await db.authtoken.create( # type: ignore[attr-defined] data={ "uuid": str(uuid.uuid4()), diff --git a/fastagency/db/helpers.py b/fastagency/db/helpers.py index 221e0d75..40fa5552 100644 --- a/fastagency/db/helpers.py +++ b/fastagency/db/helpers.py @@ -1,63 +1,63 @@ -from contextlib import asynccontextmanager -from os import environ -from typing import Any, AsyncGenerator, Dict, Optional, Union -from uuid import UUID - -from fastapi import HTTPException -from prisma import Prisma # type: ignore[attr-defined] - - -@asynccontextmanager -async def get_db_connection( - db_url: Optional[str] = None, -) -> AsyncGenerator[Prisma, None]: - if not db_url: - db_url = environ.get("PY_DATABASE_URL", None) - if not db_url: - raise ValueError( - "No database URL provided nor set as environment variable 'PY_DATABASE_URL'" - ) # pragma: no cover - if "connect_timeout" not in db_url: - db_url += "?connect_timeout=60" - db = Prisma(datasource={"url": db_url}) - await db.connect() - try: - yield db - finally: - await db.disconnect() - - -async def get_wasp_db_url() -> str: - wasp_db_url: str = environ.get("DATABASE_URL") # type: ignore[assignment] - if "connect_timeout" not in wasp_db_url: - wasp_db_url += "?connect_timeout=60" - return wasp_db_url - - -async def find_model_using_raw(model_uuid: Union[str, UUID]) -> Dict[str, Any]: - if isinstance(model_uuid, UUID): - model_uuid = str(model_uuid) - - async with get_db_connection() as db: - model: Optional[Dict[str, Any]] = await db.query_first( - 'SELECT * from "Model" where uuid=' # nosec: [B608] - + f"'{model_uuid}'" - ) - - if not model: - raise HTTPException( - status_code=404, detail="Something went wrong. Please try again later." - ) - return model - - -async def get_user(user_uuid: Union[int, str]) -> Any: - wasp_db_url = await get_wasp_db_url() - async with get_db_connection(db_url=wasp_db_url) as db: - select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608] - user = await db.query_first( - select_query # nosec: [B608] - ) - if not user: - raise HTTPException(status_code=404, detail=f"user_uuid {user_uuid} not found") - return user +# from contextlib import asynccontextmanager +# from os import environ +# from typing import Any, AsyncGenerator, Dict, Optional, Union +# from uuid import UUID + +# from fastapi import HTTPException +# from prisma import Prisma # type: ignore[attr-defined] + + +# @asynccontextmanager +# async def get_db_connection( +# db_url: Optional[str] = None, +# ) -> AsyncGenerator[Prisma, None]: +# if not db_url: +# db_url = environ.get("PY_DATABASE_URL", None) +# if not db_url: +# raise ValueError( +# "No database URL provided nor set as environment variable 'PY_DATABASE_URL'" +# ) # pragma: no cover +# if "connect_timeout" not in db_url: +# db_url += "?connect_timeout=60" +# db = Prisma(datasource={"url": db_url}) +# await db.connect() +# try: +# yield db +# finally: +# await db.disconnect() + + +# async def get_wasp_db_url() -> str: +# wasp_db_url: str = environ.get("DATABASE_URL") # type: ignore[assignment] +# if "connect_timeout" not in wasp_db_url: +# wasp_db_url += "?connect_timeout=60" +# return wasp_db_url + + +# async def find_model_using_raw(model_uuid: Union[str, UUID]) -> Dict[str, Any]: +# if isinstance(model_uuid, UUID): +# model_uuid = str(model_uuid) + +# async with get_db_connection() as db: +# model: Optional[Dict[str, Any]] = await db.query_first( +# 'SELECT * from "Model" where uuid=' # nosec: [B608] +# + f"'{model_uuid}'" +# ) + +# if not model: +# raise HTTPException( +# status_code=404, detail="Something went wrong. Please try again later." +# ) +# return model + + +# async def get_user(user_uuid: Union[int, str]) -> Any: +# wasp_db_url = await get_wasp_db_url() +# async with get_db_connection(db_url=wasp_db_url) as db: +# select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608] +# user = await db.query_first( +# select_query # nosec: [B608] +# ) +# if not user: +# raise HTTPException(status_code=404, detail=f"user_uuid {user_uuid} not found") +# return user diff --git a/fastagency/helpers.py b/fastagency/helpers.py index 0788c434..a7b3a992 100644 --- a/fastagency/helpers.py +++ b/fastagency/helpers.py @@ -13,17 +13,18 @@ ) from .auth_token.auth import create_deployment_auth_token - -# from fastagency.app import add_model -from .db.helpers import find_model_using_raw, get_db_connection, get_user from .models.base import Model, ObjectReference from .models.registry import Registry +# from fastagency.app import add_model +# from .db.helpers import find_model_using_raw, get_db_connection, get_user +from .protocols.prisma import PrismaProtocol + T = TypeVar("T", bound=Model) async def get_model_by_uuid(model_uuid: Union[str, UUID]) -> Model: - model_dict = await find_model_using_raw(model_uuid=model_uuid) + model_dict = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid) registry = Registry.get_default() model = registry.validate( @@ -43,11 +44,11 @@ async def validate_tokens_and_create_gh_repo( model: Dict[str, Any], model_uuid: str, ) -> SaasAppGenerator: - async with get_db_connection(): - found_gh_token = await find_model_using_raw( + async with PrismaProtocol().get_db_connection(): + found_gh_token = await PrismaProtocol().find_model_using_raw( model_uuid=model["gh_token"]["uuid"] ) - found_fly_token = await find_model_using_raw( + found_fly_token = await PrismaProtocol().find_model_using_raw( model_uuid=model["fly_token"]["uuid"] ) @@ -81,8 +82,8 @@ async def deploy_saas_app( await asyncify(saas_app.execute)() - async with get_db_connection() as db: - found_model = await find_model_using_raw(model_uuid=model_uuid) + async with PrismaProtocol().get_db_connection() as db: + found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid) found_model["json_str"]["app_deploy_status"] = "completed" await db.model.update( @@ -125,8 +126,8 @@ async def add_model_to_user( updated_validated_model_dict["gh_repo_url"] = saas_app.gh_repo_url validated_model_json = json.dumps(updated_validated_model_dict) - await get_user(user_uuid=user_uuid) - async with get_db_connection() as db: + await PrismaProtocol().get_user(user_uuid=user_uuid) + async with PrismaProtocol().get_db_connection() as db: await db.model.create( data={ "uuid": model_uuid, @@ -209,7 +210,7 @@ async def get_all_models_for_user( if type_name: filters["type_name"] = type_name - async with get_db_connection() as db: + async with PrismaProtocol().get_db_connection() as db: models = await db.model.find_many(where=filters) # type: ignore[arg-type] return models # type: ignore[no-any-return] diff --git a/fastagency/io/ionats.py b/fastagency/io/ionats.py index cf680d07..515b5850 100644 --- a/fastagency/io/ionats.py +++ b/fastagency/io/ionats.py @@ -13,9 +13,10 @@ from nats.js import api from pydantic import BaseModel -from ..db.helpers import find_model_using_raw +# from ..db.helpers import find_model_using_raw from ..models.teams.multi_agent_team import MultiAgentTeam from ..models.teams.two_agent_teams import TwoAgentTeam +from ..protocols.prisma import PrismaProtocol from .app import app, broker, stream # noqa if TYPE_CHECKING: @@ -166,7 +167,7 @@ class InitiateModel(BaseModel): async def create_team( team_id: UUID, user_id: UUID ) -> Callable[[str], List[Dict[str, Any]]]: - team_dict = await find_model_using_raw(team_id) + team_dict = await PrismaProtocol().find_model_using_raw(team_id) team_model: Union[TwoAgentTeam, MultiAgentTeam] if "initial_agent" in team_dict["json_str"]: diff --git a/fastagency/models/agents/base.py b/fastagency/models/agents/base.py index 4c1b046c..18e23fa3 100644 --- a/fastagency/models/agents/base.py +++ b/fastagency/models/agents/base.py @@ -6,7 +6,8 @@ from fastagency.openapi.client import Client -from ...db.helpers import find_model_using_raw +# from ...db.helpers import find_model_using_raw +from ...protocols.prisma import PrismaProtocol from ..base import Model from ..registry import Registry from ..toolboxes.toolbox import ToolboxRef @@ -59,7 +60,9 @@ async def get_clients_from_toolboxes(self, user_id: UUID) -> List[Client]: if toolbox_property is None: continue - toolbox_dict = await find_model_using_raw(toolbox_property.uuid) + toolbox_dict = await PrismaProtocol().find_model_using_raw( + toolbox_property.uuid + ) toolbox_model = toolbox_property.get_data_model()( **toolbox_dict["json_str"] ) diff --git a/fastagency/models/base.py b/fastagency/models/base.py index 36ae7ece..72aaa60e 100644 --- a/fastagency/models/base.py +++ b/fastagency/models/base.py @@ -5,7 +5,8 @@ from pydantic import BaseModel, Field, create_model, model_validator from typing_extensions import TypeAlias -from ..db.helpers import find_model_using_raw +# from ..db.helpers import find_model_using_raw +from ..protocols.prisma import PrismaProtocol M = TypeVar("M", bound="Model") @@ -39,7 +40,7 @@ async def create_autogen( @classmethod async def from_db(cls: Type[T], model_id: UUID) -> T: - my_model_dict = await find_model_using_raw(model_id) + my_model_dict = await PrismaProtocol().find_model_using_raw(model_id) my_model = cls(**my_model_dict["json_str"]) return my_model diff --git a/fastagency/protocols/__init__.py b/fastagency/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastagency/protocols/base.py b/fastagency/protocols/base.py new file mode 100644 index 00000000..8a135fc3 --- /dev/null +++ b/fastagency/protocols/base.py @@ -0,0 +1,24 @@ +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, Dict, Optional, Union +from uuid import UUID + +from prisma import Prisma # type: ignore[attr-defined] + + +class BaseProtocol: + @asynccontextmanager # type: ignore[arg-type] + async def get_db_connection( + self, db_url: Optional[str] = None + ) -> AsyncGenerator[Prisma, None]: + raise NotImplementedError() + + async def get_wasp_db_url(self) -> str: + raise NotImplementedError() + + async def find_model_using_raw( + self, model_uuid: Union[str, UUID] + ) -> Dict[str, Any]: + raise NotImplementedError() + + async def get_user(self, user_uuid: Union[int, str]) -> Any: + raise NotImplementedError() diff --git a/fastagency/protocols/prisma.py b/fastagency/protocols/prisma.py new file mode 100644 index 00000000..5066fb92 --- /dev/null +++ b/fastagency/protocols/prisma.py @@ -0,0 +1,66 @@ +from contextlib import asynccontextmanager +from os import environ +from typing import Any, AsyncGenerator, Dict, Optional, Union +from uuid import UUID + +from fastapi import HTTPException +from prisma import Prisma # type: ignore[attr-defined] + +from .base import BaseProtocol + + +class PrismaProtocol(BaseProtocol): + @asynccontextmanager + async def get_db_connection( # type: ignore[override] + self, + db_url: Optional[str] = None, + ) -> AsyncGenerator[Prisma, None]: + if not db_url: + db_url = environ.get("PY_DATABASE_URL", None) + if not db_url: + raise ValueError( + "No database URL provided nor set as environment variable 'PY_DATABASE_URL'" + ) # pragma: no cover + if "connect_timeout" not in db_url: + db_url += "?connect_timeout=60" + db = Prisma(datasource={"url": db_url}) + await db.connect() + try: + yield db + finally: + await db.disconnect() + + async def get_wasp_db_url(self) -> str: + wasp_db_url: str = environ.get("DATABASE_URL") # type: ignore[assignment] + if "connect_timeout" not in wasp_db_url: + wasp_db_url += "?connect_timeout=60" + return wasp_db_url + + async def find_model_using_raw( + self, model_uuid: Union[str, UUID] + ) -> Dict[str, Any]: + if isinstance(model_uuid, UUID): + model_uuid = str(model_uuid) + async with self.get_db_connection() as db: + model: Optional[Dict[str, Any]] = await db.query_first( + 'SELECT * from "Model" where uuid=' # nosec: [B608] + + f"'{model_uuid}'" + ) + if not model: + raise HTTPException( + status_code=404, detail="Something went wrong. Please try again later." + ) + return model + + async def get_user(self, user_uuid: Union[int, str]) -> Any: + wasp_db_url = await self.get_wasp_db_url() + async with self.get_db_connection(db_url=wasp_db_url) as db: + select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608] + user = await db.query_first( + select_query # nosec: [B608] + ) + if not user: + raise HTTPException( + status_code=404, detail=f"user_uuid {user_uuid} not found" + ) + return user diff --git a/tests/app/test_model_routes.py b/tests/app/test_model_routes.py index 43ab215c..88399b72 100644 --- a/tests/app/test_model_routes.py +++ b/tests/app/test_model_routes.py @@ -302,8 +302,11 @@ async def test_background_task_not_called_on_error(self, user_uuid: str) -> None model_uuid = str(uuid.uuid4()) with ( - patch("fastagency.app.get_user", side_effect=Exception()), - patch("fastagency.db.helpers.get_db_connection", side_effect=Exception()), + patch("fastagency.app.PrismaProtocol.get_user", side_effect=Exception()), + patch( + "fastagency.protocols.prisma.PrismaProtocol.get_db_connection", + side_effect=Exception(), + ), patch("fastagency.helpers.deploy_saas_app") as mock_task, ): response = client.post( diff --git a/tests/auth_token/test_auth_token.py b/tests/auth_token/test_auth_token.py index 04581065..cde74276 100644 --- a/tests/auth_token/test_auth_token.py +++ b/tests/auth_token/test_auth_token.py @@ -8,6 +8,8 @@ import fastagency.app import fastagency.auth_token.auth +import fastagency.protocols +import fastagency.protocols.prisma from fastagency.app import app from fastagency.auth_token.auth import ( create_deployment_auth_token, @@ -85,7 +87,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) token = await create_deployment_auth_token(user_uuid, deployment_uuid) @@ -107,7 +111,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) with pytest.raises(HTTPException) as e: @@ -131,7 +137,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) response = client.post( @@ -157,7 +165,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) response = client.post( @@ -167,7 +177,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] assert response.status_code == 200 monkeypatch.setattr( - fastagency.app, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) response = client.get(f"/user/{user_uuid}/deployment/{deployment_uuid}") assert response.status_code == 200 @@ -192,7 +204,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) response = client.post( @@ -202,7 +216,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] assert response.status_code == 200 monkeypatch.setattr( - fastagency.app, "find_model_using_raw", mock_find_model_using_raw + fastagency.protocols.prisma.PrismaProtocol, + "find_model_using_raw", + mock_find_model_using_raw, ) response = client.get(f"/user/{user_uuid}/deployment/{deployment_uuid}") assert len(response.json()) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index c6579f48..747d8243 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,10 +24,6 @@ from fastapi import FastAPI, Path from pydantic import BaseModel -from fastagency.db.helpers import ( - get_db_connection, - get_wasp_db_url, -) from fastagency.helpers import create_autogen, create_model_ref, get_model_by_ref from fastagency.models.agents.assistant import AssistantAgent from fastagency.models.agents.user_proxy import UserProxyAgent @@ -41,6 +37,12 @@ from fastagency.models.teams.two_agent_teams import TwoAgentTeam from fastagency.models.toolboxes.toolbox import OpenAPIAuth, Toolbox +# from fastagency.db.helpers import ( +# get_db_connection, +# get_wasp_db_url, +# ) +from fastagency.protocols.prisma import PrismaProtocol + from .helpers import add_random_sufix, expand_fixture, get_by_tag, tag, tag_list F = TypeVar("F", bound=Callable[..., Any]) @@ -51,8 +53,8 @@ async def user_uuid() -> AsyncIterator[str]: try: random_id = random.randint(1, 1_000_000) generated_uuid = str(uuid.uuid4()) - wasp_db_url = await get_wasp_db_url() - async with get_db_connection(db_url=wasp_db_url) as db: + wasp_db_url = await PrismaProtocol().get_wasp_db_url() + async with PrismaProtocol().get_db_connection(db_url=wasp_db_url) as db: insert_query = ( 'INSERT INTO "User" (email, username, uuid) VALUES (' + f"'user{random_id}@airt.ai', 'user{random_id}', '{generated_uuid}')"