Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix all the nasty stuff #319

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions src/abci/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import signal
import platform
import io
import os

from .utils import *
Expand Down Expand Up @@ -168,37 +169,45 @@ async def _start(self) -> None:
await self.server.serve_forever()

async def _handler(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:

data = BytesIO()
last_pos = 0

while True:
if last_pos == data.tell():
data = BytesIO()
last_pos = 0

# Read data from the reader
bits = await reader.read(MaxReadInBytes)
if len(bits) == 0:
logger.error(" ... tendermint closed connection")
# break to the _stop if the connection stops
break
break # Exit the loop if the connection is closed

# Append new data to the buffer
data.seek(0, io.SEEK_END)
data.write(bits)
data.seek(last_pos)

# Tendermint prefixes each serialized protobuf message
# with varint encoded length. We use the 'data' buffer to
# keep track of where we are in the byte stream and progress
# based on the length encoding
for message in read_messages(data, Request):
req_type = message.WhichOneof("value")
response = await self.protocol.process(req_type, message)
writer.write(response)
last_pos = data.tell()

# Any connection fails and we shut the whole thing down
data.seek(0) # Reset position to start of buffer

# Attempt to parse and process messages
while True:
start_pos = data.tell()
messages = list(read_messages(data, Request))
if not messages:
# No complete messages available, reset position and wait
data.seek(start_pos)
break # Exit the parsing loop to read more data

for message in messages:
req_type = message.WhichOneof("value")
response = await self.protocol.process(req_type, message)
writer.write(response)
await writer.drain()
# Update start position after processing
start_pos = data.tell()

# Remove processed data from the buffer
remaining_data = data.read()
data = BytesIO()
data.write(remaining_data)
data.seek(0) # Reset position to start of buffer

# Shut down if the connection is closed
await _stop()


Expand Down
23 changes: 15 additions & 8 deletions src/abci/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,26 @@ def write_message(message: Message) -> bytes:
return buffer.getvalue()


def read_messages(reader: BytesIO, message: Message) -> Message:
def read_messages(reader: BytesIO, message_class):
"""
Return an interator over the messages found in the byte stream
Return an iterator over the messages found in the byte stream.
"""
while True:
start_pos = reader.tell()
try:
length = decode_varint(reader)
except EOFError:
return
# Not enough data to read the length, reset and wait for more data
reader.seek(start_pos)
break # Exit the loop to wait for more data

data = reader.read(length)
if len(data) < length:
print(f"Expected {length} bytes, but got only {len(data)}. End of stream or data corruption.")
return
m = message()
m.ParseFromString(data)
yield m
# Not enough data to read the full message, reset and wait
reader.seek(start_pos)
break # Exit the loop to wait for more data

# Parse the message
msg = message_class()
msg.ParseFromString(data)
yield msg
42 changes: 36 additions & 6 deletions src/xian/methods/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,24 @@ async def query(self, req) -> ResponseQuery:
message_length = struct.pack('>I', len(byte_data))
connection.sendall(message_length + byte_data)
recv_length = connection.recv(4)
length = struct.unpack('>I', recv_length)[0]
recv = connection.recv(length)
result = recv.decode()

if len(recv_length) < 4:
# Handle error or incomplete length prefix
raise ValueError("Incomplete length prefix received")
else:
length = struct.unpack('>I', recv_length)[0]
recv = b''
while len(recv) < length:
packet = connection.recv(length - len(recv))
if not packet:
# Connection closed or error
raise ConnectionError("Connection closed before receiving all data")
recv += packet
if len(recv) == length:
result = recv.decode('utf-8')
else:
# Handle incomplete data error
raise ValueError("Did not receive all expected data")

# TODO: Deprecated - Remove after wallet and tools are reworked to use 'simulate_tx'
# http://localhost:26657/abci_query?path="/calculate_stamps/<encoded_payload>"
Expand All @@ -182,9 +197,24 @@ async def query(self, req) -> ResponseQuery:
message_length = struct.pack('>I', len(byte_data))
connection.sendall(message_length + byte_data)
recv_length = connection.recv(4)
length = struct.unpack('>I', recv_length)[0]
recv = connection.recv(length)
result = recv.decode()

if len(recv_length) < 4:
# Handle error or incomplete length prefix
raise ValueError("Incomplete length prefix received")
else:
length = struct.unpack('>I', recv_length)[0]
recv = b''
while len(recv) < length:
packet = connection.recv(length - len(recv))
if not packet:
# Connection closed or error
raise ConnectionError("Connection closed before receiving all data")
recv += packet
if len(recv) == length:
result = recv.decode('utf-8')
else:
# Handle incomplete data error
raise ValueError("Did not receive all expected data")

else:
error = f'Unknown query path: {path_parts[0]}'
Expand Down
21 changes: 14 additions & 7 deletions src/xian/services/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,34 @@ def listen(self):
connection, client_address = self.socket.accept()
print("Client connected")
try:
# Accept a connection
while True:
try:
# Read message length (4 bytes)
raw_msglen = connection.recv(4)
if not raw_msglen:
break
if len(raw_msglen) < 4:
# Handle incomplete length prefix
raise ValueError("Incomplete length prefix received")
msglen = struct.unpack('>I', raw_msglen)[0]

# Read the message data
data = connection.recv(msglen)
data = b''
while len(data) < msglen:
packet = connection.recv(msglen - len(data))
if not packet:
# No more data from client, client closed connection
print("Client disconnected")
break
data += packet

if not data:
# No more data from client, client closed connection
print("Client disconnected")
break

print(f"Received: {data.decode()}")
# Parse the JSON payload directly from bytes
payload = json.loads(data)

payload = data.decode()
payload = json.loads(payload)
try:
response = self.execute(payload)
response = json.dumps(response)
Expand All @@ -67,7 +75,6 @@ def listen(self):
print("Client disconnected")
break
finally:
# Clean up the connection
print("Client disconnected")
connection.close()

Expand Down
Loading