Skip to content

Commit

Permalink
Merge branch 'main' into improv_ref
Browse files Browse the repository at this point in the history
  • Loading branch information
kugesan1105 committed Oct 2, 2024
2 parents 13e57dd + c9c0b30 commit 413c3ea
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 95 deletions.
31 changes: 21 additions & 10 deletions jac-cloud/jac_cloud/core/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Core constructs for Jac Language."""

from contextvars import ContextVar
from dataclasses import is_dataclass
from dataclasses import dataclass, is_dataclass
from os import getenv
from typing import Any, NotRequired, TypedDict, cast
from typing import Any, Generic, TypeVar, cast

from bson import ObjectId

Expand Down Expand Up @@ -33,13 +33,24 @@
SUPER_ROOT = NodeAnchor.ref(f"n::{SUPER_ROOT_ID}")
PUBLIC_ROOT = NodeAnchor.ref(f"n::{PUBLIC_ROOT_ID}")

RT = TypeVar("RT")

class ContextResponse(TypedDict):

@dataclass
class ContextResponse(Generic[RT]):
"""Default Context Response."""

status: int
reports: NotRequired[list[Any]]
returns: NotRequired[list[Any]]
reports: list[Any] | None = None
returns: list[RT] | None = None

def __serialize__(self) -> dict[str, Any]:
"""Serialize response."""
return {
key: value
for key, value in self.__dict__.items()
if value is not None and not key.startswith("_")
}


class JaseciContext(ExecutionContext):
Expand Down Expand Up @@ -135,20 +146,20 @@ def get_root() -> Root: # type: ignore[override]

def response(self, returns: list[Any]) -> ORJSONResponse:
"""Return serialized version of reports."""
resp: ContextResponse = {"status": self.status, "returns": returns}
resp = ContextResponse[Any](status=self.status)

if self.reports:
for key, val in enumerate(self.reports):
self.clean_response(key, val, self.reports)
resp["reports"] = self.reports
resp.reports = self.reports

for key, val in enumerate(returns):
self.clean_response(key, val, returns)

if not SHOW_ENDPOINT_RETURNS:
resp.pop("returns")
if SHOW_ENDPOINT_RETURNS:
resp.returns = returns

return ORJSONResponse(resp, status_code=self.status)
return ORJSONResponse(resp.__serialize__(), status_code=self.status)

def clean_response(
self, key: str | int, val: Any, obj: list | dict # noqa: ANN401
Expand Down
18 changes: 16 additions & 2 deletions jac-cloud/jac_cloud/plugin/jaseci.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from functools import wraps
from os import getenv
from re import compile
from typing import Any, Callable, Type, TypeVar, cast, get_type_hints
from types import NoneType
from typing import Any, Callable, Type, TypeAlias, TypeVar, Union, cast, get_type_hints

from asyncer import syncify

Expand Down Expand Up @@ -182,9 +183,22 @@ def api_root(

walker_method = getattr(walker_router, method)

raw_types: list[Type] = [
get_type_hints(jef.func).get("return", NoneType)
for jef in (*cls._jac_entry_funcs_, *cls._jac_exit_funcs_)
]

if raw_types:
if len(raw_types) > 1:
ret_types: TypeAlias = Union[*raw_types] # type: ignore[valid-type]
else:
ret_types = raw_types[0] # type: ignore[misc]
else:
ret_types = NoneType # type: ignore[misc]

settings: dict[str, Any] = {
"tags": ["walker"],
"response_model": ContextResponse,
"response_model": ContextResponse[ret_types],
}
if auth:
settings["dependencies"] = cast(list, authenticator)
Expand Down
Loading

0 comments on commit 413c3ea

Please sign in to comment.