Skip to content

Commit

Permalink
[WEBSOCKET]: Initial Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
amadolid committed Nov 4, 2024
1 parent de3c755 commit 94bb2d4
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 49 deletions.
18 changes: 10 additions & 8 deletions jac-cloud/jac_cloud/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from bson import ObjectId

from fastapi import Request
from fastapi import Request, WebSocket

from jaclang.runtimelib.context import ExecutionContext
from jaclang.runtimelib.context import EXECUTION_CONTEXT, ExecutionContext

from .architype import (
AccessLevel,
Expand Down Expand Up @@ -61,19 +61,21 @@ class JaseciContext(ExecutionContext):
system_root: NodeAnchor
root: NodeAnchor
entry_node: NodeAnchor
base: ExecutionContext
request: Request
base: ExecutionContext | None
connection: Request | WebSocket

def close(self) -> None:
"""Clean up context."""
self.mem.close()

@staticmethod
def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext": # type: ignore[override]
def create( # type: ignore[override]
connection: Request | WebSocket, entry: NodeAnchor | None = None
) -> "JaseciContext":
"""Create JacContext."""
ctx = JaseciContext()
ctx.base = ExecutionContext.get()
ctx.request = request
ctx.base = EXECUTION_CONTEXT.get(None)
ctx.connection = connection
ctx.mem = MongoDB()
ctx.reports = []
ctx.status = 200
Expand All @@ -94,7 +96,7 @@ def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext"

ctx.system_root = system_root

if _root := getattr(request, "_root", None):
if _root := getattr(connection, "_root", None):
ctx.root = _root
ctx.mem.set(_root.id, _root)
else:
Expand Down
12 changes: 9 additions & 3 deletions jac-cloud/jac_cloud/jaseci/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,15 @@ async def lifespan(app: _FaststAPI) -> AsyncGenerator[None, _FaststAPI]:
populate_yaml_specs(cls.__app__)

from .routers import healthz_router, sso_router, user_router
from ..plugin.jaseci import walker_router

for router in [healthz_router, sso_router, user_router, walker_router]:
from ..plugin.jaseci import walker_router, websocket_router

for router in [
healthz_router,
sso_router,
user_router,
walker_router,
websocket_router,
]:
cls.__app__.include_router(router)

@cls.__app__.exception_handler(Exception)
Expand Down
22 changes: 21 additions & 1 deletion jac-cloud/jac_cloud/jaseci/security/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bson import ObjectId

from fastapi import Depends, Request
from fastapi import Depends, Request, WebSocket
from fastapi.exceptions import HTTPException
from fastapi.security import HTTPBearer

Expand Down Expand Up @@ -104,4 +104,24 @@ def authenticate(request: Request) -> None:
raise HTTPException(status_code=401)


def authenticate_websocket(websocket: WebSocket) -> bool:
"""Authenticate websocket connection."""
if (
authorization := websocket.headers.get("Authorization")
) and authorization.lower().startswith("bearer"):
token = authorization[7:]
decrypted = decrypt(token)
if (
decrypted
and decrypted["expiration"] > utc_timestamp()
and TokenRedis.hget(f"{decrypted['id']}:{token}")
and (user := User.Collection.find_by_id(decrypted["id"]))
and (root := NodeAnchor.Collection.find_by_id(user.root_id))
):
websocket._user = user # type: ignore[attr-defined]
websocket._root = root # type: ignore[attr-defined]
return True
return False


authenticator = [Depends(HTTPBearer()), Depends(authenticate)]
147 changes: 116 additions & 31 deletions jac-cloud/jac_cloud/plugin/jaseci.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from functools import wraps
from os import getenv
from re import compile
from traceback import format_exception
from types import NoneType
from typing import Any, Callable, Type, TypeAlias, TypeVar, Union, cast, get_type_hints

from anyio import to_thread

from asyncer import syncify

from fastapi import (
Expand All @@ -19,6 +22,7 @@
Request,
Response,
UploadFile,
WebSocket,
)
from fastapi.responses import ORJSONResponse

Expand Down Expand Up @@ -50,9 +54,9 @@
WalkerAnchor,
WalkerArchitype,
)
from ..core.context import ContextResponse, ExecutionContext, JaseciContext
from ..core.context import ContextResponse, ExecutionContext, JaseciContext, PUBLIC_ROOT
from ..jaseci import FastAPI
from ..jaseci.security import authenticator
from ..jaseci.security import authenticate_websocket, authenticator

# from ..jaseci.utils import log_entry, log_exit

Expand All @@ -68,6 +72,72 @@
}

walker_router = APIRouter(prefix="/walker", tags=["walker"])
websocket_router = APIRouter(prefix="/websocket", tags=["walker"])
websocket_events: dict[str, dict[str, Any]] = {}


def notify(self: WebSocket, data: Any) -> None: # noqa: ANN401
"""Notify synchrounously."""
syncify(self.send_json)(data)


WebSocket.notify = notify # type: ignore[attr-defined]


def websocket_synchronizer(websocket: WebSocket, data: dict[str, Any]) -> dict:
"""Websocket event sychronizer."""
if event := websocket_events.get(even_walker := data["walker"]):
if event["auth"] and websocket._root is PUBLIC_ROOT: # type: ignore[attr-defined]
return {"error": f"Event {even_walker} requires to be authenticated!"}
elif not event["auth"] and websocket._root is not PUBLIC_ROOT: # type: ignore[attr-defined]
return {"error": f"Event {even_walker} requires to be unauthenticated!"}

walker: type = event["type"]
node: str | None = data.get("node")
try:
payload = event["model"](**data["context"]).__dict__
except ValidationError:
raise

jctx = JaseciContext.create(websocket, NodeAnchor.ref(node) if node else None)

wlk: WalkerAnchor = walker(**payload).__jac__
if Jac.check_read_access(jctx.entry_node):
Jac.spawn_call(wlk.architype, jctx.entry_node.architype)
jctx.close()

if jctx.custom is not MISSING:
return jctx.custom

return jctx.response(wlk.returns)
else:
jctx.close()
return {
"error": f"You don't have access on target entry{cast(Anchor, jctx.entry_node).ref_id}!"
}
else:
return {"error": "Invalid request! Please use valid walker event!"}


@websocket_router.websocket("")
async def websocket_endpoint(websocket: WebSocket) -> None:
"""Websocket Endpoint."""
if not websocket_events:
await websocket.close()
return

if not authenticate_websocket(websocket):
websocket._root = PUBLIC_ROOT # type: ignore[attr-defined]

await websocket.accept()
while True:
data: dict[str, Any] = await websocket.receive_json()
try:
await to_thread.run_sync(websocket_synchronizer, websocket, data)
except ValidationError as e:
await websocket.send_json({"error": e.errors()})
except Exception as e:
await websocket.send_json({"error": format_exception(e)})


def get_specs(cls: type) -> Type["DefaultSpecs"] | None:
Expand Down Expand Up @@ -108,6 +178,7 @@ def populate_apis(cls: Type[WalkerArchitype]) -> None:
query: dict[str, Any] = {}
body: dict[str, Any] = {}
files: dict[str, Any] = {}
message: dict[str, Any] = {}

if path:
if not path.startswith("/"):
Expand All @@ -133,9 +204,12 @@ def populate_apis(cls: Type[WalkerArchitype]) -> None:
f_name = f.name
f_type = hintings[f_name]
if f_type in FILE_TYPES:
files[f_name] = gen_model_field(f_type, f, True)
message[f_name] = files[f_name] = gen_model_field(
f_type, f, True
)
else:
consts = gen_model_field(f_type, f)
message[f_name] = consts

if as_query == "*" or f_name in as_query:
query[f_name] = consts
Expand Down Expand Up @@ -189,7 +263,7 @@ def api_entry(
if isinstance(body, BaseUploadFile) and body_model:
body = loads(syncify(body.read)())
try:
body = body_model(**body).model_dump()
body = body_model(**body).__dict__
except ValidationError as e:
return ORJSONResponse({"detail": e.errors()})

Expand Down Expand Up @@ -224,35 +298,46 @@ def api_root(

for method in methods:
method = method.lower()

walker_method = getattr(walker_router, method)

raw_types: list[Type] = [
get_type_hints(jef.func).get("return", NoneType)
for jef in (*cls._jac_entry_funcs_, *cls._jac_exit_funcs_)
]

if raw_types:
if len(raw_types) > 1:
ret_types: TypeAlias = Union[*raw_types] # type: ignore[valid-type]
else:
ret_types = raw_types[0] # type: ignore[misc]
else:
ret_types = NoneType # type: ignore[misc]

settings: dict[str, Any] = {
"tags": ["walker"],
"response_model": ContextResponse[ret_types] | Any,
}
if auth:
settings["dependencies"] = cast(list, authenticator)

walker_method(url := f"/{cls.__name__}{path}", summary=url, **settings)(
api_root
)
walker_method(
url := f"/{cls.__name__}/{{node}}{path}", summary=url, **settings
)(api_entry)
match method:
case "websocket":
websocket_events[cls.__name__] = {
"type": cls,
"model": create_model(
f"{cls.__name__.lower()}_message_model", **message
),
"auth": auth,
}
case _:
raw_types: list[Type] = [
get_type_hints(jef.func).get("return", NoneType)
for jef in (*cls._jac_entry_funcs_, *cls._jac_exit_funcs_)
]

if raw_types:
if len(raw_types) > 1:
ret_types: TypeAlias = Union[*raw_types] # type: ignore[valid-type]
else:
ret_types = raw_types[0] # type: ignore[misc]
else:
ret_types = NoneType # type: ignore[misc]

settings: dict[str, Any] = {
"tags": ["walker"],
"response_model": ContextResponse[ret_types] | Any,
}
if auth:
settings["dependencies"] = cast(list, authenticator)

walker_method(
url := f"/{cls.__name__}{path}", summary=url, **settings
)(api_root)
walker_method(
url := f"/{cls.__name__}/{{node}}{path}",
summary=url,
**settings,
)(api_entry)


def specs(
Expand Down
12 changes: 12 additions & 0 deletions jac-cloud/jac_cloud/tests/simple_graph.jac
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,16 @@ walker custom_report {
class __specs__ {
has auth: bool = False;
}
}


walker websocket {
has val: int;
can enter1 with `root entry {
Jac.get_context().connection.notify({"testing": 1});
}

class __specs__ {
has methods: list = ["websocket"];
}
}
6 changes: 3 additions & 3 deletions jac-cloud/jac_cloud/tests/test_simple_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,19 +612,19 @@ def trigger_upload_file(self) -> None:
"single": {
"name": "simple_graph.jac",
"content_type": "application/octet-stream",
"size": 6852,
"size": 7066,
}
},
"multiple": [
{
"name": "simple_graph.jac",
"content_type": "application/octet-stream",
"size": 6852,
"size": 7066,
},
{
"name": "simple_graph.jac",
"content_type": "application/octet-stream",
"size": 6852,
"size": 7066,
},
],
"singleOptional": None,
Expand Down
6 changes: 3 additions & 3 deletions jac-cloud/jac_cloud/tests/test_simple_graph_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,19 +390,19 @@ def trigger_upload_file(self) -> None:
"single": {
"name": "simple_graph.jac",
"content_type": "application/octet-stream",
"size": 6852,
"size": 7066,
}
},
"multiple": [
{
"name": "simple_graph.jac",
"content_type": "application/octet-stream",
"size": 6852,
"size": 7066,
},
{
"name": "simple_graph.jac",
"content_type": "application/octet-stream",
"size": 6852,
"size": 7066,
},
],
"singleOptional": None,
Expand Down

0 comments on commit 94bb2d4

Please sign in to comment.