From 7422ac5a4c59c6149f53e9d8223b14656430ff55 Mon Sep 17 00:00:00 2001 From: Kumaran Rajendhiran Date: Tue, 13 Aug 2024 11:48:56 +0530 Subject: [PATCH] Add prisma protocol class (#6) * Add prisma protocol class * Move protocol files to db directory * Update docs * Remove unnecessary commented out lines * Split protocol * Remove unnecessary code and rearrange * Update docs * Add methods for getting prisma actions objects * Use protocol * Optimize code * Update docs * Update BaseFrontendProtocol * Rename prisma backend db class * Delete unnecessary class * Remove unnecessary methods * Use default find model * Remove get_authtoken_connection method * Remove get_model_connection method * Fix mypy issues * Fix mypy issue * Rename protocol class * Add an empty line between methods in protocol * Use sync get default and set default methods * Refactor code * Add a internal method to create user * Add tests for prisma backend classes * Add inmemory protocol implementation * Update prisma test * Update docs * Use inmemory db in tests * wip * wip * Use class defaultdb * Remove unnecessary type adapter * Move lifespan * Reuse from_db in agents/base.py * Use Union[str, UUID] * Raise common exceptions and handle it in fastapi app * Update docs * Use exception args * Update together model string * Update openai models list * Update openai schema in test * Update openai tests --------- Co-authored-by: Davor Runje --- docs/docs/SUMMARY.md | 21 +- .../app/handle_keynotfounderror_middleware.md | 11 + .../BackendDBProtocol.md} | 2 +- .../get_user.md => base/DefaultDB.md} | 2 +- .../FrontendDBProtocol.md} | 2 +- .../KeyExistsError.md} | 2 +- .../fastagency/db/base/KeyNotFoundError.md | 11 + .../db/inmemory/InMemoryBackendDB.md | 11 + .../db/inmemory/InMemoryFrontendDB.md | 11 + .../fastagency/db/prisma/PrismaBackendDB.md | 11 + .../api/fastagency/db/prisma/PrismaBaseDB.md | 11 + .../fastagency/db/prisma/PrismaFrontendDB.md | 11 + .../fastagency/db/prisma/fastapi_lifespan.md | 11 + .../db/prisma/faststream_lifespan.md | 11 + fastagency/app.py | 114 +++++---- fastagency/auth_token/auth.py | 32 ++- fastagency/db/base.py | 111 +++++++++ fastagency/db/helpers.py | 63 ----- fastagency/db/inmemory.py | 145 ++++++++++++ fastagency/db/prisma.py | 217 ++++++++++++++++++ fastagency/helpers.py | 71 +++--- fastagency/io/app.py | 5 +- fastagency/io/ionats.py | 4 +- fastagency/models/agents/base.py | 6 +- fastagency/models/base.py | 4 +- fastagency/models/llms/openai.py | 17 +- fastagency/models/llms/together.py | 14 +- tests/app/test_model_routes.py | 26 ++- tests/app/test_openai_extensively.py | 4 +- tests/auth_token/test_auth_token.py | 54 +++-- tests/conftest.py | 37 +-- tests/db/__init__.py | 0 tests/db/test_inmemory.py | 185 +++++++++++++++ tests/db/test_prisma.py | 186 +++++++++++++++ tests/models/llms/test_openai.py | 17 +- 35 files changed, 1184 insertions(+), 256 deletions(-) create mode 100644 docs/docs/en/api/fastagency/app/handle_keynotfounderror_middleware.md rename docs/docs/en/api/fastagency/db/{helpers/get_wasp_db_url.md => base/BackendDBProtocol.md} (72%) rename docs/docs/en/api/fastagency/db/{helpers/get_user.md => base/DefaultDB.md} (75%) rename docs/docs/en/api/fastagency/db/{helpers/get_db_connection.md => base/FrontendDBProtocol.md} (71%) rename docs/docs/en/api/fastagency/db/{helpers/find_model_using_raw.md => base/KeyExistsError.md} (70%) create mode 100644 docs/docs/en/api/fastagency/db/base/KeyNotFoundError.md create mode 100644 docs/docs/en/api/fastagency/db/inmemory/InMemoryBackendDB.md create mode 100644 docs/docs/en/api/fastagency/db/inmemory/InMemoryFrontendDB.md create mode 100644 docs/docs/en/api/fastagency/db/prisma/PrismaBackendDB.md create mode 100644 docs/docs/en/api/fastagency/db/prisma/PrismaBaseDB.md create mode 100644 docs/docs/en/api/fastagency/db/prisma/PrismaFrontendDB.md create mode 100644 docs/docs/en/api/fastagency/db/prisma/fastapi_lifespan.md create mode 100644 docs/docs/en/api/fastagency/db/prisma/faststream_lifespan.md create mode 100644 fastagency/db/base.py delete mode 100644 fastagency/db/helpers.py create mode 100644 fastagency/db/inmemory.py create mode 100644 fastagency/db/prisma.py create mode 100644 tests/db/__init__.py create mode 100644 tests/db/test_inmemory.py create mode 100644 tests/db/test_prisma.py diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index 580fc72f..6f5612d0 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -20,6 +20,7 @@ search: - [get_all_models](api/fastagency/app/get_all_models.md) - [get_azure_llm_client](api/fastagency/app/get_azure_llm_client.md) - [get_models_schemas](api/fastagency/app/get_models_schemas.md) + - [handle_keynotfounderror_middleware](api/fastagency/app/handle_keynotfounderror_middleware.md) - [mask](api/fastagency/app/mask.md) - [models_delete](api/fastagency/app/models_delete.md) - [setup_user](api/fastagency/app/setup_user.md) @@ -36,11 +37,21 @@ search: - [parse_expiry](api/fastagency/auth_token/auth/parse_expiry.md) - [verify_auth_token](api/fastagency/auth_token/auth/verify_auth_token.md) - db - - helpers - - [find_model_using_raw](api/fastagency/db/helpers/find_model_using_raw.md) - - [get_db_connection](api/fastagency/db/helpers/get_db_connection.md) - - [get_user](api/fastagency/db/helpers/get_user.md) - - [get_wasp_db_url](api/fastagency/db/helpers/get_wasp_db_url.md) + - base + - [BackendDBProtocol](api/fastagency/db/base/BackendDBProtocol.md) + - [DefaultDB](api/fastagency/db/base/DefaultDB.md) + - [FrontendDBProtocol](api/fastagency/db/base/FrontendDBProtocol.md) + - [KeyExistsError](api/fastagency/db/base/KeyExistsError.md) + - [KeyNotFoundError](api/fastagency/db/base/KeyNotFoundError.md) + - inmemory + - [InMemoryBackendDB](api/fastagency/db/inmemory/InMemoryBackendDB.md) + - [InMemoryFrontendDB](api/fastagency/db/inmemory/InMemoryFrontendDB.md) + - prisma + - [PrismaBackendDB](api/fastagency/db/prisma/PrismaBackendDB.md) + - [PrismaBaseDB](api/fastagency/db/prisma/PrismaBaseDB.md) + - [PrismaFrontendDB](api/fastagency/db/prisma/PrismaFrontendDB.md) + - [fastapi_lifespan](api/fastagency/db/prisma/fastapi_lifespan.md) + - [faststream_lifespan](api/fastagency/db/prisma/faststream_lifespan.md) - faststream_app - [ping_handler](api/fastagency/faststream_app/ping_handler.md) - helpers diff --git a/docs/docs/en/api/fastagency/app/handle_keynotfounderror_middleware.md b/docs/docs/en/api/fastagency/app/handle_keynotfounderror_middleware.md new file mode 100644 index 00000000..cae8d69d --- /dev/null +++ b/docs/docs/en/api/fastagency/app/handle_keynotfounderror_middleware.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.app.handle_keynotfounderror_middleware diff --git a/docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md b/docs/docs/en/api/fastagency/db/base/BackendDBProtocol.md similarity index 72% rename from docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md rename to docs/docs/en/api/fastagency/db/base/BackendDBProtocol.md index 27fdb6f7..3d38a08d 100644 --- a/docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md +++ b/docs/docs/en/api/fastagency/db/base/BackendDBProtocol.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: fastagency.db.helpers.get_wasp_db_url +::: fastagency.db.base.BackendDBProtocol diff --git a/docs/docs/en/api/fastagency/db/helpers/get_user.md b/docs/docs/en/api/fastagency/db/base/DefaultDB.md similarity index 75% rename from docs/docs/en/api/fastagency/db/helpers/get_user.md rename to docs/docs/en/api/fastagency/db/base/DefaultDB.md index 6f21bd72..ea64bfe7 100644 --- a/docs/docs/en/api/fastagency/db/helpers/get_user.md +++ b/docs/docs/en/api/fastagency/db/base/DefaultDB.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: fastagency.db.helpers.get_user +::: fastagency.db.base.DefaultDB diff --git a/docs/docs/en/api/fastagency/db/helpers/get_db_connection.md b/docs/docs/en/api/fastagency/db/base/FrontendDBProtocol.md similarity index 71% rename from docs/docs/en/api/fastagency/db/helpers/get_db_connection.md rename to docs/docs/en/api/fastagency/db/base/FrontendDBProtocol.md index e36504b2..3c1f6581 100644 --- a/docs/docs/en/api/fastagency/db/helpers/get_db_connection.md +++ b/docs/docs/en/api/fastagency/db/base/FrontendDBProtocol.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: fastagency.db.helpers.get_db_connection +::: fastagency.db.base.FrontendDBProtocol diff --git a/docs/docs/en/api/fastagency/db/helpers/find_model_using_raw.md b/docs/docs/en/api/fastagency/db/base/KeyExistsError.md similarity index 70% rename from docs/docs/en/api/fastagency/db/helpers/find_model_using_raw.md rename to docs/docs/en/api/fastagency/db/base/KeyExistsError.md index 85afa70d..d87ffcc2 100644 --- a/docs/docs/en/api/fastagency/db/helpers/find_model_using_raw.md +++ b/docs/docs/en/api/fastagency/db/base/KeyExistsError.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: fastagency.db.helpers.find_model_using_raw +::: fastagency.db.base.KeyExistsError diff --git a/docs/docs/en/api/fastagency/db/base/KeyNotFoundError.md b/docs/docs/en/api/fastagency/db/base/KeyNotFoundError.md new file mode 100644 index 00000000..c4dcae0e --- /dev/null +++ b/docs/docs/en/api/fastagency/db/base/KeyNotFoundError.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.base.KeyNotFoundError diff --git a/docs/docs/en/api/fastagency/db/inmemory/InMemoryBackendDB.md b/docs/docs/en/api/fastagency/db/inmemory/InMemoryBackendDB.md new file mode 100644 index 00000000..bed3b861 --- /dev/null +++ b/docs/docs/en/api/fastagency/db/inmemory/InMemoryBackendDB.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.inmemory.InMemoryBackendDB diff --git a/docs/docs/en/api/fastagency/db/inmemory/InMemoryFrontendDB.md b/docs/docs/en/api/fastagency/db/inmemory/InMemoryFrontendDB.md new file mode 100644 index 00000000..0875b0ef --- /dev/null +++ b/docs/docs/en/api/fastagency/db/inmemory/InMemoryFrontendDB.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.inmemory.InMemoryFrontendDB diff --git a/docs/docs/en/api/fastagency/db/prisma/PrismaBackendDB.md b/docs/docs/en/api/fastagency/db/prisma/PrismaBackendDB.md new file mode 100644 index 00000000..b693e07e --- /dev/null +++ b/docs/docs/en/api/fastagency/db/prisma/PrismaBackendDB.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.prisma.PrismaBackendDB diff --git a/docs/docs/en/api/fastagency/db/prisma/PrismaBaseDB.md b/docs/docs/en/api/fastagency/db/prisma/PrismaBaseDB.md new file mode 100644 index 00000000..7b7f61f1 --- /dev/null +++ b/docs/docs/en/api/fastagency/db/prisma/PrismaBaseDB.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.prisma.PrismaBaseDB diff --git a/docs/docs/en/api/fastagency/db/prisma/PrismaFrontendDB.md b/docs/docs/en/api/fastagency/db/prisma/PrismaFrontendDB.md new file mode 100644 index 00000000..040f7e79 --- /dev/null +++ b/docs/docs/en/api/fastagency/db/prisma/PrismaFrontendDB.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.prisma.PrismaFrontendDB diff --git a/docs/docs/en/api/fastagency/db/prisma/fastapi_lifespan.md b/docs/docs/en/api/fastagency/db/prisma/fastapi_lifespan.md new file mode 100644 index 00000000..8df34c3c --- /dev/null +++ b/docs/docs/en/api/fastagency/db/prisma/fastapi_lifespan.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.prisma.fastapi_lifespan diff --git a/docs/docs/en/api/fastagency/db/prisma/faststream_lifespan.md b/docs/docs/en/api/fastagency/db/prisma/faststream_lifespan.md new file mode 100644 index 00000000..f9c2b398 --- /dev/null +++ b/docs/docs/en/api/fastagency/db/prisma/faststream_lifespan.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: fastagency.db.prisma.faststream_lifespan diff --git a/fastagency/app.py b/fastagency/app.py index 03155ff1..9c88d6ad 100644 --- a/fastagency/app.py +++ b/fastagency/app.py @@ -1,18 +1,30 @@ import json import logging from os import environ -from typing import Annotated, Any, Dict, List, Optional, Tuple, Union +from typing import ( + Annotated, + Any, + Callable, + Coroutine, + Dict, + List, + Optional, + Tuple, + Union, +) from uuid import UUID import httpx import yaml from fastapi import BackgroundTasks, Body, FastAPI, HTTPException, Path +from fastapi.requests import Request +from fastapi.responses import JSONResponse, Response from openai import AsyncAzureOpenAI -from prisma.models import Model -from pydantic import BaseModel, TypeAdapter, ValidationError +from pydantic import BaseModel, ValidationError from .auth_token.auth import DeploymentAuthToken, create_deployment_auth_token -from .db.helpers import find_model_using_raw, get_db_connection, get_user +from .db.base import DefaultDB, KeyNotFoundError +from .db.prisma import fastapi_lifespan from .helpers import ( add_model_to_user, create_model, @@ -23,7 +35,18 @@ logging.basicConfig(level=logging.INFO) -app = FastAPI() + +app = FastAPI(lifespan=fastapi_lifespan) + + +@app.middleware("http") +async def handle_keynotfounderror_middleware( + request: Request, call_next: Callable[[Request], Coroutine[Any, Any, Response]] +) -> Response: + try: + return await call_next(request) + except KeyNotFoundError as e: + return JSONResponse(status_code=404, content={"detail": e.args[0]}) @app.get("/models/schemas") @@ -78,7 +101,7 @@ async def validate_secret_model( ) -> Dict[str, Any]: type: str = "secret" - found_model = await find_model_using_raw(model_uuid=model_uuid) + found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid) if "api_key" in found_model["json_str"]: model["api_key"] = found_model["json_str"]["api_key"] try: @@ -100,10 +123,9 @@ async def get_all_models( user_uuid: str, type_name: Optional[str] = None, ) -> List[Any]: - models = await get_all_models_for_user(user_uuid=user_uuid, type_name=type_name) - - ta = TypeAdapter(List[Model]) - ret_val_without_mask = ta.dump_python(models, serialize_as_any=True) # type: ignore[call-arg] + ret_val_without_mask = await get_all_models_for_user( + user_uuid=user_uuid, type_name=type_name + ) ret_val = [] for model in ret_val_without_mask: @@ -136,7 +158,7 @@ async def add_model( async def create_toolbox_for_new_user(user_uuid: Union[str, UUID]) -> Dict[str, Any]: - await get_user(user_uuid=user_uuid) # type: ignore[arg-type] + await DefaultDB.frontend().get_user(user_uuid=user_uuid) # type: ignore[arg-type] domain = environ.get("DOMAIN", "localhost") toolbox_openapi_url = ( @@ -181,18 +203,14 @@ async def update_model( registry = Registry.get_default() validated_model = registry.validate(type_name, model_name, model) - async with get_db_connection() as db: - found_model = await find_model_using_raw(model_uuid=model_uuid) - - await db.model.update( - where={"uuid": found_model["uuid"]}, # type: ignore[arg-type] - data={ # type: ignore[typeddict-unknown-key] - "type_name": type_name, - "model_name": model_name, - "json_str": validated_model.model_dump_json(), # type: ignore[typeddict-item] - "user_uuid": user_uuid, - }, - ) + found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid) + await DefaultDB.backend().update_model( + model_uuid=found_model["uuid"], + user_uuid=user_uuid, + type_name=type_name, + model_name=model_name, + json_str=validated_model.model_dump_json(), + ) return validated_model.model_dump() @@ -201,13 +219,9 @@ async def update_model( async def models_delete( user_uuid: str, type_name: str, model_uuid: str ) -> Dict[str, Any]: - async with get_db_connection() as db: - found_model = await find_model_using_raw(model_uuid=model_uuid) - model = await db.model.delete( - where={"uuid": found_model["uuid"]} # type: ignore[arg-type] - ) - - return model.json_str # type: ignore + found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid) + model = await DefaultDB.backend().delete_model(model_uuid=found_model["uuid"]) + return model["json_str"] # type: ignore def get_azure_llm_client() -> Tuple[AsyncAzureOpenAI, str]: @@ -340,7 +354,7 @@ async def chat(request: ChatRequest) -> Dict[str, Any]: @app.post("/deployment/{deployment_uuid}/chat") async def deployment_chat(deployment_uuid: str) -> Dict[str, Any]: - found_model = await find_model_using_raw(model_uuid=deployment_uuid) + found_model = await DefaultDB.backend().find_model(model_uuid=deployment_uuid) team_name = found_model["json_str"]["name"] team_uuid = found_model["json_str"]["team"]["uuid"] @@ -381,21 +395,22 @@ class DeploymentAuthTokenInfo(BaseModel): async def get_all_deployment_auth_tokens( user_uuid: str, deployment_uuid: str ) -> List[DeploymentAuthTokenInfo]: - user = await get_user(user_uuid=user_uuid) - deployment = await find_model_using_raw(model_uuid=deployment_uuid) + user = await DefaultDB.frontend().get_user(user_uuid=user_uuid) + deployment = await DefaultDB.backend().find_model(model_uuid=deployment_uuid) if user["uuid"] != deployment["user_uuid"]: raise HTTPException( # pragma: no cover status_code=403, detail="User does not have access to this deployment" ) - async with get_db_connection() as db: - auth_tokens = await db.authtoken.find_many( - where={"deployment_uuid": deployment_uuid, "user_uuid": user_uuid}, - ) + auth_tokens = await DefaultDB.backend().find_many_auth_token( + user_uuid=user_uuid, deployment_uuid=deployment_uuid + ) return [ DeploymentAuthTokenInfo( - uuid=auth_token.uuid, name=auth_token.name, expiry=auth_token.expiry + uuid=auth_token["uuid"], + name=auth_token["name"], + expiry=auth_token["expiry"], ) for auth_token in auth_tokens ] @@ -407,24 +422,21 @@ async def delete_deployment_auth_token( deployment_uuid: str, auth_token_uuid: str, ) -> DeploymentAuthTokenInfo: - user = await get_user(user_uuid=user_uuid) - deployment = await find_model_using_raw(model_uuid=deployment_uuid) + user = await DefaultDB.frontend().get_user(user_uuid=user_uuid) + deployment = await DefaultDB.backend().find_model(model_uuid=deployment_uuid) if user["uuid"] != deployment["user_uuid"]: raise HTTPException( # pragma: no cover status_code=403, detail="User does not have access to this deployment" ) - async with get_db_connection() as db: - auth_token = await db.authtoken.delete( - where={ # type: ignore[typeddict-unknown-key] - "uuid": auth_token_uuid, - "deployment_uuid": deployment_uuid, - "user_uuid": user_uuid, - }, - ) + auth_token = await DefaultDB.backend().delete_auth_token( + auth_token_uuid=auth_token_uuid, + deployment_uuid=deployment_uuid, + user_uuid=user_uuid, + ) return DeploymentAuthTokenInfo( - uuid=auth_token.uuid, # type: ignore[union-attr] - name=auth_token.name, # type: ignore[union-attr] - expiry=auth_token.expiry, # type: ignore[union-attr] + uuid=auth_token["uuid"], # type: ignore[union-attr] + name=auth_token["name"], # type: ignore[union-attr] + expiry=auth_token["expiry"], # type: ignore[union-attr] ) diff --git a/fastagency/auth_token/auth.py b/fastagency/auth_token/auth.py index 15ea21aa..452e21fb 100644 --- a/fastagency/auth_token/auth.py +++ b/fastagency/auth_token/auth.py @@ -4,11 +4,12 @@ import string import uuid from datetime import datetime, timedelta +from typing import Union from fastapi import HTTPException from pydantic import BaseModel -from fastagency.db.helpers import find_model_using_raw, get_db_connection, get_user +from ..db.base import DefaultDB def generate_auth_token(length: int = 32) -> str: @@ -75,13 +76,13 @@ async def parse_expiry(expiry: str) -> datetime: async def create_deployment_auth_token( - user_uuid: str, - deployment_uuid: str, + user_uuid: Union[str, uuid.UUID], + deployment_uuid: Union[str, uuid.UUID], name: str = "Default deployment token", expiry: str = "99999d", ) -> DeploymentAuthToken: - user = await get_user(user_uuid=user_uuid) - deployment = await find_model_using_raw(model_uuid=deployment_uuid) + user = await DefaultDB.frontend().get_user(user_uuid=user_uuid) + deployment = await DefaultDB.backend().find_model(model_uuid=deployment_uuid) if user["uuid"] != deployment["user_uuid"]: raise HTTPException( @@ -92,17 +93,14 @@ async def create_deployment_auth_token( auth_token = generate_auth_token() hashed_token = hash_auth_token(auth_token) - async with get_db_connection() as db: - await db.authtoken.create( # type: ignore[attr-defined] - data={ - "uuid": str(uuid.uuid4()), - "name": name, - "user_uuid": user_uuid, - "deployment_uuid": deployment_uuid, - "auth_token": hashed_token, - "expiry": expiry, - "expires_at": expires_at, - } - ) + await DefaultDB.backend().create_auth_token( + auth_token_uuid=uuid.uuid4(), + name=name, + user_uuid=user_uuid, + deployment_uuid=deployment_uuid, + hashed_auth_token=hashed_token, + expiry=expiry, + expires_at=expires_at, + ) return DeploymentAuthToken(auth_token=auth_token) diff --git a/fastagency/db/base.py b/fastagency/db/base.py new file mode 100644 index 00000000..3550694b --- /dev/null +++ b/fastagency/db/base.py @@ -0,0 +1,111 @@ +from contextlib import contextmanager +from datetime import datetime +from typing import ( + Any, + Dict, + Generator, + List, + Optional, + Protocol, + Union, + runtime_checkable, +) +from uuid import UUID + + +class KeyNotFoundError(ValueError): + pass + + +class KeyExistsError(ValueError): + pass + + +@runtime_checkable +class BackendDBProtocol(Protocol): + async def create_model( + self, + model_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + type_name: str, + model_name: str, + json_str: str, + ) -> Dict[str, Any]: ... + + async def find_model(self, model_uuid: Union[str, UUID]) -> Dict[str, Any]: ... + + async def find_many_model( + self, user_uuid: Union[str, UUID], type_name: Optional[str] = None + ) -> List[Dict[str, Any]]: ... + + async def update_model( + self, + model_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + type_name: str, + model_name: str, + json_str: str, + ) -> Dict[str, Any]: ... + + async def delete_model(self, model_uuid: Union[str, UUID]) -> Dict[str, Any]: ... + + async def create_auth_token( + self, + auth_token_uuid: Union[str, UUID], + name: str, + user_uuid: Union[str, UUID], + deployment_uuid: Union[str, UUID], + hashed_auth_token: str, + expiry: str, + expires_at: datetime, + ) -> Dict[str, Any]: ... + + async def find_many_auth_token( + self, user_uuid: Union[str, UUID], deployment_uuid: Union[str, UUID] + ) -> List[Dict[str, Any]]: ... + + async def delete_auth_token( + self, + auth_token_uuid: Union[str, UUID], + deployment_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + ) -> Dict[str, Any]: ... + + +@runtime_checkable +class FrontendDBProtocol(Protocol): + async def get_user(self, user_uuid: Union[str, UUID]) -> Dict[str, Any]: ... + + async def _create_user( + self, user_uuid: Union[str, UUID], email: str, username: str + ) -> Union[str, UUID]: ... + + +class DefaultDB: + _backend_db: Optional[BackendDBProtocol] = None + _frontend_db: Optional[FrontendDBProtocol] = None + + @staticmethod + @contextmanager + def set( + *, + backend_db: BackendDBProtocol, + frontend_db: FrontendDBProtocol, + ) -> Generator[None, None, None]: + old_backend_default = DefaultDB._backend_db + old_frontend_default = DefaultDB._frontend_db + try: + DefaultDB._backend_db = backend_db + DefaultDB._frontend_db = frontend_db + yield + finally: + DefaultDB._backend_db = old_backend_default + DefaultDB._frontend_db = old_frontend_default + + @staticmethod + def backend() -> BackendDBProtocol: + return DefaultDB._backend_db # type: ignore[return-value] + + @staticmethod + def frontend() -> FrontendDBProtocol: + return DefaultDB._frontend_db # type: ignore[return-value] diff --git a/fastagency/db/helpers.py b/fastagency/db/helpers.py deleted file mode 100644 index 221e0d75..00000000 --- a/fastagency/db/helpers.py +++ /dev/null @@ -1,63 +0,0 @@ -from contextlib import asynccontextmanager -from os import environ -from typing import Any, AsyncGenerator, Dict, Optional, Union -from uuid import UUID - -from fastapi import HTTPException -from prisma import Prisma # type: ignore[attr-defined] - - -@asynccontextmanager -async def get_db_connection( - db_url: Optional[str] = None, -) -> AsyncGenerator[Prisma, None]: - if not db_url: - db_url = environ.get("PY_DATABASE_URL", None) - if not db_url: - raise ValueError( - "No database URL provided nor set as environment variable 'PY_DATABASE_URL'" - ) # pragma: no cover - if "connect_timeout" not in db_url: - db_url += "?connect_timeout=60" - db = Prisma(datasource={"url": db_url}) - await db.connect() - try: - yield db - finally: - await db.disconnect() - - -async def get_wasp_db_url() -> str: - wasp_db_url: str = environ.get("DATABASE_URL") # type: ignore[assignment] - if "connect_timeout" not in wasp_db_url: - wasp_db_url += "?connect_timeout=60" - return wasp_db_url - - -async def find_model_using_raw(model_uuid: Union[str, UUID]) -> Dict[str, Any]: - if isinstance(model_uuid, UUID): - model_uuid = str(model_uuid) - - async with get_db_connection() as db: - model: Optional[Dict[str, Any]] = await db.query_first( - 'SELECT * from "Model" where uuid=' # nosec: [B608] - + f"'{model_uuid}'" - ) - - if not model: - raise HTTPException( - status_code=404, detail="Something went wrong. Please try again later." - ) - return model - - -async def get_user(user_uuid: Union[int, str]) -> Any: - wasp_db_url = await get_wasp_db_url() - async with get_db_connection(db_url=wasp_db_url) as db: - select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608] - user = await db.query_first( - select_query # nosec: [B608] - ) - if not user: - raise HTTPException(status_code=404, detail=f"user_uuid {user_uuid} not found") - return user diff --git a/fastagency/db/inmemory.py b/fastagency/db/inmemory.py new file mode 100644 index 00000000..8a24c7fe --- /dev/null +++ b/fastagency/db/inmemory.py @@ -0,0 +1,145 @@ +import json +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from .base import BackendDBProtocol, FrontendDBProtocol, KeyNotFoundError + + +class InMemoryBackendDB(BackendDBProtocol): + def __init__(self) -> None: + """In memory backend database.""" + self._models: List[Dict[str, Any]] = [] + self._auth_tokens: List[Dict[str, Any]] = [] + + async def create_model( + self, + model_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + type_name: str, + model_name: str, + json_str: str, + ) -> Dict[str, Any]: + model = { + "uuid": str(model_uuid), + "user_uuid": str(user_uuid), + "type_name": type_name, + "model_name": model_name, + "json_str": json.loads(json_str), + "created_at": datetime.now(), + "updated_at": datetime.now(), + } + self._models.append(model) + return model + + async def find_model(self, model_uuid: Union[str, UUID]) -> Dict[str, Any]: + for model in self._models: + if model["uuid"] == str(model_uuid): + return model + raise KeyNotFoundError(f"model_uuid {model_uuid} not found") + + async def find_many_model( + self, user_uuid: Union[str, UUID], type_name: Optional[str] = None + ) -> List[Dict[str, Any]]: + return [model for model in self._models if model["user_uuid"] == str(user_uuid)] + + async def update_model( + self, + model_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + type_name: str, + model_name: str, + json_str: str, + ) -> Dict[str, Any]: + for model in self._models: + if model["uuid"] == str(model_uuid): + model["type_name"] = type_name + model["model_name"] = model_name + model["json_str"] = json.loads(json_str) + model["updated_at"] = datetime.now() + return model + raise KeyNotFoundError(f"model_uuid {model_uuid} not found") + + async def delete_model(self, model_uuid: Union[str, UUID]) -> Dict[str, Any]: + for model in self._models: + if model["uuid"] == str(model_uuid): + self._models.remove(model) + return model + raise KeyNotFoundError(f"model_uuid {model_uuid} not found") + + async def create_auth_token( + self, + auth_token_uuid: Union[str, UUID], + name: str, + user_uuid: Union[str, UUID], + deployment_uuid: Union[str, UUID], + hashed_auth_token: str, + expiry: str, + expires_at: datetime, + ) -> Dict[str, Any]: + auth_token = { + "uuid": str(auth_token_uuid), + "name": name, + "user_uuid": str(user_uuid), + "deployment_uuid": str(deployment_uuid), + "hashed_auth_token": hashed_auth_token, + "expiry": expiry, + "expires_at": expires_at, + "created_at": datetime.now(), + "updated_at": datetime.now(), + } + self._auth_tokens.append(auth_token) + return auth_token + + async def find_many_auth_token( + self, user_uuid: Union[str, UUID], deployment_uuid: Union[str, UUID] + ) -> List[Dict[str, Any]]: + return [ + auth_token + for auth_token in self._auth_tokens + if auth_token["user_uuid"] == str(user_uuid) + and auth_token["deployment_uuid"] == str(deployment_uuid) + ] + + async def delete_auth_token( + self, + auth_token_uuid: Union[str, UUID], + deployment_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + ) -> Dict[str, Any]: + for auth_token in self._auth_tokens: + if ( + auth_token["uuid"] == str(auth_token_uuid) + and auth_token["user_uuid"] == str(user_uuid) + and auth_token["deployment_uuid"] == str(deployment_uuid) + ): + self._auth_tokens.remove(auth_token) + return auth_token + raise KeyNotFoundError(f"auth_token_uuid {auth_token_uuid} not found") + + +class InMemoryFrontendDB(FrontendDBProtocol): + def __init__(self) -> None: + """In memory frontend database.""" + self._users: List[Dict[str, Any]] = [] + + async def get_user(self, user_uuid: Union[str, UUID]) -> Any: + for user in self._users: + if user["uuid"] == str(user_uuid): + return user + raise KeyNotFoundError(f"user_uuid {user_uuid} not found") + + async def _create_user( + self, user_uuid: Union[str, UUID], email: str, username: str + ) -> Union[str, UUID]: + """Only to create user in testing.""" + self._users.append( + { + "uuid": str(user_uuid), + "email": email, + "username": username, + "created_at": datetime.now(), + "updated_at": datetime.now(), + } + ) + return user_uuid diff --git a/fastagency/db/prisma.py b/fastagency/db/prisma.py new file mode 100644 index 00000000..da473d6a --- /dev/null +++ b/fastagency/db/prisma.py @@ -0,0 +1,217 @@ +from contextlib import asynccontextmanager +from datetime import datetime +from os import environ +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union +from uuid import UUID + +from prisma import Prisma # type: ignore[attr-defined] + +from .base import BackendDBProtocol, DefaultDB, FrontendDBProtocol, KeyNotFoundError + +if TYPE_CHECKING: + from fastapi import FastAPI + from faststream import ContextRepo + + +class PrismaBaseDB: + ENV_VAR: str + + @staticmethod + async def _get_db_url(env_var: str) -> str: + db_url: Optional[str] = environ.get(env_var, None) + if not db_url: + raise ValueError( + f"No database URL provided nor set as environment variable '{env_var}'" + ) + if "connect_timeout" not in db_url: + db_url += "?connect_timeout=60" + return db_url + + @asynccontextmanager + async def _get_db_connection(self) -> AsyncGenerator[Prisma, None]: + db_url = await self._get_db_url(self.ENV_VAR) + db = Prisma(datasource={"url": db_url}) + await db.connect() + try: + yield db + finally: + await db.disconnect() + + +class PrismaBackendDB(BackendDBProtocol, PrismaBaseDB): + ENV_VAR = "PY_DATABASE_URL" + + async def create_model( + self, + model_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + type_name: str, + model_name: str, + json_str: str, + ) -> Dict[str, Any]: + async with self._get_db_connection() as db: + created_model = await db.model.create( + data={ + "uuid": str(model_uuid), + "user_uuid": str(user_uuid), + "type_name": type_name, + "model_name": model_name, + "json_str": json_str, # type: ignore[typeddict-item] + } + ) + return created_model.model_dump() # type: ignore[no-any-return] + + async def find_model(self, model_uuid: Union[str, UUID]) -> Dict[str, Any]: + model_uuid = str(model_uuid) + async with self._get_db_connection() as db: + model: Optional[Dict[str, Any]] = await db.query_first( + 'SELECT * from "Model" where uuid=' # nosec: [B608] + + f"'{model_uuid}'" + ) + if not model: + raise KeyNotFoundError(f"model_uuid {model_uuid} not found") + return model + + async def find_many_model( + self, user_uuid: Union[str, UUID], type_name: Optional[str] = None + ) -> List[Dict[str, Any]]: + filters: Dict[str, Any] = {"user_uuid": str(user_uuid)} + if type_name: + filters["type_name"] = type_name + + async with self._get_db_connection() as db: + models = await db.model.find_many(where=filters) # type: ignore[arg-type] + return [model.model_dump() for model in models] + + async def update_model( + self, + model_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + type_name: str, + model_name: str, + json_str: str, + ) -> Dict[str, Any]: + async with self._get_db_connection() as db: + updated_model = await db.model.update( + where={"uuid": str(model_uuid)}, # type: ignore[arg-type] + data={ # type: ignore[typeddict-unknown-key] + "type_name": type_name, + "model_name": model_name, + "json_str": json_str, # type: ignore[typeddict-item] + "user_uuid": str(user_uuid), + }, + ) + if updated_model is None: + raise KeyNotFoundError(f"model_uuid {model_uuid} not found") + return updated_model.model_dump() # type: ignore[no-any-return,union-attr] + + async def delete_model(self, model_uuid: Union[str, UUID]) -> Dict[str, Any]: + async with self._get_db_connection() as db: + deleted_model = await db.model.delete(where={"uuid": str(model_uuid)}) + if deleted_model is None: + raise KeyNotFoundError(f"model_uuid {model_uuid} not found") + return deleted_model.model_dump() # type: ignore[no-any-return,union-attr] + + async def create_auth_token( + self, + auth_token_uuid: Union[str, UUID], + name: str, + user_uuid: Union[str, UUID], + deployment_uuid: Union[str, UUID], + hashed_auth_token: str, + expiry: str, + expires_at: datetime, + ) -> Dict[str, Any]: + async with self._get_db_connection() as db: + created_auth_token = await db.authtoken.create( # type: ignore[attr-defined] + data={ + "uuid": str(auth_token_uuid), + "name": name, + "user_uuid": str(user_uuid), + "deployment_uuid": str(deployment_uuid), + "auth_token": hashed_auth_token, + "expiry": expiry, + "expires_at": expires_at, + } + ) + return created_auth_token.model_dump() # type: ignore[no-any-return,union-attr] + + async def find_many_auth_token( + self, user_uuid: Union[str, UUID], deployment_uuid: Union[str, UUID] + ) -> List[Dict[str, Any]]: + async with self._get_db_connection() as db: + auth_tokens = await db.authtoken.find_many( + where={ + "deployment_uuid": str(deployment_uuid), + "user_uuid": str(user_uuid), + }, + ) + return [auth_token.model_dump() for auth_token in auth_tokens] + + async def delete_auth_token( + self, + auth_token_uuid: Union[str, UUID], + deployment_uuid: Union[str, UUID], + user_uuid: Union[str, UUID], + ) -> Dict[str, Any]: + async with self._get_db_connection() as db: + deleted_auth_token = await db.authtoken.delete( + where={ # type: ignore[typeddict-unknown-key] + "uuid": str(auth_token_uuid), + "deployment_uuid": str(deployment_uuid), + "user_uuid": str(user_uuid), + }, + ) + if deleted_auth_token is None: + raise KeyNotFoundError(f"auth_token_uuid {auth_token_uuid} not found") + return deleted_auth_token.model_dump() # type: ignore[no-any-return,union-attr] + + +class PrismaFrontendDB(FrontendDBProtocol, PrismaBaseDB): # type: ignore[misc] + ENV_VAR = "DATABASE_URL" + + async def get_user(self, user_uuid: Union[str, UUID]) -> Any: + async with self._get_db_connection() as db: + select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608] + user = await db.query_first( + select_query # nosec: [B608] + ) + if not user: + raise KeyNotFoundError(f"user_uuid {user_uuid} not found") + return user + + async def _create_user( + self, user_uuid: Union[str, UUID], email: str, username: str + ) -> Union[str, UUID]: + """Only to create user in testing.""" + async with self._get_db_connection() as db: + insert_query = ( + 'INSERT INTO "User" (email, username, uuid) VALUES (' # nosec: [B608] + + f"'{email}', '{username}', '{user_uuid}')" + ) + await db.execute_raw(insert_query) + + return user_uuid + + +@asynccontextmanager +async def _lifespan() -> AsyncGenerator[None, None]: + prisma_backend_db = PrismaBackendDB() + prisma_frontend_db = PrismaFrontendDB() + + with ( + DefaultDB.set(backend_db=prisma_backend_db, frontend_db=prisma_frontend_db), + ): + yield + + +@asynccontextmanager +async def fastapi_lifespan(app: "FastAPI") -> AsyncGenerator[None, None]: + async with _lifespan(): + yield + + +@asynccontextmanager +async def faststream_lifespan(context: "ContextRepo") -> AsyncGenerator[None, None]: + async with _lifespan(): + yield diff --git a/fastagency/helpers.py b/fastagency/helpers.py index 0788c434..a22a2e12 100644 --- a/fastagency/helpers.py +++ b/fastagency/helpers.py @@ -13,9 +13,7 @@ ) from .auth_token.auth import create_deployment_auth_token - -# from fastagency.app import add_model -from .db.helpers import find_model_using_raw, get_db_connection, get_user +from .db.base import DefaultDB from .models.base import Model, ObjectReference from .models.registry import Registry @@ -23,7 +21,7 @@ async def get_model_by_uuid(model_uuid: Union[str, UUID]) -> Model: - model_dict = await find_model_using_raw(model_uuid=model_uuid) + model_dict = await DefaultDB.backend().find_model(model_uuid=model_uuid) registry = Registry.get_default() model = registry.validate( @@ -43,13 +41,12 @@ async def validate_tokens_and_create_gh_repo( model: Dict[str, Any], model_uuid: str, ) -> SaasAppGenerator: - async with get_db_connection(): - found_gh_token = await find_model_using_raw( - model_uuid=model["gh_token"]["uuid"] - ) - found_fly_token = await find_model_using_raw( - model_uuid=model["fly_token"]["uuid"] - ) + found_gh_token = await DefaultDB.backend().find_model( + model_uuid=model["gh_token"]["uuid"] + ) + found_fly_token = await DefaultDB.backend().find_model( + model_uuid=model["fly_token"]["uuid"] + ) found_gh_token_uuid = found_gh_token["json_str"]["gh_token"] found_fly_token_uuid = found_fly_token["json_str"]["fly_token"] @@ -81,19 +78,15 @@ async def deploy_saas_app( await asyncify(saas_app.execute)() - async with get_db_connection() as db: - found_model = await find_model_using_raw(model_uuid=model_uuid) - found_model["json_str"]["app_deploy_status"] = "completed" - - await db.model.update( - where={"uuid": found_model["uuid"]}, # type: ignore[arg-type] - data={ # type: ignore[typeddict-unknown-key] - "type_name": type_name, - "model_name": model_name, - "json_str": json.dumps(found_model["json_str"]), # type: ignore[typeddict-item] - "user_uuid": user_uuid, - }, - ) + found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid) + found_model["json_str"]["app_deploy_status"] = "completed" + await DefaultDB.backend().update_model( + model_uuid=found_model["uuid"], + user_uuid=user_uuid, + type_name=type_name, + model_name=model_name, + json_str=json.dumps(found_model["json_str"]), + ) async def add_model_to_user( @@ -125,17 +118,14 @@ async def add_model_to_user( updated_validated_model_dict["gh_repo_url"] = saas_app.gh_repo_url validated_model_json = json.dumps(updated_validated_model_dict) - await get_user(user_uuid=user_uuid) - async with get_db_connection() as db: - await db.model.create( - data={ - "uuid": model_uuid, - "user_uuid": user_uuid, - "type_name": type_name, - "model_name": model_name, - "json_str": validated_model_json, # type: ignore[typeddict-item] - } - ) + await DefaultDB.frontend().get_user(user_uuid=user_uuid) + await DefaultDB.backend().create_model( + model_uuid=model_uuid, + user_uuid=user_uuid, + type_name=type_name, + model_name=model_name, + json_str=validated_model_json, + ) if saas_app is not None: background_tasks.add_task( @@ -204,13 +194,10 @@ async def create_model_ref( async def get_all_models_for_user( user_uuid: Union[str, UUID], type_name: Optional[str] = None, -) -> List[Any]: - filters: Dict[str, Any] = {"user_uuid": user_uuid} - if type_name: - filters["type_name"] = type_name - - async with get_db_connection() as db: - models = await db.model.find_many(where=filters) # type: ignore[arg-type] +) -> List[Dict[str, Any]]: + models = await DefaultDB.backend().find_many_model( + user_uuid=user_uuid, type_name=type_name + ) return models # type: ignore[no-any-return] diff --git a/fastagency/io/app.py b/fastagency/io/app.py index 2a829831..add9a975 100644 --- a/fastagency/io/app.py +++ b/fastagency/io/app.py @@ -4,6 +4,8 @@ from faststream import FastStream from faststream.nats import JStream, NatsBroker +from ..db.prisma import faststream_lifespan + nats_url: Optional[str] = environ.get("NATS_URL", None) # type: ignore[assignment] if nats_url is None: domain: str = environ.get("DOMAIN") # type: ignore[assignment] @@ -15,8 +17,9 @@ print(f"{nats_url=}") # noqa print("Starting IONats faststream app...") # noqa + broker = NatsBroker(nats_url, user=username, password=password) -app = FastStream(broker) +app = FastStream(broker, lifespan=faststream_lifespan) stream = JStream( name="FastAgency", diff --git a/fastagency/io/ionats.py b/fastagency/io/ionats.py index cf680d07..294ae3cf 100644 --- a/fastagency/io/ionats.py +++ b/fastagency/io/ionats.py @@ -13,7 +13,7 @@ from nats.js import api from pydantic import BaseModel -from ..db.helpers import find_model_using_raw +from ..db.base import DefaultDB from ..models.teams.multi_agent_team import MultiAgentTeam from ..models.teams.two_agent_teams import TwoAgentTeam from .app import app, broker, stream # noqa @@ -166,7 +166,7 @@ class InitiateModel(BaseModel): async def create_team( team_id: UUID, user_id: UUID ) -> Callable[[str], List[Dict[str, Any]]]: - team_dict = await find_model_using_raw(team_id) + team_dict = await DefaultDB.backend().find_model(team_id) team_model: Union[TwoAgentTeam, MultiAgentTeam] if "initial_agent" in team_dict["json_str"]: diff --git a/fastagency/models/agents/base.py b/fastagency/models/agents/base.py index 4c1b046c..9fb947ff 100644 --- a/fastagency/models/agents/base.py +++ b/fastagency/models/agents/base.py @@ -6,7 +6,6 @@ from fastagency.openapi.client import Client -from ...db.helpers import find_model_using_raw from ..base import Model from ..registry import Registry from ..toolboxes.toolbox import ToolboxRef @@ -59,9 +58,8 @@ async def get_clients_from_toolboxes(self, user_id: UUID) -> List[Client]: if toolbox_property is None: continue - toolbox_dict = await find_model_using_raw(toolbox_property.uuid) - toolbox_model = toolbox_property.get_data_model()( - **toolbox_dict["json_str"] + toolbox_model = await toolbox_property.get_data_model().from_db( + toolbox_property.uuid ) client = await toolbox_model.create_autogen(toolbox_property.uuid, user_id) clients.append(client) diff --git a/fastagency/models/base.py b/fastagency/models/base.py index 36ae7ece..68d42319 100644 --- a/fastagency/models/base.py +++ b/fastagency/models/base.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, create_model, model_validator from typing_extensions import TypeAlias -from ..db.helpers import find_model_using_raw +from ..db.base import DefaultDB M = TypeVar("M", bound="Model") @@ -39,7 +39,7 @@ async def create_autogen( @classmethod async def from_db(cls: Type[T], model_id: UUID) -> T: - my_model_dict = await find_model_using_raw(model_id) + my_model_dict = await DefaultDB.backend().find_model(model_id) my_model = cls(**my_model_dict["json_str"]) return my_model diff --git a/fastagency/models/llms/openai.py b/fastagency/models/llms/openai.py index e08e59b6..e5b522a5 100644 --- a/fastagency/models/llms/openai.py +++ b/fastagency/models/llms/openai.py @@ -9,23 +9,24 @@ from ..registry import register OpenAIModels: TypeAlias = Literal[ - "gpt-4-turbo-2024-04-09", "gpt-4-1106-preview", - "gpt-4-turbo", "gpt-4-turbo-preview", + "gpt-4o-mini", "gpt-4-0125-preview", - "gpt-4o-2024-05-13", - "gpt-3.5-turbo", - "gpt-3.5-turbo-instruct", - "gpt-3.5-turbo-instruct-0914", "gpt-4o-mini-2024-07-18", - "gpt-4o-mini", + "gpt-3.5-turbo", "gpt-3.5-turbo-16k", + "gpt-4-turbo-2024-04-09", "gpt-3.5-turbo-0125", + "gpt-4-turbo", "gpt-3.5-turbo-1106", - "gpt-4-0613", + "gpt-3.5-turbo-instruct-0914", + "gpt-3.5-turbo-instruct", "gpt-4o", + "gpt-4-0613", + "gpt-4o-2024-05-13", "gpt-4", + "gpt-4o-2024-08-06", ] __all__ = [ diff --git a/fastagency/models/llms/together.py b/fastagency/models/llms/together.py index df1197b0..cf1912a0 100644 --- a/fastagency/models/llms/together.py +++ b/fastagency/models/llms/together.py @@ -17,6 +17,7 @@ "WizardLM v1.2 (13B)": "WizardLM/WizardLM-13B-V1.2", "Code Llama Instruct (34B)": "togethercomputer/CodeLlama-34b-Instruct", "Upstage SOLAR Instruct v1 (11B)": "upstage/SOLAR-10.7B-Instruct-v1.0", + "Meta Llama 3 70B Reference": "meta-llama/Llama-3-70b-chat-hf", "OpenHermes-2-Mistral (7B)": "teknium/OpenHermes-2-Mistral-7B", "LLaMA-2-7B-32K-Instruct (7B)": "togethercomputer/Llama-2-7B-32K-Instruct", "ReMM SLERP L2 (13B)": "Undi95/ReMM-SLERP-L2-13B", @@ -40,24 +41,23 @@ "Llama3 8B Chat HF INT4": "togethercomputer/Llama-3-8b-chat-hf-int4", "OpenHermes-2.5-Mistral (7B)": "teknium/OpenHermes-2p5-Mistral-7B", "Nous Capybara v1.9 (7B)": "NousResearch/Nous-Capybara-7B-V1p9", + "Meta Llama 3.1 70B Instruct Turbo": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "Nous Hermes 2 - Mistral DPO (7B)": "NousResearch/Nous-Hermes-2-Mistral-7B-DPO", "StripedHyena Nous (7B)": "togethercomputer/StripedHyena-Nous-7B", "Alpaca (7B)": "togethercomputer/alpaca-7b", "Platypus2 Instruct (70B)": "garage-bAInd/Platypus2-70B-instruct", "Gemma Instruct (2B)": "google/gemma-2b-it", "Gemma Instruct (7B)": "google/gemma-7b-it", + "LLaMA-2 Chat (7B)": "togethercomputer/llama-2-7b-chat", "OLMo Instruct (7B)": "allenai/OLMo-7B-Instruct", "Qwen 1.5 Chat (4B)": "Qwen/Qwen1.5-4B-Chat", "MythoMax-L2 (13B)": "Gryphe/MythoMax-L2-13b", - "Meta Llama 3 70B Reference": "meta-llama/Llama-3-70b-chat-hf", "Mistral (7B) Instruct": "mistralai/Mistral-7B-Instruct-v0.1", "Mistral (7B) Instruct v0.2": "mistralai/Mistral-7B-Instruct-v0.2", "Meta Llama 3.1 8B Instruct Turbo": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "OpenOrca Mistral (7B) 8K": "Open-Orca/Mistral-7B-OpenOrca", "Nous Hermes LLaMA-2 (7B)": "NousResearch/Nous-Hermes-llama-2-7b", "Qwen 1.5 Chat (32B)": "Qwen/Qwen1.5-32B-Chat", - "Meta Llama 3.1 405B Instruct Turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", - "Meta Llama 3.1 70B Instruct Turbo": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "Qwen 2 Instruct (72B)": "Qwen/Qwen2-72B-Instruct", "Qwen 1.5 Chat (72B)": "Qwen/Qwen1.5-72B-Chat", "DeepSeek LLM Chat (67B)": "deepseek-ai/deepseek-llm-67b-chat", @@ -74,27 +74,24 @@ "Gemma-2 Instruct (9B)": "google/gemma-2-9b-it", "Meta Llama 3 8B Reference": "meta-llama/Llama-3-8b-chat-hf", "Mixtral-8x7B Instruct v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", + "Code Llama Instruct (70B)": "codellama/CodeLlama-70b-Instruct-hf", + "Meta Llama 3.1 405B Instruct Turbo": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "DBRX Instruct": "databricks/dbrx-instruct", "Meta Llama 3.1 8B Instruct": "meta-llama/Meta-Llama-3.1-8B-Instruct-Reference", "Meta Llama 3 8B Instruct Turbo": "meta-llama/Meta-Llama-3-8B-Instruct-Turbo", "Dolphin 2.5 Mixtral 8x7b": "cognitivecomputations/dolphin-2.5-mixtral-8x7b", "Mixtral-8x22B Instruct v0.1": "mistralai/Mixtral-8x22B-Instruct-v0.1", - "Code Llama Instruct (70B)": "codellama/CodeLlama-70b-Instruct-hf", "Meta Llama 3 8B Instruct Lite": "meta-llama/Meta-Llama-3-8B-Instruct-Lite", - "LLaMA-2 Chat (7B)": "togethercomputer/llama-2-7b-chat", "LLaMA-2 Chat (70B)": "togethercomputer/llama-2-70b-chat", "Koala (7B)": "togethercomputer/Koala-7B", "Qwen 2 Instruct (1.5B)": "Qwen/Qwen2-1.5B-Instruct", "Qwen 2 Instruct (7B)": "Qwen/Qwen2-7B-Instruct", "Guanaco (65B) ": "togethercomputer/guanaco-65b", "Vicuna v1.3 (7B)": "lmsys/vicuna-7b-v1.3", - "Qwen 2 (72B)": "Qwen/Qwen2-72B", "Nous Hermes LLaMA-2 (70B)": "NousResearch/Nous-Hermes-Llama2-70b", "Vicuna v1.5 16K (13B)": "lmsys/vicuna-13b-v1.5-16k", "Zephyr-7B-ß": "HuggingFaceH4/zephyr-7b-beta", "Guanaco (13B) ": "togethercomputer/guanaco-13b", - "Qwen 2 (7B)": "Qwen/Qwen2-7B", - "Qwen 2 (1.5B)": "Qwen/Qwen2-1.5B", "Vicuna v1.3 (13B)": "lmsys/vicuna-13b-v1.3", "Guanaco (33B) ": "togethercomputer/guanaco-33b", "Koala (13B)": "togethercomputer/Koala-13B", @@ -109,6 +106,7 @@ "carson ml318br": "carson/ml318br", "Llama-3 70B Instruct Gradient 1048K": "gradientai/Llama-3-70B-Instruct-Gradient-1048k", "Meta Llama 3.1 70B Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct-Reference", + "Meta Llama 3.1 70B": "meta-llama/Meta-Llama-3.1-70B-Reference", } TogetherModels: TypeAlias = Literal[tuple(together_model_string.keys())] # type: ignore[valid-type] diff --git a/tests/app/test_model_routes.py b/tests/app/test_model_routes.py index 43ab215c..894e10db 100644 --- a/tests/app/test_model_routes.py +++ b/tests/app/test_model_routes.py @@ -1,3 +1,4 @@ +import random import uuid from typing import List, Optional from unittest.mock import AsyncMock, patch @@ -6,6 +7,7 @@ from fastapi.testclient import TestClient from fastagency.app import app, mask +from fastagency.db.base import DefaultDB from fastagency.models.llms.azure import AzureOAIAPIKey from fastagency.saas_app_generator import SaasAppGenerator @@ -104,10 +106,16 @@ async def test_get_all_models( assert actual[i][key] == expected[i][key] @pytest.mark.asyncio() - async def test_setup_user(self, user_uuid: str) -> None: + async def test_setup_user(self) -> None: + random_id = random.randint(1, 1_000_000) + user_uuid = await DefaultDB.frontend()._create_user( + user_uuid=uuid.uuid4(), + email=f"user{random_id}@airt.ai", + username=f"user{random_id}", + ) # Call setup route for user response = client.get(f"/user/{user_uuid}/setup") - assert response.status_code == 200, response + assert response.status_code == 200, response.text expected_setup = { "name": "WeatherToolbox", "openapi_url": "https://weather.tools.staging.fastagency.ai/openapi.json", @@ -122,7 +130,7 @@ async def test_setup_user(self, user_uuid: str) -> None: ) assert response.status_code == 200 expected_toolbox_model = { - "user_uuid": user_uuid, + "user_uuid": str(user_uuid), "type_name": "toolbox", "model_name": "Toolbox", "json_str": { @@ -299,11 +307,17 @@ async def test_background_task_not_called_on_error(self, user_uuid: str) -> None } type_name = "deployment" model_name = "Deployment" - model_uuid = str(uuid.uuid4()) + model_uuid = uuid.uuid4() with ( - patch("fastagency.app.get_user", side_effect=Exception()), - patch("fastagency.db.helpers.get_db_connection", side_effect=Exception()), + patch( + "fastagency.app.DefaultDB._frontend_db.get_user", + side_effect=Exception(), + ), + patch( + "fastagency.db.prisma.PrismaBackendDB._get_db_connection", + side_effect=Exception(), + ), patch("fastagency.helpers.deploy_saas_app") as mock_task, ): response = client.post( diff --git a/tests/app/test_openai_extensively.py b/tests/app/test_openai_extensively.py index 2645c4f8..d41e0416 100644 --- a/tests/app/test_openai_extensively.py +++ b/tests/app/test_openai_extensively.py @@ -144,9 +144,9 @@ def test_validate_incorrect_model(self, model_dict: Dict[str, Any]) -> None: expected = { "type": "literal_error", "loc": ["model"], - "msg": "Input should be 'gpt-4-turbo-2024-04-09', 'gpt-4-1106-preview', 'gpt-4-turbo', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'gpt-4o-2024-05-13', 'gpt-3.5-turbo', 'gpt-3.5-turbo-instruct', 'gpt-3.5-turbo-instruct-0914', 'gpt-4o-mini-2024-07-18', 'gpt-4o-mini', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo-1106', 'gpt-4-0613', 'gpt-4o' or 'gpt-4'", + "msg": "Input should be 'gpt-4-1106-preview', 'gpt-4-turbo-preview', 'gpt-4o-mini', 'gpt-4-0125-preview', 'gpt-4o-mini-2024-07-18', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-turbo-2024-04-09', 'gpt-3.5-turbo-0125', 'gpt-4-turbo', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-instruct-0914', 'gpt-3.5-turbo-instruct', 'gpt-4o', 'gpt-4-0613', 'gpt-4o-2024-05-13', 'gpt-4' or 'gpt-4o-2024-08-06'", "ctx": { - "expected": "'gpt-4-turbo-2024-04-09', 'gpt-4-1106-preview', 'gpt-4-turbo', 'gpt-4-turbo-preview', 'gpt-4-0125-preview', 'gpt-4o-2024-05-13', 'gpt-3.5-turbo', 'gpt-3.5-turbo-instruct', 'gpt-3.5-turbo-instruct-0914', 'gpt-4o-mini-2024-07-18', 'gpt-4o-mini', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0125', 'gpt-3.5-turbo-1106', 'gpt-4-0613', 'gpt-4o' or 'gpt-4'" + "expected": "'gpt-4-1106-preview', 'gpt-4-turbo-preview', 'gpt-4o-mini', 'gpt-4-0125-preview', 'gpt-4o-mini-2024-07-18', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-4-turbo-2024-04-09', 'gpt-3.5-turbo-0125', 'gpt-4-turbo', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-instruct-0914', 'gpt-3.5-turbo-instruct', 'gpt-4o', 'gpt-4-0613', 'gpt-4o-2024-05-13', 'gpt-4' or 'gpt-4o-2024-08-06'" }, } # print(f"{msg_dict=}") diff --git a/tests/auth_token/test_auth_token.py b/tests/auth_token/test_auth_token.py index 04581065..9fe3f6f6 100644 --- a/tests/auth_token/test_auth_token.py +++ b/tests/auth_token/test_auth_token.py @@ -1,6 +1,7 @@ import uuid from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, Union +from uuid import UUID import pytest from fastapi import HTTPException @@ -8,6 +9,9 @@ import fastagency.app import fastagency.auth_token.auth +import fastagency.db +import fastagency.db.inmemory +import fastagency.db.prisma from fastagency.app import app from fastagency.auth_token.auth import ( create_deployment_auth_token, @@ -76,16 +80,18 @@ async def test_parse_expiry_with_invalid_expiry(expiry_str: str, expected: str) async def test_create_deployment_token( user_uuid: str, monkeypatch: pytest.MonkeyPatch ) -> None: - deployment_uuid = str(uuid.uuid4()) + deployment_uuid = uuid.uuid4() - async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str]: + async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, Union[str, UUID]]: return { "user_uuid": user_uuid, "uuid": deployment_uuid, } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) token = await create_deployment_auth_token(user_uuid, deployment_uuid) @@ -98,16 +104,18 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] async def test_create_deployment_token_with_wrong_user_uuid( user_uuid: str, monkeypatch: pytest.MonkeyPatch ) -> None: - deployment_uuid = str(uuid.uuid4()) + deployment_uuid = uuid.uuid4() - async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str]: + async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, Union[str, UUID]]: return { "user_uuid": "random_wrong_uuid", "uuid": deployment_uuid, } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) with pytest.raises(HTTPException) as e: @@ -122,16 +130,18 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] async def test_create_deployment_auth_token_route( user_uuid: str, monkeypatch: pytest.MonkeyPatch ) -> None: - deployment_uuid = str(uuid.uuid4()) + deployment_uuid = uuid.uuid4() - async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str]: + async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, Union[str, UUID]]: return { "user_uuid": user_uuid, "uuid": deployment_uuid, } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) response = client.post( @@ -148,16 +158,18 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] async def test_get_all_deployment_auth_tokens( user_uuid: str, monkeypatch: pytest.MonkeyPatch ) -> None: - deployment_uuid = str(uuid.uuid4()) + deployment_uuid = uuid.uuid4() - async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str]: + async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, Union[str, UUID]]: return { "user_uuid": user_uuid, "uuid": deployment_uuid, } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) response = client.post( @@ -167,7 +179,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] assert response.status_code == 200 monkeypatch.setattr( - fastagency.app, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) response = client.get(f"/user/{user_uuid}/deployment/{deployment_uuid}") assert response.status_code == 200 @@ -183,16 +197,18 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] async def test_delete_deployment_auth_token( user_uuid: str, monkeypatch: pytest.MonkeyPatch ) -> None: - deployment_uuid = str(uuid.uuid4()) + deployment_uuid = uuid.uuid4() - async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str]: + async def mock_find_model(*args: Any, **kwargs: Any) -> Dict[str, Union[str, UUID]]: return { "user_uuid": user_uuid, "uuid": deployment_uuid, } monkeypatch.setattr( - fastagency.auth_token.auth, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) response = client.post( @@ -202,7 +218,9 @@ async def mock_find_model_using_raw(*args: Any, **kwargs: Any) -> Dict[str, str] assert response.status_code == 200 monkeypatch.setattr( - fastagency.app, "find_model_using_raw", mock_find_model_using_raw + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, ) response = client.get(f"/user/{user_uuid}/deployment/{deployment_uuid}") assert len(response.json()) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index c6579f48..fba08a02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from typing import ( Annotated, Any, + AsyncGenerator, AsyncIterator, Callable, Dict, @@ -24,10 +25,8 @@ from fastapi import FastAPI, Path from pydantic import BaseModel -from fastagency.db.helpers import ( - get_db_connection, - get_wasp_db_url, -) +from fastagency.db.base import DefaultDB +from fastagency.db.inmemory import InMemoryBackendDB, InMemoryFrontendDB from fastagency.helpers import create_autogen, create_model_ref, get_model_by_ref from fastagency.models.agents.assistant import AssistantAgent from fastagency.models.agents.user_proxy import UserProxyAgent @@ -46,21 +45,29 @@ F = TypeVar("F", bound=Callable[..., Any]) +@pytest_asyncio.fixture(scope="session", autouse=True) # type: ignore[misc] +async def set_default_db() -> AsyncGenerator[None, None]: + backend_db = InMemoryBackendDB() + frontend_db = InMemoryFrontendDB() + + with ( + DefaultDB.set(backend_db=backend_db, frontend_db=frontend_db), + ): + yield + + @pytest_asyncio.fixture(scope="session") # type: ignore[misc] async def user_uuid() -> AsyncIterator[str]: try: random_id = random.randint(1, 1_000_000) - generated_uuid = str(uuid.uuid4()) - wasp_db_url = await get_wasp_db_url() - async with get_db_connection(db_url=wasp_db_url) as db: - insert_query = ( - 'INSERT INTO "User" (email, username, uuid) VALUES (' - + f"'user{random_id}@airt.ai', 'user{random_id}', '{generated_uuid}')" - ) - await db.execute_raw(insert_query) - - select_query = 'SELECT * FROM "User" WHERE uuid=' + f"'{generated_uuid}'" - user = await db.query_first(select_query) + generated_uuid = uuid.uuid4() + email = f"user{random_id}@airt.ai" + username = f"user{random_id}" + + await DefaultDB.frontend()._create_user( + user_uuid=generated_uuid, email=email, username=username + ) + user = await DefaultDB.frontend().get_user(user_uuid=generated_uuid) yield user["uuid"] finally: diff --git a/tests/db/__init__.py b/tests/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/db/test_inmemory.py b/tests/db/test_inmemory.py new file mode 100644 index 00000000..6c22d194 --- /dev/null +++ b/tests/db/test_inmemory.py @@ -0,0 +1,185 @@ +import random +import uuid +from datetime import datetime, timedelta +from typing import Any, Dict, Union +from uuid import UUID + +import pytest + +import fastagency.db +import fastagency.db.inmemory +from fastagency.db.base import DefaultDB, KeyNotFoundError +from fastagency.db.inmemory import InMemoryBackendDB, InMemoryFrontendDB +from fastagency.models.llms.azure import AzureOAIAPIKey + + +@pytest.mark.asyncio() +class TestInMemoryFrontendDB: + async def test_set(self) -> None: + frontend_db = InMemoryFrontendDB() + backend_db = InMemoryBackendDB() + with DefaultDB.set(backend_db=backend_db, frontend_db=frontend_db): + assert DefaultDB._frontend_db == frontend_db + assert DefaultDB._backend_db == backend_db + + async def test_db(self) -> None: + frontend_db = InMemoryFrontendDB() + backend_db = InMemoryBackendDB() + with DefaultDB.set(backend_db=backend_db, frontend_db=frontend_db): + assert DefaultDB.frontend() == frontend_db + assert DefaultDB.backend() == backend_db + + async def test_create_user_get_user(self) -> None: + frontend_db = InMemoryFrontendDB() + + random_id = random.randint(1, 1_000_000) + generated_uuid = uuid.uuid4() + email = f"user{random_id}@airt.ai" + username = f"user{random_id}" + + user_uuid = await frontend_db._create_user(generated_uuid, email, username) + assert user_uuid == generated_uuid + + user = await frontend_db.get_user(user_uuid) + assert user["uuid"] == str(user_uuid) + assert user["email"] == email + assert user["username"] == username + + async def test_user_exception(self) -> None: + frontend_db = InMemoryFrontendDB() + user_uuid = uuid.uuid4() + with pytest.raises(KeyNotFoundError) as e: + await frontend_db.get_user(user_uuid) + assert f"user_uuid {user_uuid} not found" == str(e.value) + + +@pytest.mark.db() +@pytest.mark.asyncio() +class TestInMemoryBackendDB: + async def test_model_CRUD(self) -> None: # noqa: N802 + # Setup + frontend_db = InMemoryFrontendDB() + backend_db = InMemoryBackendDB() + random_id = random.randint(1, 1_000_000) + user_uuid = await frontend_db._create_user( + uuid.uuid4(), f"user{random_id}@airt.ai", f"user{random_id}" + ) + model_uuid = uuid.uuid4() + azure_oai_api_key = AzureOAIAPIKey(api_key="whatever", name="who cares?") + + # Tests + model = await backend_db.create_model( + user_uuid=user_uuid, + model_uuid=model_uuid, + type_name="secret", + model_name="AzureOAIAPIKey", + json_str=azure_oai_api_key.model_dump_json(), + ) + assert model["uuid"] == str(model_uuid) + assert model["user_uuid"] == str(user_uuid) + assert model["type_name"] == "secret" + assert model["model_name"] == "AzureOAIAPIKey" + assert model["json_str"] == azure_oai_api_key.model_dump() + + found_model = await backend_db.find_model(model_uuid) + assert found_model["uuid"] == str(model_uuid) + + many_model = await backend_db.find_many_model(user_uuid) + assert len(many_model) == 1 + assert many_model[0]["uuid"] == str(model_uuid) + + updated_model = await backend_db.update_model( + model_uuid=model_uuid, + user_uuid=user_uuid, + type_name="secret", + model_name="AzureOAIAPIKey2", + json_str=azure_oai_api_key.model_dump_json(), + ) + assert updated_model["uuid"] == str(model_uuid) + assert updated_model["model_name"] == "AzureOAIAPIKey2" + + deleted_model = await backend_db.delete_model(model_uuid) + assert deleted_model["uuid"] == str(model_uuid) + + async def test_auth_token_CRUD(self, monkeypatch: pytest.MonkeyPatch) -> None: # noqa: N802 + # Setup + frontend_db = InMemoryFrontendDB() + backend_db = InMemoryBackendDB() + random_id = random.randint(1, 1_000_000) + user_uuid = await frontend_db._create_user( + uuid.uuid4(), f"user{random_id}@airt.ai", f"user{random_id}" + ) + deployment_uuid = uuid.uuid4() + auth_token_uuid = uuid.uuid4() + + async def mock_find_model( + *args: Any, **kwargs: Any + ) -> Dict[str, Union[str, UUID]]: + return { + "user_uuid": user_uuid, + "uuid": deployment_uuid, + } + + monkeypatch.setattr( + fastagency.db.inmemory.InMemoryBackendDB, + "find_model", + mock_find_model, + ) + + # Tests + auth_token = await backend_db.create_auth_token( + auth_token_uuid=auth_token_uuid, + name="Test token", + user_uuid=user_uuid, + deployment_uuid=deployment_uuid, + hashed_auth_token="whatever", + expiry="99d", + expires_at=datetime.utcnow() + timedelta(days=99), + ) + assert auth_token["uuid"] == str(auth_token_uuid) + assert auth_token["name"] == "Test token" + + many_auth_token = await backend_db.find_many_auth_token( + user_uuid, deployment_uuid + ) + assert len(many_auth_token) == 1 + assert many_auth_token[0]["uuid"] == str(auth_token_uuid) + + deleted_auth_token = await backend_db.delete_auth_token( + auth_token_uuid, deployment_uuid, user_uuid + ) + assert deleted_auth_token["uuid"] == str(auth_token_uuid) + + async def test_model_exception(self) -> None: + backend_db = InMemoryBackendDB() + model_uuid = uuid.uuid4() + user_uuid = uuid.uuid4() + with pytest.raises(KeyNotFoundError) as e: + await backend_db.find_model(model_uuid) + assert f"model_uuid {model_uuid} not found" == str(e.value) + + with pytest.raises(KeyNotFoundError) as e: + await backend_db.update_model( + model_uuid=model_uuid, + user_uuid=user_uuid, + type_name="secret", + model_name="AzureOAIAPIKey2", + json_str="[]", + ) + assert f"model_uuid {model_uuid} not found" == str(e.value) + + with pytest.raises(KeyNotFoundError) as e: + await backend_db.delete_model(model_uuid) + assert f"model_uuid {model_uuid} not found" == str(e.value) + + async def test_auth_token_exception(self) -> None: + backend_db = InMemoryBackendDB() + auth_token_uuid = uuid.uuid4() + deployment_uuid = uuid.uuid4() + user_uuid = uuid.uuid4() + + with pytest.raises(KeyNotFoundError) as e: + await backend_db.delete_auth_token( + auth_token_uuid, deployment_uuid, user_uuid + ) + assert f"auth_token_uuid {auth_token_uuid} not found" == str(e.value) diff --git a/tests/db/test_prisma.py b/tests/db/test_prisma.py new file mode 100644 index 00000000..9b52fb30 --- /dev/null +++ b/tests/db/test_prisma.py @@ -0,0 +1,186 @@ +import random +import uuid +from datetime import datetime, timedelta +from typing import Any, Dict, Union +from uuid import UUID + +import pytest + +import fastagency.db +import fastagency.db.prisma +from fastagency.db.base import DefaultDB, KeyNotFoundError +from fastagency.db.prisma import PrismaBackendDB, PrismaFrontendDB +from fastagency.models.llms.azure import AzureOAIAPIKey + + +@pytest.mark.db() +@pytest.mark.asyncio() +class TestPrismaFrontendDB: + async def test_set(self) -> None: + frontend_db = PrismaFrontendDB() + backend_db = PrismaBackendDB() + with DefaultDB.set(backend_db=backend_db, frontend_db=frontend_db): + assert DefaultDB._frontend_db == frontend_db + assert DefaultDB._backend_db == backend_db + + async def test_db(self) -> None: + frontend_db = PrismaFrontendDB() + backend_db = PrismaBackendDB() + with DefaultDB.set(backend_db=backend_db, frontend_db=frontend_db): + assert DefaultDB.frontend() == frontend_db + assert DefaultDB.backend() == backend_db + + async def test_create_user_get_user(self) -> None: + frontend_db = PrismaFrontendDB() + + random_id = random.randint(1, 1_000_000) + generated_uuid = uuid.uuid4() + email = f"user{random_id}@airt.ai" + username = f"user{random_id}" + + user_uuid = await frontend_db._create_user(generated_uuid, email, username) + assert user_uuid == generated_uuid + + user = await frontend_db.get_user(user_uuid) + assert user["uuid"] == str(user_uuid) + assert user["email"] == email + assert user["username"] == username + + async def test_user_exception(self) -> None: + frontend_db = PrismaFrontendDB() + user_uuid = uuid.uuid4() + with pytest.raises(KeyNotFoundError) as e: + await frontend_db.get_user(user_uuid) + assert f"user_uuid {user_uuid} not found" == str(e.value) + + +@pytest.mark.db() +@pytest.mark.asyncio() +class TestPrismaBackendDB: + async def test_model_CRUD(self) -> None: # noqa: N802 + # Setup + frontend_db = PrismaFrontendDB() + backend_db = PrismaBackendDB() + random_id = random.randint(1, 1_000_000) + user_uuid = await frontend_db._create_user( + uuid.uuid4(), f"user{random_id}@airt.ai", f"user{random_id}" + ) + model_uuid = uuid.uuid4() + azure_oai_api_key = AzureOAIAPIKey(api_key="whatever", name="who cares?") + + # Tests + model = await backend_db.create_model( + user_uuid=user_uuid, + model_uuid=model_uuid, + type_name="secret", + model_name="AzureOAIAPIKey", + json_str=azure_oai_api_key.model_dump_json(), + ) + assert model["uuid"] == str(model_uuid) + assert model["user_uuid"] == str(user_uuid) + assert model["type_name"] == "secret" + assert model["model_name"] == "AzureOAIAPIKey" + assert model["json_str"] == azure_oai_api_key.model_dump() + + found_model = await backend_db.find_model(model_uuid) + assert found_model["uuid"] == str(model_uuid) + + many_model = await backend_db.find_many_model(user_uuid) + assert len(many_model) == 1 + assert many_model[0]["uuid"] == str(model_uuid) + + updated_model = await backend_db.update_model( + model_uuid=model_uuid, + user_uuid=user_uuid, + type_name="secret", + model_name="AzureOAIAPIKey2", + json_str=azure_oai_api_key.model_dump_json(), + ) + assert updated_model["uuid"] == str(model_uuid) + assert updated_model["model_name"] == "AzureOAIAPIKey2" + + deleted_model = await backend_db.delete_model(model_uuid) + assert deleted_model["uuid"] == str(model_uuid) + + async def test_auth_token_CRUD(self, monkeypatch: pytest.MonkeyPatch) -> None: # noqa: N802 + # Setup + frontend_db = PrismaFrontendDB() + backend_db = PrismaBackendDB() + random_id = random.randint(1, 1_000_000) + user_uuid = await frontend_db._create_user( + uuid.uuid4(), f"user{random_id}@airt.ai", f"user{random_id}" + ) + deployment_uuid = uuid.uuid4() + auth_token_uuid = uuid.uuid4() + + async def mock_find_model( + *args: Any, **kwargs: Any + ) -> Dict[str, Union[str, UUID]]: + return { + "user_uuid": user_uuid, + "uuid": deployment_uuid, + } + + monkeypatch.setattr( + fastagency.db.prisma.PrismaBackendDB, + "find_model", + mock_find_model, + ) + + # Tests + auth_token = await backend_db.create_auth_token( + auth_token_uuid=auth_token_uuid, + name="Test token", + user_uuid=user_uuid, + deployment_uuid=deployment_uuid, + hashed_auth_token="whatever", + expiry="99d", + expires_at=datetime.utcnow() + timedelta(days=99), + ) + assert auth_token["uuid"] == str(auth_token_uuid) + assert auth_token["name"] == "Test token" + + many_auth_token = await backend_db.find_many_auth_token( + user_uuid, deployment_uuid + ) + assert len(many_auth_token) == 1 + assert many_auth_token[0]["uuid"] == str(auth_token_uuid) + + deleted_auth_token = await backend_db.delete_auth_token( + auth_token_uuid, deployment_uuid, user_uuid + ) + assert deleted_auth_token["uuid"] == str(auth_token_uuid) + + async def test_model_exception(self) -> None: + backend_db = PrismaBackendDB() + model_uuid = uuid.uuid4() + user_uuid = uuid.uuid4() + with pytest.raises(KeyNotFoundError) as e: + await backend_db.find_model(model_uuid) + assert f"model_uuid {model_uuid} not found" == str(e.value) + + with pytest.raises(KeyNotFoundError) as e: + await backend_db.update_model( + model_uuid=model_uuid, + user_uuid=user_uuid, + type_name="secret", + model_name="AzureOAIAPIKey2", + json_str="[]", + ) + assert f"model_uuid {model_uuid} not found" == str(e.value) + + with pytest.raises(KeyNotFoundError) as e: + await backend_db.delete_model(model_uuid) + assert f"model_uuid {model_uuid} not found" == str(e.value) + + async def test_auth_token_exception(self) -> None: + backend_db = PrismaBackendDB() + auth_token_uuid = uuid.uuid4() + deployment_uuid = uuid.uuid4() + user_uuid = uuid.uuid4() + + with pytest.raises(KeyNotFoundError) as e: + await backend_db.delete_auth_token( + auth_token_uuid, deployment_uuid, user_uuid + ) + assert f"auth_token_uuid {auth_token_uuid} not found" == str(e.value) diff --git a/tests/models/llms/test_openai.py b/tests/models/llms/test_openai.py index 7dc6fa3e..e97ac322 100644 --- a/tests/models/llms/test_openai.py +++ b/tests/models/llms/test_openai.py @@ -118,23 +118,24 @@ def test_openai_schema(self) -> None: "default": "gpt-3.5-turbo", "description": "The model to use for the OpenAI API, e.g. 'gpt-3.5-turbo'", "enum": [ - "gpt-4-turbo-2024-04-09", "gpt-4-1106-preview", - "gpt-4-turbo", "gpt-4-turbo-preview", + "gpt-4o-mini", "gpt-4-0125-preview", - "gpt-4o-2024-05-13", - "gpt-3.5-turbo", - "gpt-3.5-turbo-instruct", - "gpt-3.5-turbo-instruct-0914", "gpt-4o-mini-2024-07-18", - "gpt-4o-mini", + "gpt-3.5-turbo", "gpt-3.5-turbo-16k", + "gpt-4-turbo-2024-04-09", "gpt-3.5-turbo-0125", + "gpt-4-turbo", "gpt-3.5-turbo-1106", - "gpt-4-0613", + "gpt-3.5-turbo-instruct-0914", + "gpt-3.5-turbo-instruct", "gpt-4o", + "gpt-4-0613", + "gpt-4o-2024-05-13", "gpt-4", + "gpt-4o-2024-08-06", ], "title": "Model", "type": "string",