Skip to content

Commit c9c0b30

Browse files
authored
[MINOR]: Dynamic return types (#1292)
1 parent a7fd9f9 commit c9c0b30

File tree

4 files changed

+361
-93
lines changed

4 files changed

+361
-93
lines changed

jac-cloud/jac_cloud/core/context.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Core constructs for Jac Language."""
22

33
from contextvars import ContextVar
4-
from dataclasses import is_dataclass
4+
from dataclasses import dataclass, is_dataclass
55
from os import getenv
6-
from typing import Any, NotRequired, TypedDict, cast
6+
from typing import Any, Generic, TypeVar, cast
77

88
from bson import ObjectId
99

@@ -33,13 +33,24 @@
3333
SUPER_ROOT = NodeAnchor.ref(f"n::{SUPER_ROOT_ID}")
3434
PUBLIC_ROOT = NodeAnchor.ref(f"n::{PUBLIC_ROOT_ID}")
3535

36+
RT = TypeVar("RT")
3637

37-
class ContextResponse(TypedDict):
38+
39+
@dataclass
40+
class ContextResponse(Generic[RT]):
3841
"""Default Context Response."""
3942

4043
status: int
41-
reports: NotRequired[list[Any]]
42-
returns: NotRequired[list[Any]]
44+
reports: list[Any] | None = None
45+
returns: list[RT] | None = None
46+
47+
def __serialize__(self) -> dict[str, Any]:
48+
"""Serialize response."""
49+
return {
50+
key: value
51+
for key, value in self.__dict__.items()
52+
if value is not None and not key.startswith("_")
53+
}
4354

4455

4556
class JaseciContext(ExecutionContext):
@@ -135,20 +146,20 @@ def get_root() -> Root: # type: ignore[override]
135146

136147
def response(self, returns: list[Any]) -> ORJSONResponse:
137148
"""Return serialized version of reports."""
138-
resp: ContextResponse = {"status": self.status, "returns": returns}
149+
resp = ContextResponse[Any](status=self.status)
139150

140151
if self.reports:
141152
for key, val in enumerate(self.reports):
142153
self.clean_response(key, val, self.reports)
143-
resp["reports"] = self.reports
154+
resp.reports = self.reports
144155

145156
for key, val in enumerate(returns):
146157
self.clean_response(key, val, returns)
147158

148-
if not SHOW_ENDPOINT_RETURNS:
149-
resp.pop("returns")
159+
if SHOW_ENDPOINT_RETURNS:
160+
resp.returns = returns
150161

151-
return ORJSONResponse(resp, status_code=self.status)
162+
return ORJSONResponse(resp.__serialize__(), status_code=self.status)
152163

153164
def clean_response(
154165
self, key: str | int, val: Any, obj: list | dict # noqa: ANN401

jac-cloud/jac_cloud/plugin/jaseci.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from functools import wraps
77
from os import getenv
88
from re import compile
9-
from typing import Any, Callable, Type, TypeVar, cast, get_type_hints
9+
from types import NoneType
10+
from typing import Any, Callable, Type, TypeAlias, TypeVar, Union, cast, get_type_hints
1011

1112
from asyncer import syncify
1213

@@ -182,9 +183,22 @@ def api_root(
182183

183184
walker_method = getattr(walker_router, method)
184185

186+
raw_types: list[Type] = [
187+
get_type_hints(jef.func).get("return", NoneType)
188+
for jef in (*cls._jac_entry_funcs_, *cls._jac_exit_funcs_)
189+
]
190+
191+
if raw_types:
192+
if len(raw_types) > 1:
193+
ret_types: TypeAlias = Union[*raw_types] # type: ignore[valid-type]
194+
else:
195+
ret_types = raw_types[0] # type: ignore[misc]
196+
else:
197+
ret_types = NoneType # type: ignore[misc]
198+
185199
settings: dict[str, Any] = {
186200
"tags": ["walker"],
187-
"response_model": ContextResponse,
201+
"response_model": ContextResponse[ret_types],
188202
}
189203
if auth:
190204
settings["dependencies"] = cast(list, authenticator)

0 commit comments

Comments
 (0)