From 0fc7dd7ee1bf4b065ff4919d7c4fecd6668e317c Mon Sep 17 00:00:00 2001 From: Endogen Date: Thu, 7 Nov 2024 03:01:42 +0100 Subject: [PATCH] Fix all the nasty stuff --- src/abci/server.py | 55 ++++++++++++++++++++-------------- src/abci/utils.py | 23 +++++++++----- src/xian/methods/query.py | 42 ++++++++++++++++++++++---- src/xian/services/simulator.py | 21 ++++++++----- 4 files changed, 97 insertions(+), 44 deletions(-) diff --git a/src/abci/server.py b/src/abci/server.py index c2196ed0..dba91173 100644 --- a/src/abci/server.py +++ b/src/abci/server.py @@ -4,6 +4,7 @@ import asyncio import signal import platform +import io import os from .utils import * @@ -168,37 +169,45 @@ async def _start(self) -> None: await self.server.serve_forever() async def _handler( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - data = BytesIO() - last_pos = 0 - while True: - if last_pos == data.tell(): - data = BytesIO() - last_pos = 0 - + # Read data from the reader bits = await reader.read(MaxReadInBytes) if len(bits) == 0: logger.error(" ... tendermint closed connection") - # break to the _stop if the connection stops - break + break # Exit the loop if the connection is closed + # Append new data to the buffer + data.seek(0, io.SEEK_END) data.write(bits) - data.seek(last_pos) - - # Tendermint prefixes each serialized protobuf message - # with varint encoded length. We use the 'data' buffer to - # keep track of where we are in the byte stream and progress - # based on the length encoding - for message in read_messages(data, Request): - req_type = message.WhichOneof("value") - response = await self.protocol.process(req_type, message) - writer.write(response) - last_pos = data.tell() - - # Any connection fails and we shut the whole thing down + data.seek(0) # Reset position to start of buffer + + # Attempt to parse and process messages + while True: + start_pos = data.tell() + messages = list(read_messages(data, Request)) + if not messages: + # No complete messages available, reset position and wait + data.seek(start_pos) + break # Exit the parsing loop to read more data + + for message in messages: + req_type = message.WhichOneof("value") + response = await self.protocol.process(req_type, message) + writer.write(response) + await writer.drain() + # Update start position after processing + start_pos = data.tell() + + # Remove processed data from the buffer + remaining_data = data.read() + data = BytesIO() + data.write(remaining_data) + data.seek(0) # Reset position to start of buffer + + # Shut down if the connection is closed await _stop() diff --git a/src/abci/utils.py b/src/abci/utils.py index 0119eac8..9434e83b 100644 --- a/src/abci/utils.py +++ b/src/abci/utils.py @@ -90,19 +90,26 @@ def write_message(message: Message) -> bytes: return buffer.getvalue() -def read_messages(reader: BytesIO, message: Message) -> Message: +def read_messages(reader: BytesIO, message_class): """ - Return an interator over the messages found in the byte stream + Return an iterator over the messages found in the byte stream. """ while True: + start_pos = reader.tell() try: length = decode_varint(reader) except EOFError: - return + # Not enough data to read the length, reset and wait for more data + reader.seek(start_pos) + break # Exit the loop to wait for more data + data = reader.read(length) if len(data) < length: - print(f"Expected {length} bytes, but got only {len(data)}. End of stream or data corruption.") - return - m = message() - m.ParseFromString(data) - yield m + # Not enough data to read the full message, reset and wait + reader.seek(start_pos) + break # Exit the loop to wait for more data + + # Parse the message + msg = message_class() + msg.ParseFromString(data) + yield msg diff --git a/src/xian/methods/query.py b/src/xian/methods/query.py index a712ad17..f25d9040 100644 --- a/src/xian/methods/query.py +++ b/src/xian/methods/query.py @@ -167,9 +167,24 @@ async def query(self, req) -> ResponseQuery: message_length = struct.pack('>I', len(byte_data)) connection.sendall(message_length + byte_data) recv_length = connection.recv(4) - length = struct.unpack('>I', recv_length)[0] - recv = connection.recv(length) - result = recv.decode() + + if len(recv_length) < 4: + # Handle error or incomplete length prefix + raise ValueError("Incomplete length prefix received") + else: + length = struct.unpack('>I', recv_length)[0] + recv = b'' + while len(recv) < length: + packet = connection.recv(length - len(recv)) + if not packet: + # Connection closed or error + raise ConnectionError("Connection closed before receiving all data") + recv += packet + if len(recv) == length: + result = recv.decode('utf-8') + else: + # Handle incomplete data error + raise ValueError("Did not receive all expected data") # TODO: Deprecated - Remove after wallet and tools are reworked to use 'simulate_tx' # http://localhost:26657/abci_query?path="/calculate_stamps/" @@ -182,9 +197,24 @@ async def query(self, req) -> ResponseQuery: message_length = struct.pack('>I', len(byte_data)) connection.sendall(message_length + byte_data) recv_length = connection.recv(4) - length = struct.unpack('>I', recv_length)[0] - recv = connection.recv(length) - result = recv.decode() + + if len(recv_length) < 4: + # Handle error or incomplete length prefix + raise ValueError("Incomplete length prefix received") + else: + length = struct.unpack('>I', recv_length)[0] + recv = b'' + while len(recv) < length: + packet = connection.recv(length - len(recv)) + if not packet: + # Connection closed or error + raise ConnectionError("Connection closed before receiving all data") + recv += packet + if len(recv) == length: + result = recv.decode('utf-8') + else: + # Handle incomplete data error + raise ValueError("Did not receive all expected data") else: error = f'Unknown query path: {path_parts[0]}' diff --git a/src/xian/services/simulator.py b/src/xian/services/simulator.py index 92f639ff..4d0020a4 100644 --- a/src/xian/services/simulator.py +++ b/src/xian/services/simulator.py @@ -34,26 +34,34 @@ def listen(self): connection, client_address = self.socket.accept() print("Client connected") try: - # Accept a connection while True: try: # Read message length (4 bytes) raw_msglen = connection.recv(4) if not raw_msglen: break + if len(raw_msglen) < 4: + # Handle incomplete length prefix + raise ValueError("Incomplete length prefix received") msglen = struct.unpack('>I', raw_msglen)[0] # Read the message data - data = connection.recv(msglen) + data = b'' + while len(data) < msglen: + packet = connection.recv(msglen - len(data)) + if not packet: + # No more data from client, client closed connection + print("Client disconnected") + break + data += packet + if not data: - # No more data from client, client closed connection print("Client disconnected") break - print(f"Received: {data.decode()}") + # Parse the JSON payload directly from bytes + payload = json.loads(data) - payload = data.decode() - payload = json.loads(payload) try: response = self.execute(payload) response = json.dumps(response) @@ -67,7 +75,6 @@ def listen(self): print("Client disconnected") break finally: - # Clean up the connection print("Client disconnected") connection.close()