|
1 | 1 | """Core constructs for Jac Language."""
|
2 | 2 |
|
3 | 3 | from contextvars import ContextVar
|
4 |
| -from dataclasses import is_dataclass |
| 4 | +from dataclasses import dataclass, is_dataclass |
5 | 5 | from os import getenv
|
6 |
| -from typing import Any, NotRequired, TypedDict, cast |
| 6 | +from typing import Any, Generic, TypeVar, cast |
7 | 7 |
|
8 | 8 | from bson import ObjectId
|
9 | 9 |
|
|
33 | 33 | SUPER_ROOT = NodeAnchor.ref(f"n::{SUPER_ROOT_ID}")
|
34 | 34 | PUBLIC_ROOT = NodeAnchor.ref(f"n::{PUBLIC_ROOT_ID}")
|
35 | 35 |
|
| 36 | +RT = TypeVar("RT") |
36 | 37 |
|
37 |
| -class ContextResponse(TypedDict): |
| 38 | + |
| 39 | +@dataclass |
| 40 | +class ContextResponse(Generic[RT]): |
38 | 41 | """Default Context Response."""
|
39 | 42 |
|
40 | 43 | status: int
|
41 |
| - reports: NotRequired[list[Any]] |
42 |
| - returns: NotRequired[list[Any]] |
| 44 | + reports: list[Any] | None = None |
| 45 | + returns: list[RT] | None = None |
| 46 | + |
| 47 | + def __serialize__(self) -> dict[str, Any]: |
| 48 | + """Serialize response.""" |
| 49 | + return { |
| 50 | + key: value |
| 51 | + for key, value in self.__dict__.items() |
| 52 | + if value is not None and not key.startswith("_") |
| 53 | + } |
43 | 54 |
|
44 | 55 |
|
45 | 56 | class JaseciContext(ExecutionContext):
|
@@ -135,20 +146,20 @@ def get_root() -> Root: # type: ignore[override]
|
135 | 146 |
|
136 | 147 | def response(self, returns: list[Any]) -> ORJSONResponse:
|
137 | 148 | """Return serialized version of reports."""
|
138 |
| - resp: ContextResponse = {"status": self.status, "returns": returns} |
| 149 | + resp = ContextResponse[Any](status=self.status) |
139 | 150 |
|
140 | 151 | if self.reports:
|
141 | 152 | for key, val in enumerate(self.reports):
|
142 | 153 | self.clean_response(key, val, self.reports)
|
143 |
| - resp["reports"] = self.reports |
| 154 | + resp.reports = self.reports |
144 | 155 |
|
145 | 156 | for key, val in enumerate(returns):
|
146 | 157 | self.clean_response(key, val, returns)
|
147 | 158 |
|
148 |
| - if not SHOW_ENDPOINT_RETURNS: |
149 |
| - resp.pop("returns") |
| 159 | + if SHOW_ENDPOINT_RETURNS: |
| 160 | + resp.returns = returns |
150 | 161 |
|
151 |
| - return ORJSONResponse(resp, status_code=self.status) |
| 162 | + return ORJSONResponse(resp.__serialize__(), status_code=self.status) |
152 | 163 |
|
153 | 164 | def clean_response(
|
154 | 165 | self, key: str | int, val: Any, obj: list | dict # noqa: ANN401
|
|
0 commit comments