From 94bb2d428b92bc57e20928797fa80361d3c4b6e5 Mon Sep 17 00:00:00 2001 From: "Alexie (Boyong) Madolid" Date: Thu, 31 Oct 2024 17:13:26 +0800 Subject: [PATCH] [WEBSOCKET]: Initial Implementation --- jac-cloud/jac_cloud/core/context.py | 18 ++- jac-cloud/jac_cloud/jaseci/__init__.py | 12 +- .../jac_cloud/jaseci/security/__init__.py | 22 ++- jac-cloud/jac_cloud/plugin/jaseci.py | 147 ++++++++++++++---- jac-cloud/jac_cloud/tests/simple_graph.jac | 12 ++ .../jac_cloud/tests/test_simple_graph.py | 6 +- .../jac_cloud/tests/test_simple_graph_mini.py | 6 +- 7 files changed, 174 insertions(+), 49 deletions(-) diff --git a/jac-cloud/jac_cloud/core/context.py b/jac-cloud/jac_cloud/core/context.py index 803254265a..099078fd66 100644 --- a/jac-cloud/jac_cloud/core/context.py +++ b/jac-cloud/jac_cloud/core/context.py @@ -7,9 +7,9 @@ from bson import ObjectId -from fastapi import Request +from fastapi import Request, WebSocket -from jaclang.runtimelib.context import ExecutionContext +from jaclang.runtimelib.context import EXECUTION_CONTEXT, ExecutionContext from .architype import ( AccessLevel, @@ -61,19 +61,21 @@ class JaseciContext(ExecutionContext): system_root: NodeAnchor root: NodeAnchor entry_node: NodeAnchor - base: ExecutionContext - request: Request + base: ExecutionContext | None + connection: Request | WebSocket def close(self) -> None: """Clean up context.""" self.mem.close() @staticmethod - def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext": # type: ignore[override] + def create( # type: ignore[override] + connection: Request | WebSocket, entry: NodeAnchor | None = None + ) -> "JaseciContext": """Create JacContext.""" ctx = JaseciContext() - ctx.base = ExecutionContext.get() - ctx.request = request + ctx.base = EXECUTION_CONTEXT.get(None) + ctx.connection = connection ctx.mem = MongoDB() ctx.reports = [] ctx.status = 200 @@ -94,7 +96,7 @@ def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext" ctx.system_root = system_root - if _root := getattr(request, "_root", None): + if _root := getattr(connection, "_root", None): ctx.root = _root ctx.mem.set(_root.id, _root) else: diff --git a/jac-cloud/jac_cloud/jaseci/__init__.py b/jac-cloud/jac_cloud/jaseci/__init__.py index b5b8cb78a1..50e821f668 100644 --- a/jac-cloud/jac_cloud/jaseci/__init__.py +++ b/jac-cloud/jac_cloud/jaseci/__init__.py @@ -62,9 +62,15 @@ async def lifespan(app: _FaststAPI) -> AsyncGenerator[None, _FaststAPI]: populate_yaml_specs(cls.__app__) from .routers import healthz_router, sso_router, user_router - from ..plugin.jaseci import walker_router - - for router in [healthz_router, sso_router, user_router, walker_router]: + from ..plugin.jaseci import walker_router, websocket_router + + for router in [ + healthz_router, + sso_router, + user_router, + walker_router, + websocket_router, + ]: cls.__app__.include_router(router) @cls.__app__.exception_handler(Exception) diff --git a/jac-cloud/jac_cloud/jaseci/security/__init__.py b/jac-cloud/jac_cloud/jaseci/security/__init__.py index 64bfa642dd..9fb9162bdf 100644 --- a/jac-cloud/jac_cloud/jaseci/security/__init__.py +++ b/jac-cloud/jac_cloud/jaseci/security/__init__.py @@ -5,7 +5,7 @@ from bson import ObjectId -from fastapi import Depends, Request +from fastapi import Depends, Request, WebSocket from fastapi.exceptions import HTTPException from fastapi.security import HTTPBearer @@ -104,4 +104,24 @@ def authenticate(request: Request) -> None: raise HTTPException(status_code=401) +def authenticate_websocket(websocket: WebSocket) -> bool: + """Authenticate websocket connection.""" + if ( + authorization := websocket.headers.get("Authorization") + ) and authorization.lower().startswith("bearer"): + token = authorization[7:] + decrypted = decrypt(token) + if ( + decrypted + and decrypted["expiration"] > utc_timestamp() + and TokenRedis.hget(f"{decrypted['id']}:{token}") + and (user := User.Collection.find_by_id(decrypted["id"])) + and (root := NodeAnchor.Collection.find_by_id(user.root_id)) + ): + websocket._user = user # type: ignore[attr-defined] + websocket._root = root # type: ignore[attr-defined] + return True + return False + + authenticator = [Depends(HTTPBearer()), Depends(authenticate)] diff --git a/jac-cloud/jac_cloud/plugin/jaseci.py b/jac-cloud/jac_cloud/plugin/jaseci.py index b428b6cb8f..0064d5b959 100644 --- a/jac-cloud/jac_cloud/plugin/jaseci.py +++ b/jac-cloud/jac_cloud/plugin/jaseci.py @@ -6,9 +6,12 @@ from functools import wraps from os import getenv from re import compile +from traceback import format_exception from types import NoneType from typing import Any, Callable, Type, TypeAlias, TypeVar, Union, cast, get_type_hints +from anyio import to_thread + from asyncer import syncify from fastapi import ( @@ -19,6 +22,7 @@ Request, Response, UploadFile, + WebSocket, ) from fastapi.responses import ORJSONResponse @@ -50,9 +54,9 @@ WalkerAnchor, WalkerArchitype, ) -from ..core.context import ContextResponse, ExecutionContext, JaseciContext +from ..core.context import ContextResponse, ExecutionContext, JaseciContext, PUBLIC_ROOT from ..jaseci import FastAPI -from ..jaseci.security import authenticator +from ..jaseci.security import authenticate_websocket, authenticator # from ..jaseci.utils import log_entry, log_exit @@ -68,6 +72,72 @@ } walker_router = APIRouter(prefix="/walker", tags=["walker"]) +websocket_router = APIRouter(prefix="/websocket", tags=["walker"]) +websocket_events: dict[str, dict[str, Any]] = {} + + +def notify(self: WebSocket, data: Any) -> None: # noqa: ANN401 + """Notify synchrounously.""" + syncify(self.send_json)(data) + + +WebSocket.notify = notify # type: ignore[attr-defined] + + +def websocket_synchronizer(websocket: WebSocket, data: dict[str, Any]) -> dict: + """Websocket event sychronizer.""" + if event := websocket_events.get(even_walker := data["walker"]): + if event["auth"] and websocket._root is PUBLIC_ROOT: # type: ignore[attr-defined] + return {"error": f"Event {even_walker} requires to be authenticated!"} + elif not event["auth"] and websocket._root is not PUBLIC_ROOT: # type: ignore[attr-defined] + return {"error": f"Event {even_walker} requires to be unauthenticated!"} + + walker: type = event["type"] + node: str | None = data.get("node") + try: + payload = event["model"](**data["context"]).__dict__ + except ValidationError: + raise + + jctx = JaseciContext.create(websocket, NodeAnchor.ref(node) if node else None) + + wlk: WalkerAnchor = walker(**payload).__jac__ + if Jac.check_read_access(jctx.entry_node): + Jac.spawn_call(wlk.architype, jctx.entry_node.architype) + jctx.close() + + if jctx.custom is not MISSING: + return jctx.custom + + return jctx.response(wlk.returns) + else: + jctx.close() + return { + "error": f"You don't have access on target entry{cast(Anchor, jctx.entry_node).ref_id}!" + } + else: + return {"error": "Invalid request! Please use valid walker event!"} + + +@websocket_router.websocket("") +async def websocket_endpoint(websocket: WebSocket) -> None: + """Websocket Endpoint.""" + if not websocket_events: + await websocket.close() + return + + if not authenticate_websocket(websocket): + websocket._root = PUBLIC_ROOT # type: ignore[attr-defined] + + await websocket.accept() + while True: + data: dict[str, Any] = await websocket.receive_json() + try: + await to_thread.run_sync(websocket_synchronizer, websocket, data) + except ValidationError as e: + await websocket.send_json({"error": e.errors()}) + except Exception as e: + await websocket.send_json({"error": format_exception(e)}) def get_specs(cls: type) -> Type["DefaultSpecs"] | None: @@ -108,6 +178,7 @@ def populate_apis(cls: Type[WalkerArchitype]) -> None: query: dict[str, Any] = {} body: dict[str, Any] = {} files: dict[str, Any] = {} + message: dict[str, Any] = {} if path: if not path.startswith("/"): @@ -133,9 +204,12 @@ def populate_apis(cls: Type[WalkerArchitype]) -> None: f_name = f.name f_type = hintings[f_name] if f_type in FILE_TYPES: - files[f_name] = gen_model_field(f_type, f, True) + message[f_name] = files[f_name] = gen_model_field( + f_type, f, True + ) else: consts = gen_model_field(f_type, f) + message[f_name] = consts if as_query == "*" or f_name in as_query: query[f_name] = consts @@ -189,7 +263,7 @@ def api_entry( if isinstance(body, BaseUploadFile) and body_model: body = loads(syncify(body.read)()) try: - body = body_model(**body).model_dump() + body = body_model(**body).__dict__ except ValidationError as e: return ORJSONResponse({"detail": e.errors()}) @@ -224,35 +298,46 @@ def api_root( for method in methods: method = method.lower() - walker_method = getattr(walker_router, method) - raw_types: list[Type] = [ - get_type_hints(jef.func).get("return", NoneType) - for jef in (*cls._jac_entry_funcs_, *cls._jac_exit_funcs_) - ] - - if raw_types: - if len(raw_types) > 1: - ret_types: TypeAlias = Union[*raw_types] # type: ignore[valid-type] - else: - ret_types = raw_types[0] # type: ignore[misc] - else: - ret_types = NoneType # type: ignore[misc] - - settings: dict[str, Any] = { - "tags": ["walker"], - "response_model": ContextResponse[ret_types] | Any, - } - if auth: - settings["dependencies"] = cast(list, authenticator) - - walker_method(url := f"/{cls.__name__}{path}", summary=url, **settings)( - api_root - ) - walker_method( - url := f"/{cls.__name__}/{{node}}{path}", summary=url, **settings - )(api_entry) + match method: + case "websocket": + websocket_events[cls.__name__] = { + "type": cls, + "model": create_model( + f"{cls.__name__.lower()}_message_model", **message + ), + "auth": auth, + } + case _: + raw_types: list[Type] = [ + get_type_hints(jef.func).get("return", NoneType) + for jef in (*cls._jac_entry_funcs_, *cls._jac_exit_funcs_) + ] + + if raw_types: + if len(raw_types) > 1: + ret_types: TypeAlias = Union[*raw_types] # type: ignore[valid-type] + else: + ret_types = raw_types[0] # type: ignore[misc] + else: + ret_types = NoneType # type: ignore[misc] + + settings: dict[str, Any] = { + "tags": ["walker"], + "response_model": ContextResponse[ret_types] | Any, + } + if auth: + settings["dependencies"] = cast(list, authenticator) + + walker_method( + url := f"/{cls.__name__}{path}", summary=url, **settings + )(api_root) + walker_method( + url := f"/{cls.__name__}/{{node}}{path}", + summary=url, + **settings, + )(api_entry) def specs( diff --git a/jac-cloud/jac_cloud/tests/simple_graph.jac b/jac-cloud/jac_cloud/tests/simple_graph.jac index 1a05bf5010..469f031211 100644 --- a/jac-cloud/jac_cloud/tests/simple_graph.jac +++ b/jac-cloud/jac_cloud/tests/simple_graph.jac @@ -271,4 +271,16 @@ walker custom_report { class __specs__ { has auth: bool = False; } +} + + +walker websocket { + has val: int; + can enter1 with `root entry { + Jac.get_context().connection.notify({"testing": 1}); + } + + class __specs__ { + has methods: list = ["websocket"]; + } } \ No newline at end of file diff --git a/jac-cloud/jac_cloud/tests/test_simple_graph.py b/jac-cloud/jac_cloud/tests/test_simple_graph.py index 80ef287769..f5a84bfb7e 100644 --- a/jac-cloud/jac_cloud/tests/test_simple_graph.py +++ b/jac-cloud/jac_cloud/tests/test_simple_graph.py @@ -612,19 +612,19 @@ def trigger_upload_file(self) -> None: "single": { "name": "simple_graph.jac", "content_type": "application/octet-stream", - "size": 6852, + "size": 7066, } }, "multiple": [ { "name": "simple_graph.jac", "content_type": "application/octet-stream", - "size": 6852, + "size": 7066, }, { "name": "simple_graph.jac", "content_type": "application/octet-stream", - "size": 6852, + "size": 7066, }, ], "singleOptional": None, diff --git a/jac-cloud/jac_cloud/tests/test_simple_graph_mini.py b/jac-cloud/jac_cloud/tests/test_simple_graph_mini.py index a489793779..05774d5493 100644 --- a/jac-cloud/jac_cloud/tests/test_simple_graph_mini.py +++ b/jac-cloud/jac_cloud/tests/test_simple_graph_mini.py @@ -390,19 +390,19 @@ def trigger_upload_file(self) -> None: "single": { "name": "simple_graph.jac", "content_type": "application/octet-stream", - "size": 6852, + "size": 7066, } }, "multiple": [ { "name": "simple_graph.jac", "content_type": "application/octet-stream", - "size": 6852, + "size": 7066, }, { "name": "simple_graph.jac", "content_type": "application/octet-stream", - "size": 6852, + "size": 7066, }, ], "singleOptional": None,