Skip to content

Commit 0b4be95

Browse files
committed
rawsocket: ensure to read full payload
1 parent 6bc7d35 commit 0b4be95

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

xconn/transports.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ def create_ping():
4242

4343
return payload, ping_header, created_at
4444

45+
def _recv_exactly(sock, n: int) -> bytes:
46+
"""Receive exactly n bytes from a socket or raise if connection breaks."""
47+
chunks = []
48+
received = 0
49+
while received < n:
50+
chunk = sock.recv(n - received)
51+
if not chunk:
52+
raise ConnectionError("Socket connection broken")
53+
54+
chunks.append(chunk)
55+
received += len(chunk)
56+
57+
return b"".join(chunks)
58+
4559

4660
class RawSocketTransport(ITransport):
4761
def __init__(self, sock: socket.socket):
@@ -68,7 +82,7 @@ def connect(
6882

6983
sock.sendall(hs_request.to_bytes())
7084

71-
hs_response_bytes = sock.recv(RAW_SOCKET_HEADER_LENGTH)
85+
hs_response_bytes = _recv_exactly(sock, RAW_SOCKET_HEADER_LENGTH)
7286
hs_response = Handshake.from_bytes(hs_response_bytes)
7387

7488
if hs_request.protocol != hs_response.protocol:
@@ -77,20 +91,20 @@ def connect(
7791
return RawSocketTransport(sock)
7892

7993
def read(self) -> str | bytes:
80-
msg_header_bytes = self._sock.recv(RAW_SOCKET_HEADER_LENGTH)
94+
msg_header_bytes = _recv_exactly(self._sock, RAW_SOCKET_HEADER_LENGTH)
8195
msg_header = MessageHeader.from_bytes(msg_header_bytes)
8296

8397
if msg_header.kind == MSG_TYPE_WAMP:
84-
return self._sock.recv(msg_header.length)
98+
return _recv_exactly(self._sock, msg_header.length)
8599
elif msg_header.kind == MSG_TYPE_PING:
86-
ping_payload = self._sock.recv(msg_header.length)
100+
ping_payload = _recv_exactly(self._sock, msg_header.length)
87101
pong = MessageHeader(MSG_TYPE_PONG, msg_header.length)
88102
self._sock.sendall(pong.to_bytes())
89103
self._sock.sendall(ping_payload)
90104

91105
return self.read()
92106
elif msg_header.kind == MSG_TYPE_PONG:
93-
pong_payload = self._sock.recv(msg_header.length)
107+
pong_payload = _recv_exactly(self._sock, msg_header.length)
94108
pending_ping = self._pending_pings.pop(pong_payload, None)
95109
if pending_ping is not None:
96110
received_at = time.time() * 1000

0 commit comments

Comments
 (0)