Skip to content

Commit

Permalink
Add prisma protocol class (#6)
Browse files Browse the repository at this point in the history
* Add prisma protocol class

* Move protocol files to db directory

* Update docs

* Remove unnecessary commented out lines

* Split protocol

* Remove unnecessary code and rearrange

* Update docs

* Add methods for getting prisma actions objects

* Use protocol

* Optimize code

* Update docs

* Update BaseFrontendProtocol

* Rename prisma backend db class

* Delete unnecessary class

* Remove unnecessary methods

* Use default find model

* Remove get_authtoken_connection method

* Remove get_model_connection method

* Fix mypy issues

* Fix mypy issue

* Rename protocol class

* Add an empty line between methods in protocol

* Use sync get default and set default methods

* Refactor code

* Add a internal method to create user

* Add tests for prisma backend classes

* Add inmemory protocol implementation

* Update prisma test

* Update docs

* Use inmemory db in tests

* wip

* wip

* Use class defaultdb

* Remove unnecessary type adapter

* Move lifespan

* Reuse from_db in agents/base.py

* Use Union[str, UUID]

* Raise common exceptions and handle it in fastapi app

* Update docs

* Use exception args

* Update together model string

* Update openai models list

* Update openai schema in test

* Update openai tests

---------

Co-authored-by: Davor Runje <davor@airt.ai>
  • Loading branch information
kumaranvpl and davorrunje committed Aug 13, 2024
1 parent af8010d commit 7422ac5
Show file tree
Hide file tree
Showing 35 changed files with 1,184 additions and 256 deletions.
21 changes: 16 additions & 5 deletions docs/docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ search:
- [get_all_models](api/fastagency/app/get_all_models.md)
- [get_azure_llm_client](api/fastagency/app/get_azure_llm_client.md)
- [get_models_schemas](api/fastagency/app/get_models_schemas.md)
- [handle_keynotfounderror_middleware](api/fastagency/app/handle_keynotfounderror_middleware.md)
- [mask](api/fastagency/app/mask.md)
- [models_delete](api/fastagency/app/models_delete.md)
- [setup_user](api/fastagency/app/setup_user.md)
Expand All @@ -36,11 +37,21 @@ search:
- [parse_expiry](api/fastagency/auth_token/auth/parse_expiry.md)
- [verify_auth_token](api/fastagency/auth_token/auth/verify_auth_token.md)
- db
- helpers
- [find_model_using_raw](api/fastagency/db/helpers/find_model_using_raw.md)
- [get_db_connection](api/fastagency/db/helpers/get_db_connection.md)
- [get_user](api/fastagency/db/helpers/get_user.md)
- [get_wasp_db_url](api/fastagency/db/helpers/get_wasp_db_url.md)
- base
- [BackendDBProtocol](api/fastagency/db/base/BackendDBProtocol.md)
- [DefaultDB](api/fastagency/db/base/DefaultDB.md)
- [FrontendDBProtocol](api/fastagency/db/base/FrontendDBProtocol.md)
- [KeyExistsError](api/fastagency/db/base/KeyExistsError.md)
- [KeyNotFoundError](api/fastagency/db/base/KeyNotFoundError.md)
- inmemory
- [InMemoryBackendDB](api/fastagency/db/inmemory/InMemoryBackendDB.md)
- [InMemoryFrontendDB](api/fastagency/db/inmemory/InMemoryFrontendDB.md)
- prisma
- [PrismaBackendDB](api/fastagency/db/prisma/PrismaBackendDB.md)
- [PrismaBaseDB](api/fastagency/db/prisma/PrismaBaseDB.md)
- [PrismaFrontendDB](api/fastagency/db/prisma/PrismaFrontendDB.md)
- [fastapi_lifespan](api/fastagency/db/prisma/fastapi_lifespan.md)
- [faststream_lifespan](api/fastagency/db/prisma/faststream_lifespan.md)
- faststream_app
- [ping_handler](api/fastagency/faststream_app/ping_handler.md)
- helpers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.app.handle_keynotfounderror_middleware
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ search:
boost: 0.5
---

::: fastagency.db.helpers.get_wasp_db_url
::: fastagency.db.base.BackendDBProtocol
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ search:
boost: 0.5
---

::: fastagency.db.helpers.get_user
::: fastagency.db.base.DefaultDB
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ search:
boost: 0.5
---

::: fastagency.db.helpers.get_db_connection
::: fastagency.db.base.FrontendDBProtocol
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ search:
boost: 0.5
---

::: fastagency.db.helpers.find_model_using_raw
::: fastagency.db.base.KeyExistsError
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/base/KeyNotFoundError.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.base.KeyNotFoundError
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/inmemory/InMemoryBackendDB.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.inmemory.InMemoryBackendDB
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/inmemory/InMemoryFrontendDB.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.inmemory.InMemoryFrontendDB
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/prisma/PrismaBackendDB.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.prisma.PrismaBackendDB
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/prisma/PrismaBaseDB.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.prisma.PrismaBaseDB
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/prisma/PrismaFrontendDB.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.prisma.PrismaFrontendDB
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/prisma/fastapi_lifespan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.prisma.fastapi_lifespan
11 changes: 11 additions & 0 deletions docs/docs/en/api/fastagency/db/prisma/faststream_lifespan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: fastagency.db.prisma.faststream_lifespan
114 changes: 63 additions & 51 deletions fastagency/app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
import json
import logging
from os import environ
from typing import Annotated, Any, Dict, List, Optional, Tuple, Union
from typing import (
Annotated,
Any,
Callable,
Coroutine,
Dict,
List,
Optional,
Tuple,
Union,
)
from uuid import UUID

import httpx
import yaml
from fastapi import BackgroundTasks, Body, FastAPI, HTTPException, Path
from fastapi.requests import Request
from fastapi.responses import JSONResponse, Response
from openai import AsyncAzureOpenAI
from prisma.models import Model
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import BaseModel, ValidationError

from .auth_token.auth import DeploymentAuthToken, create_deployment_auth_token
from .db.helpers import find_model_using_raw, get_db_connection, get_user
from .db.base import DefaultDB, KeyNotFoundError
from .db.prisma import fastapi_lifespan
from .helpers import (
add_model_to_user,
create_model,
Expand All @@ -23,7 +35,18 @@

logging.basicConfig(level=logging.INFO)

app = FastAPI()

app = FastAPI(lifespan=fastapi_lifespan)


@app.middleware("http")
async def handle_keynotfounderror_middleware(
request: Request, call_next: Callable[[Request], Coroutine[Any, Any, Response]]
) -> Response:
try:
return await call_next(request)
except KeyNotFoundError as e:
return JSONResponse(status_code=404, content={"detail": e.args[0]})


@app.get("/models/schemas")
Expand Down Expand Up @@ -78,7 +101,7 @@ async def validate_secret_model(
) -> Dict[str, Any]:
type: str = "secret"

found_model = await find_model_using_raw(model_uuid=model_uuid)
found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid)
if "api_key" in found_model["json_str"]:
model["api_key"] = found_model["json_str"]["api_key"]
try:
Expand All @@ -100,10 +123,9 @@ async def get_all_models(
user_uuid: str,
type_name: Optional[str] = None,
) -> List[Any]:
models = await get_all_models_for_user(user_uuid=user_uuid, type_name=type_name)

ta = TypeAdapter(List[Model])
ret_val_without_mask = ta.dump_python(models, serialize_as_any=True) # type: ignore[call-arg]
ret_val_without_mask = await get_all_models_for_user(
user_uuid=user_uuid, type_name=type_name
)

ret_val = []
for model in ret_val_without_mask:
Expand Down Expand Up @@ -136,7 +158,7 @@ async def add_model(


async def create_toolbox_for_new_user(user_uuid: Union[str, UUID]) -> Dict[str, Any]:
await get_user(user_uuid=user_uuid) # type: ignore[arg-type]
await DefaultDB.frontend().get_user(user_uuid=user_uuid) # type: ignore[arg-type]

domain = environ.get("DOMAIN", "localhost")
toolbox_openapi_url = (
Expand Down Expand Up @@ -181,18 +203,14 @@ async def update_model(
registry = Registry.get_default()
validated_model = registry.validate(type_name, model_name, model)

async with get_db_connection() as db:
found_model = await find_model_using_raw(model_uuid=model_uuid)

await db.model.update(
where={"uuid": found_model["uuid"]}, # type: ignore[arg-type]
data={ # type: ignore[typeddict-unknown-key]
"type_name": type_name,
"model_name": model_name,
"json_str": validated_model.model_dump_json(), # type: ignore[typeddict-item]
"user_uuid": user_uuid,
},
)
found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid)
await DefaultDB.backend().update_model(
model_uuid=found_model["uuid"],
user_uuid=user_uuid,
type_name=type_name,
model_name=model_name,
json_str=validated_model.model_dump_json(),
)

return validated_model.model_dump()

Expand All @@ -201,13 +219,9 @@ async def update_model(
async def models_delete(
user_uuid: str, type_name: str, model_uuid: str
) -> Dict[str, Any]:
async with get_db_connection() as db:
found_model = await find_model_using_raw(model_uuid=model_uuid)
model = await db.model.delete(
where={"uuid": found_model["uuid"]} # type: ignore[arg-type]
)

return model.json_str # type: ignore
found_model = await DefaultDB.backend().find_model(model_uuid=model_uuid)
model = await DefaultDB.backend().delete_model(model_uuid=found_model["uuid"])
return model["json_str"] # type: ignore


def get_azure_llm_client() -> Tuple[AsyncAzureOpenAI, str]:
Expand Down Expand Up @@ -340,7 +354,7 @@ async def chat(request: ChatRequest) -> Dict[str, Any]:

@app.post("/deployment/{deployment_uuid}/chat")
async def deployment_chat(deployment_uuid: str) -> Dict[str, Any]:
found_model = await find_model_using_raw(model_uuid=deployment_uuid)
found_model = await DefaultDB.backend().find_model(model_uuid=deployment_uuid)
team_name = found_model["json_str"]["name"]
team_uuid = found_model["json_str"]["team"]["uuid"]

Expand Down Expand Up @@ -381,21 +395,22 @@ class DeploymentAuthTokenInfo(BaseModel):
async def get_all_deployment_auth_tokens(
user_uuid: str, deployment_uuid: str
) -> List[DeploymentAuthTokenInfo]:
user = await get_user(user_uuid=user_uuid)
deployment = await find_model_using_raw(model_uuid=deployment_uuid)
user = await DefaultDB.frontend().get_user(user_uuid=user_uuid)
deployment = await DefaultDB.backend().find_model(model_uuid=deployment_uuid)

if user["uuid"] != deployment["user_uuid"]:
raise HTTPException( # pragma: no cover
status_code=403, detail="User does not have access to this deployment"
)

async with get_db_connection() as db:
auth_tokens = await db.authtoken.find_many(
where={"deployment_uuid": deployment_uuid, "user_uuid": user_uuid},
)
auth_tokens = await DefaultDB.backend().find_many_auth_token(
user_uuid=user_uuid, deployment_uuid=deployment_uuid
)
return [
DeploymentAuthTokenInfo(
uuid=auth_token.uuid, name=auth_token.name, expiry=auth_token.expiry
uuid=auth_token["uuid"],
name=auth_token["name"],
expiry=auth_token["expiry"],
)
for auth_token in auth_tokens
]
Expand All @@ -407,24 +422,21 @@ async def delete_deployment_auth_token(
deployment_uuid: str,
auth_token_uuid: str,
) -> DeploymentAuthTokenInfo:
user = await get_user(user_uuid=user_uuid)
deployment = await find_model_using_raw(model_uuid=deployment_uuid)
user = await DefaultDB.frontend().get_user(user_uuid=user_uuid)
deployment = await DefaultDB.backend().find_model(model_uuid=deployment_uuid)

if user["uuid"] != deployment["user_uuid"]:
raise HTTPException( # pragma: no cover
status_code=403, detail="User does not have access to this deployment"
)

async with get_db_connection() as db:
auth_token = await db.authtoken.delete(
where={ # type: ignore[typeddict-unknown-key]
"uuid": auth_token_uuid,
"deployment_uuid": deployment_uuid,
"user_uuid": user_uuid,
},
)
auth_token = await DefaultDB.backend().delete_auth_token(
auth_token_uuid=auth_token_uuid,
deployment_uuid=deployment_uuid,
user_uuid=user_uuid,
)
return DeploymentAuthTokenInfo(
uuid=auth_token.uuid, # type: ignore[union-attr]
name=auth_token.name, # type: ignore[union-attr]
expiry=auth_token.expiry, # type: ignore[union-attr]
uuid=auth_token["uuid"], # type: ignore[union-attr]
name=auth_token["name"], # type: ignore[union-attr]
expiry=auth_token["expiry"], # type: ignore[union-attr]
)
Loading

0 comments on commit 7422ac5

Please sign in to comment.