From b4ed29043c5d1b52af91f8f7823d731b34a1b3e5 Mon Sep 17 00:00:00 2001 From: Simon Holesch Date: Sun, 11 Feb 2024 04:58:46 +0100 Subject: [PATCH] jsonrpc: Merge Server and Proxy into Channel Now bidirectional communication is possible. --- not_my_board/_jsonrpc.py | 417 ++++++++++++++++++++++----------------- 1 file changed, 231 insertions(+), 186 deletions(-) diff --git a/not_my_board/_jsonrpc.py b/not_my_board/_jsonrpc.py index b79349f..7835d88 100644 --- a/not_my_board/_jsonrpc.py +++ b/not_my_board/_jsonrpc.py @@ -1,11 +1,14 @@ #!/usr/bin/env python3 import asyncio +import dataclasses import functools +import itertools import json import logging import textwrap import traceback +from typing import Any, Optional, Union import not_my_board._util as util @@ -13,141 +16,140 @@ CODE_INTERNAL_ERROR = -32603 -CODE_PARSE_ERROR = -32700 CODE_INVALID_REQUEST = -32600 CODE_METHOD_NOT_FOUND = -32601 -class RemoteError(Exception): - def __init__(self, code, message, data): - if "traceback" in data: - details = textwrap.indent(data["traceback"], " " * 4).rstrip() - super().__init__(f"{message}\n{details}") - else: - super().__init__(message) - self.code = code - self.data = data - - -class Server: +class Channel: def __init__(self, send, receive_iter, api_obj=None): - super().__init__() self._send = send self._receive_iter = receive_iter self._api_obj = api_obj self._tasks = set() - self._tasks_by_id = {} + self._request_tasks_by_id = {} + self._id_generator = itertools.count(start=1) + self._pending = {} + + # TODO should be False before communicate_forever() is running + self._is_receiving = True def set_api_object(self, api_obj): self._api_obj = api_obj - async def serve_forever(self): + async def communicate_forever(self): try: async for raw_data in self._receive_iter: - task = asyncio.create_task(self._receive(raw_data)) - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) + try: + await self._receive(raw_data) + except Exception: + traceback.print_exc() finally: + self._is_receiving = False + for future in self._pending.values(): + if not future.done(): + future.set_exception(RuntimeError("Connection closed")) await util.cancel_tasks(self._tasks.copy()) - async def _receive_task(self, raw_data): - try: - await self._receive(raw_data) - except Exception: - traceback.print_exc() + # TODO remove aliases + io_loop = communicate_forever + serve_forever = communicate_forever - async def _receive(self, raw_data): - id_ = None - next_error = (CODE_PARSE_ERROR, "Parse Error") + async def __aenter__(self): + # TODO use util.background_task() + self._task = asyncio.create_task(self.communicate_forever()) + return self + + async def __aexit__(self, exc_type, exc, tb): + self._task.cancel() try: - id_, data = Request.parse_id(raw_data) + await self._task + except asyncio.CancelledError: + pass - next_error = (CODE_INVALID_REQUEST, "Invalid Request") - request = Request.from_data(data) + def __getattr__(self, method_name): + if method_name.startswith("_"): + raise AttributeError(f"invalid attribute '{method_name}'") + return functools.partial(self._call, method_name) - next_error = (CODE_METHOD_NOT_FOUND, "Method not found") - assert not request.method.startswith("_") + async def _receive(self, raw_data): + info = {"id": None, "is_request": False} + try: + message = _parse_message(raw_data, info) + except Exception as e: + if info["id"] is not None: + if info["is_request"]: + response = ErrorResponse.with_traceback( + info["id"], CODE_INVALID_REQUEST, "Invalid Request" + ) + await self._send(bytes(response)) + return + else: + future = self._pending.get(info["id"]) + if future and not future.done(): + future.set_exception(e) + return + + raise + + if isinstance(message, Request): + task = asyncio.create_task(self._handle_request(message)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + elif isinstance(message, Response): + await self._handle_response(message) + else: # ErrorResponse + await self._handle_error_response(message) + + async def _handle_request(self, request): + next_error = (CODE_METHOD_NOT_FOUND, "Method not found") + try: if request.method == "rpc.cancel": - method = self._cancel + method = self._cancel_local else: + if request.method.startswith("_"): + raise ProtocolError(f'method "{request.method}" not allowed') + method = getattr(self._api_obj, request.method) logger.info("Method call: %s", request.method) next_error = CODE_INTERNAL_ERROR, None - if id_ is not None: - self._tasks_by_id[id_] = asyncio.current_task() - result = await method(*request.args, **request.kwargs) - response = Response(result, id_) - await self._send(bytes(response)) + if request.id is not None: + try: + self._request_tasks_by_id[request.id] = asyncio.current_task() + result = await method(*request.args, **request.kwargs) + response = Response(request.id, result) + await self._send(bytes(response)) + finally: + del self._request_tasks_by_id[request.id] else: await method(*request.args, **request.kwargs) except Exception as e: - if id_ is not None: + if request.id is not None: code, message = next_error if message is None: message = str(e) - response = ErrorResponse.with_traceback(code, message, id_) + response = ErrorResponse.with_traceback(request.id, code, message) await self._send(bytes(response)) else: raise - finally: - if id_ in self._tasks_by_id: - del self._tasks_by_id[id_] - - async def _cancel(self, id_): - if id_ in self._tasks_by_id: - await util.cancel_tasks([self._tasks_by_id[id_]]) - - -class Proxy: - def __init__(self, send, receive_iter): - self._send = send - self._receive_iter = receive_iter - self._next_id = 1 - self._pending = {} - self._is_receiving = True - - async def __aenter__(self): - self._task = asyncio.create_task(self.io_loop()) - return self - - async def __aexit__(self, exc_type, exc, tb): - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - async def io_loop(self): - try: - async for raw_data in self._receive_iter: - await self._receive(raw_data) - finally: - self._is_receiving = False - for future in self._pending.values(): - if not future.done(): - future.set_exception(RuntimeError("Connection closed")) + async def _cancel_local(self, id_): + if id_ in self._request_tasks_by_id: + await util.cancel_tasks([self._request_tasks_by_id[id_]]) - async def _receive(self, raw_data): - id_ = None - try: - id_, data = Response.parse_id(raw_data) - response = Response.from_data(data) + async def _handle_response(self, response): + future = self._pending.get(response.id) + if future and not future.done(): + future.set_result(response.result) - future = self._pending.get(id_) - if future and not future.done(): - future.set_result(response.result) - except Exception as e: - future = self._pending.get(id_) - if future and not future.done(): - future.set_exception(e) - else: - traceback.print_exc() + async def _handle_error_response(self, error_response): + exc = error_response.as_exception() - def __getattr__(self, method_name): - if method_name.startswith("_"): - raise AttributeError(f"invalid attribute '{method_name}'") - return functools.partial(self._call, method_name) + future = self._pending.get(error_response.id) + if future and not future.done(): + future.set_exception(exc) + else: + raise exc async def _call(self, method_name, *args, **kwargs): if not self._is_receiving: @@ -156,123 +158,166 @@ async def _call(self, method_name, *args, **kwargs): if kwargs.pop("_notification", False): id_ = None else: - id_ = self._next_id - self._next_id += 1 + id_ = next(self._id_generator) - assert not args or not kwargs, "use either args or kwargs" + if args and kwargs: + raise RuntimeError("Use either args or kwargs") - request = Request(method_name, args or kwargs, id_) + request = Request(id_, method_name, args or kwargs) logger.info("Calling: %s", request.method) if id_ is not None: - future = asyncio.get_running_loop().create_future() - self._pending[id_] = future - try: - await self._send(bytes(request)) - return await self._pending[id_] - except asyncio.CancelledError: - await self._cancel(id_, request.method) - raise - finally: - del self._pending[id_] + return await self._send_request(request) else: + # send notification await self._send(bytes(request)) - async def _cancel(self, to_cancel_id, to_cancel_name): - id_ = self._next_id - self._next_id += 1 - - request = Request("rpc.cancel", [to_cancel_id], id_) - logger.info("Canceling: %s", to_cancel_name) - + async def _send_request(self, request, send_cancellation=True): future = asyncio.get_running_loop().create_future() - self._pending[id_] = future + self._pending[request.id] = future try: await self._send(bytes(request)) - await self._pending[id_] - # don't request to cancel the cancellation if this task is canceled + return await self._pending[request.id] + except asyncio.CancelledError: + if send_cancellation: + logger.info("Canceling: %s", request.method) + await self._cancel_remote(request.id) + raise finally: - del self._pending[id_] + del self._pending[request.id] + async def _cancel_remote(self, to_cancel_id): + id_ = next(self._id_generator) + request = Request(id_, "rpc.cancel", [to_cancel_id]) -class Message: - _is_id_required = False - _body = {} + # don't request to cancel the cancellation if this task is canceled + await self._send_request(request, send_cancellation=False) - @classmethod - def parse_id(cls, raw_data): - data = json.loads(raw_data) - if cls._is_id_required or "id" in data: - assert isinstance(data["id"], (str, int)), '"id" must be a string or number' - return data.get("id"), data - def __bytes__(self): - return json.dumps( - { - "jsonrpc": "2.0", - **self._body, - } - ).encode() - - -class Request(Message): - def __init__(self, method, params, id_=None): - self.method = method - self.id = id_ - self._body = { - "method": method, - "params": params, - } +# TODO remove aliases +Server = Channel +Proxy = Channel - if id_: - self._body["id"] = id_ - if isinstance(params, list): - self.args = params - self.kwargs = {} - else: - self.args = [] - self.kwargs = params +def _parse_message(raw_data, info): + data = json.loads(raw_data) + id_ = data.get("id") + if id_ is not None: + if not isinstance(id_, (str, int)): + raise ProtocolError('"id" must be a string or number') + info["id"] = id_ + + # check if it is a Request + method = data.get("method") + if method is not None: + info["is_request"] = True + + if not isinstance(method, str): + raise ProtocolError('"method" must be a string') - @classmethod - def from_data(cls, data): - method = data["method"] - assert isinstance(method, str), '"method" must be a string' params = data.get("params", []) - assert isinstance(params, (list, dict)), '"params" must be a structured value' - return cls(method, params, data.get("id")) + if not isinstance(params, (list, dict)): + raise ProtocolError('"params" must be a structured value') + + return Request(id_, method, params) + + # must be a Response or ErrorResponse + if id_ is None: + raise ProtocolError('"id" is required') + + # check if it is an ErrorResponse + error = data.get("error") + if error: + code = error["code"] + if not isinstance(code, int): + raise ProtocolError('"error.code" must be an integer') + + message = error["message"] + if not isinstance(message, str): + raise ProtocolError('"error.message" must be a string') + + filtered_error = { + "code": code, + "message": message, + "data": error.get("data"), + } + return ErrorResponse(id_, filtered_error) + + # must be a Response + return Response(id_, data["result"]) + + +@dataclasses.dataclass +class _Message: + jsonrpc: str = dataclasses.field(default="2.0", init=False) + id: Optional[Union[int, str]] + + def __bytes__(self): + body = {} + for field in dataclasses.fields(self): + if field.name == "id" and self.id is None: + # skip optional id + continue + + body[field.name] = getattr(self, field.name) + + return json.dumps(body).encode() + + +@dataclasses.dataclass +class Request(_Message): + method: str + params: Union[list, dict] + + @property + def args(self): + if isinstance(self.params, list): + return self.params + return [] + + @property + def kwargs(self): + if isinstance(self.params, dict): + return self.params + return {} + +@dataclasses.dataclass +class Response(_Message): + result: Any -class Response(Message): - _is_id_required = True - def __init__(self, result, id_): - self._body = {"result": result, "id": id_} - self.result = result - self.id = id_ +@dataclasses.dataclass +class ErrorResponse(_Message): + error: dict @classmethod - def from_data(cls, data): - if "error" in data: - error = data["error"] - raise RemoteError(error["code"], error["message"], error["data"]) - return cls(data["result"], data["id"]) - - -class ErrorResponse(Response): - # pylint: disable=super-init-not-called - def __init__(self, code, message, id_, data=None): - self._body = { - "error": { - "code": code, - "message": message, + def with_traceback(cls, id_, code, message): + error = { + "code": code, + "message": message, + "data": { + "traceback": traceback.format_exc(), }, - "id": id_, } - if data is not None: - self._body["error"]["data"] = data + return cls(id_, error) - @classmethod - def with_traceback(cls, code, message, id_): - data = {"traceback": traceback.format_exc()} - return cls(code, message, id_, data) + def as_exception(self): + return RemoteError( + self.error["code"], self.error["message"], self.error["data"] + ) + + +class RemoteError(Exception): + def __init__(self, code, message, data): + if isinstance(data, dict) and "traceback" in data: + details = textwrap.indent(data["traceback"], " " * 4).rstrip() + super().__init__(f"{message}\n{details}") + else: + super().__init__(message) + self.code = code + self.data = data + + +class ProtocolError(Exception): + pass