Skip to content

Commit

Permalink
Add prisma protocol class
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Jul 24, 2024
1 parent 8595211 commit 4dc0193
Show file tree
Hide file tree
Showing 18 changed files with 243 additions and 144 deletions.
11 changes: 5 additions & 6 deletions docs/docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ search:
- [hash_auth_token](api/fastagency/auth_token/auth/hash_auth_token.md)
- [parse_expiry](api/fastagency/auth_token/auth/parse_expiry.md)
- [verify_auth_token](api/fastagency/auth_token/auth/verify_auth_token.md)
- db
- helpers
- [find_model_using_raw](api/fastagency/db/helpers/find_model_using_raw.md)
- [get_db_connection](api/fastagency/db/helpers/get_db_connection.md)
- [get_user](api/fastagency/db/helpers/get_user.md)
- [get_wasp_db_url](api/fastagency/db/helpers/get_wasp_db_url.md)
- faststream_app
- [ping_handler](api/fastagency/faststream_app/ping_handler.md)
- helpers
Expand Down Expand Up @@ -135,6 +129,11 @@ search:
- fastapi_code_generator_helpers
- [ArgumentWithDescription](api/fastagency/openapi/fastapi_code_generator_helpers/ArgumentWithDescription.md)
- [patch_get_parameter_type](api/fastagency/openapi/fastapi_code_generator_helpers/patch_get_parameter_type.md)
- protocols
- base
- [BaseProtocol](api/fastagency/protocols/base/BaseProtocol.md)
- prisma
- [PrismaProtocol](api/fastagency/protocols/prisma/PrismaProtocol.md)
- saas_app_generator
- [InvalidFlyTokenError](api/fastagency/saas_app_generator/InvalidFlyTokenError.md)
- [InvalidGHTokenError](api/fastagency/saas_app_generator/InvalidGHTokenError.md)
Expand Down
11 changes: 0 additions & 11 deletions docs/docs/en/api/fastagency/db/helpers/get_db_connection.md

This file was deleted.

11 changes: 0 additions & 11 deletions docs/docs/en/api/fastagency/db/helpers/get_wasp_db_url.md

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ search:
boost: 0.5
---

::: fastagency.db.helpers.get_user
::: fastagency.protocols.base.BaseProtocol
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ search:
boost: 0.5
---

::: fastagency.db.helpers.find_model_using_raw
::: fastagency.protocols.prisma.PrismaProtocol
32 changes: 18 additions & 14 deletions fastagency/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pydantic import BaseModel, TypeAdapter, ValidationError

from .auth_token.auth import DeploymentAuthToken, create_deployment_auth_token
from .db.helpers import find_model_using_raw, get_db_connection, get_user
from .helpers import (
add_model_to_user,
create_model,
Expand All @@ -21,6 +20,9 @@
from .models.registry import Registry, Schemas
from .models.toolboxes.toolbox import Toolbox

# from .db.helpers import find_model_using_raw, get_db_connection, get_user
from .protocols.prisma import PrismaProtocol

logging.basicConfig(level=logging.INFO)

app = FastAPI()
Expand Down Expand Up @@ -78,7 +80,7 @@ async def validate_secret_model(
) -> Dict[str, Any]:
type: str = "secret"

found_model = await find_model_using_raw(model_uuid=model_uuid)
found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid)
if "api_key" in found_model["json_str"]:
model["api_key"] = found_model["json_str"]["api_key"]
try:
Expand Down Expand Up @@ -136,7 +138,7 @@ async def add_model(


async def create_toolbox_for_new_user(user_uuid: Union[str, UUID]) -> Dict[str, Any]:
await get_user(user_uuid=user_uuid) # type: ignore[arg-type]
await PrismaProtocol().get_user(user_uuid=user_uuid) # type: ignore[arg-type]

domain = environ.get("DOMAIN", "localhost")
toolbox_openapi_url = (
Expand Down Expand Up @@ -181,8 +183,8 @@ async def update_model(
registry = Registry.get_default()
validated_model = registry.validate(type_name, model_name, model)

async with get_db_connection() as db:
found_model = await find_model_using_raw(model_uuid=model_uuid)
async with PrismaProtocol().get_db_connection() as db:
found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid)

await db.model.update(
where={"uuid": found_model["uuid"]}, # type: ignore[arg-type]
Expand All @@ -201,8 +203,8 @@ async def update_model(
async def models_delete(
user_uuid: str, type_name: str, model_uuid: str
) -> Dict[str, Any]:
async with get_db_connection() as db:
found_model = await find_model_using_raw(model_uuid=model_uuid)
async with PrismaProtocol().get_db_connection() as db:
found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid)
model = await db.model.delete(
where={"uuid": found_model["uuid"]} # type: ignore[arg-type]
)
Expand Down Expand Up @@ -340,7 +342,9 @@ async def chat(request: ChatRequest) -> Dict[str, Any]:

@app.post("/deployment/{deployment_uuid}/chat")
async def deployment_chat(deployment_uuid: str) -> Dict[str, Any]:
found_model = await find_model_using_raw(model_uuid=deployment_uuid)
found_model = await PrismaProtocol().find_model_using_raw(
model_uuid=deployment_uuid
)
team_name = found_model["json_str"]["name"]
team_uuid = found_model["json_str"]["team"]["uuid"]

Expand Down Expand Up @@ -381,15 +385,15 @@ class DeploymentAuthTokenInfo(BaseModel):
async def get_all_deployment_auth_tokens(
user_uuid: str, deployment_uuid: str
) -> List[DeploymentAuthTokenInfo]:
user = await get_user(user_uuid=user_uuid)
deployment = await find_model_using_raw(model_uuid=deployment_uuid)
user = await PrismaProtocol().get_user(user_uuid=user_uuid)
deployment = await PrismaProtocol().find_model_using_raw(model_uuid=deployment_uuid)

if user["uuid"] != deployment["user_uuid"]:
raise HTTPException( # pragma: no cover
status_code=403, detail="User does not have access to this deployment"
)

async with get_db_connection() as db:
async with PrismaProtocol().get_db_connection() as db:
auth_tokens = await db.authtoken.find_many(
where={"deployment_uuid": deployment_uuid, "user_uuid": user_uuid},
)
Expand All @@ -407,15 +411,15 @@ async def delete_deployment_auth_token(
deployment_uuid: str,
auth_token_uuid: str,
) -> DeploymentAuthTokenInfo:
user = await get_user(user_uuid=user_uuid)
deployment = await find_model_using_raw(model_uuid=deployment_uuid)
user = await PrismaProtocol().get_user(user_uuid=user_uuid)
deployment = await PrismaProtocol().find_model_using_raw(model_uuid=deployment_uuid)

if user["uuid"] != deployment["user_uuid"]:
raise HTTPException( # pragma: no cover
status_code=403, detail="User does not have access to this deployment"
)

async with get_db_connection() as db:
async with PrismaProtocol().get_db_connection() as db:
auth_token = await db.authtoken.delete(
where={ # type: ignore[typeddict-unknown-key]
"uuid": auth_token_uuid,
Expand Down
9 changes: 5 additions & 4 deletions fastagency/auth_token/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from fastapi import HTTPException
from pydantic import BaseModel

from fastagency.db.helpers import find_model_using_raw, get_db_connection, get_user
# from fastagency.db.helpers import find_model_using_raw, get_db_connection, get_user
from ..protocols.prisma import PrismaProtocol


def generate_auth_token(length: int = 32) -> str:
Expand Down Expand Up @@ -80,8 +81,8 @@ async def create_deployment_auth_token(
name: str = "Default deployment token",
expiry: str = "99999d",
) -> DeploymentAuthToken:
user = await get_user(user_uuid=user_uuid)
deployment = await find_model_using_raw(model_uuid=deployment_uuid)
user = await PrismaProtocol().get_user(user_uuid=user_uuid)
deployment = await PrismaProtocol().find_model_using_raw(model_uuid=deployment_uuid)

if user["uuid"] != deployment["user_uuid"]:
raise HTTPException(
Expand All @@ -92,7 +93,7 @@ async def create_deployment_auth_token(
auth_token = generate_auth_token()
hashed_token = hash_auth_token(auth_token)

async with get_db_connection() as db:
async with PrismaProtocol().get_db_connection() as db:
await db.authtoken.create( # type: ignore[attr-defined]
data={
"uuid": str(uuid.uuid4()),
Expand Down
126 changes: 63 additions & 63 deletions fastagency/db/helpers.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,63 @@
from contextlib import asynccontextmanager
from os import environ
from typing import Any, AsyncGenerator, Dict, Optional, Union
from uuid import UUID

from fastapi import HTTPException
from prisma import Prisma # type: ignore[attr-defined]


@asynccontextmanager
async def get_db_connection(
db_url: Optional[str] = None,
) -> AsyncGenerator[Prisma, None]:
if not db_url:
db_url = environ.get("PY_DATABASE_URL", None)
if not db_url:
raise ValueError(
"No database URL provided nor set as environment variable 'PY_DATABASE_URL'"
) # pragma: no cover
if "connect_timeout" not in db_url:
db_url += "?connect_timeout=60"
db = Prisma(datasource={"url": db_url})
await db.connect()
try:
yield db
finally:
await db.disconnect()


async def get_wasp_db_url() -> str:
wasp_db_url: str = environ.get("DATABASE_URL") # type: ignore[assignment]
if "connect_timeout" not in wasp_db_url:
wasp_db_url += "?connect_timeout=60"
return wasp_db_url


async def find_model_using_raw(model_uuid: Union[str, UUID]) -> Dict[str, Any]:
if isinstance(model_uuid, UUID):
model_uuid = str(model_uuid)

async with get_db_connection() as db:
model: Optional[Dict[str, Any]] = await db.query_first(
'SELECT * from "Model" where uuid=' # nosec: [B608]
+ f"'{model_uuid}'"
)

if not model:
raise HTTPException(
status_code=404, detail="Something went wrong. Please try again later."
)
return model


async def get_user(user_uuid: Union[int, str]) -> Any:
wasp_db_url = await get_wasp_db_url()
async with get_db_connection(db_url=wasp_db_url) as db:
select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608]
user = await db.query_first(
select_query # nosec: [B608]
)
if not user:
raise HTTPException(status_code=404, detail=f"user_uuid {user_uuid} not found")
return user
# from contextlib import asynccontextmanager
# from os import environ
# from typing import Any, AsyncGenerator, Dict, Optional, Union
# from uuid import UUID

# from fastapi import HTTPException
# from prisma import Prisma # type: ignore[attr-defined]


# @asynccontextmanager
# async def get_db_connection(
# db_url: Optional[str] = None,
# ) -> AsyncGenerator[Prisma, None]:
# if not db_url:
# db_url = environ.get("PY_DATABASE_URL", None)
# if not db_url:
# raise ValueError(
# "No database URL provided nor set as environment variable 'PY_DATABASE_URL'"
# ) # pragma: no cover
# if "connect_timeout" not in db_url:
# db_url += "?connect_timeout=60"
# db = Prisma(datasource={"url": db_url})
# await db.connect()
# try:
# yield db
# finally:
# await db.disconnect()


# async def get_wasp_db_url() -> str:
# wasp_db_url: str = environ.get("DATABASE_URL") # type: ignore[assignment]
# if "connect_timeout" not in wasp_db_url:
# wasp_db_url += "?connect_timeout=60"
# return wasp_db_url


# async def find_model_using_raw(model_uuid: Union[str, UUID]) -> Dict[str, Any]:
# if isinstance(model_uuid, UUID):
# model_uuid = str(model_uuid)

# async with get_db_connection() as db:
# model: Optional[Dict[str, Any]] = await db.query_first(
# 'SELECT * from "Model" where uuid=' # nosec: [B608]
# + f"'{model_uuid}'"
# )

# if not model:
# raise HTTPException(
# status_code=404, detail="Something went wrong. Please try again later."
# )
# return model


# async def get_user(user_uuid: Union[int, str]) -> Any:
# wasp_db_url = await get_wasp_db_url()
# async with get_db_connection(db_url=wasp_db_url) as db:
# select_query = 'SELECT * from "User" where uuid=' + f"'{user_uuid}'" # nosec: [B608]
# user = await db.query_first(
# select_query # nosec: [B608]
# )
# if not user:
# raise HTTPException(status_code=404, detail=f"user_uuid {user_uuid} not found")
# return user
25 changes: 13 additions & 12 deletions fastagency/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
)

from .auth_token.auth import create_deployment_auth_token

# from fastagency.app import add_model
from .db.helpers import find_model_using_raw, get_db_connection, get_user
from .models.base import Model, ObjectReference
from .models.registry import Registry

# from fastagency.app import add_model
# from .db.helpers import find_model_using_raw, get_db_connection, get_user
from .protocols.prisma import PrismaProtocol

T = TypeVar("T", bound=Model)


async def get_model_by_uuid(model_uuid: Union[str, UUID]) -> Model:
model_dict = await find_model_using_raw(model_uuid=model_uuid)
model_dict = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid)

registry = Registry.get_default()
model = registry.validate(
Expand All @@ -43,11 +44,11 @@ async def validate_tokens_and_create_gh_repo(
model: Dict[str, Any],
model_uuid: str,
) -> SaasAppGenerator:
async with get_db_connection():
found_gh_token = await find_model_using_raw(
async with PrismaProtocol().get_db_connection():
found_gh_token = await PrismaProtocol().find_model_using_raw(
model_uuid=model["gh_token"]["uuid"]
)
found_fly_token = await find_model_using_raw(
found_fly_token = await PrismaProtocol().find_model_using_raw(
model_uuid=model["fly_token"]["uuid"]
)

Expand Down Expand Up @@ -81,8 +82,8 @@ async def deploy_saas_app(

await asyncify(saas_app.execute)()

async with get_db_connection() as db:
found_model = await find_model_using_raw(model_uuid=model_uuid)
async with PrismaProtocol().get_db_connection() as db:
found_model = await PrismaProtocol().find_model_using_raw(model_uuid=model_uuid)
found_model["json_str"]["app_deploy_status"] = "completed"

await db.model.update(
Expand Down Expand Up @@ -125,8 +126,8 @@ async def add_model_to_user(
updated_validated_model_dict["gh_repo_url"] = saas_app.gh_repo_url
validated_model_json = json.dumps(updated_validated_model_dict)

await get_user(user_uuid=user_uuid)
async with get_db_connection() as db:
await PrismaProtocol().get_user(user_uuid=user_uuid)
async with PrismaProtocol().get_db_connection() as db:
await db.model.create(
data={
"uuid": model_uuid,
Expand Down Expand Up @@ -209,7 +210,7 @@ async def get_all_models_for_user(
if type_name:
filters["type_name"] = type_name

async with get_db_connection() as db:
async with PrismaProtocol().get_db_connection() as db:
models = await db.model.find_many(where=filters) # type: ignore[arg-type]

return models # type: ignore[no-any-return]
Expand Down
Loading

0 comments on commit 4dc0193

Please sign in to comment.