Skip to content

Commit 4e93a6c

Browse files
authored
Merge pull request #194 from xconnio/rs-payload-fix
rawsocket: ensure to read full payload
2 parents 6bc7d35 + 9e34634 commit 4e93a6c

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

xconn/transports.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ def create_ping():
4242

4343
return payload, ping_header, created_at
4444

45+
def _recv_exactly(sock, n: int) -> bytes:
46+
"""Receive exactly n bytes from a socket or raise if connection breaks."""
47+
chunks = []
48+
received = 0
49+
while received < n:
50+
chunk = sock.recv(n - received)
51+
if not chunk:
52+
raise ConnectionError("Socket connection broken")
53+
54+
chunks.append(chunk)
55+
received += len(chunk)
56+
57+
return b"".join(chunks)
58+
4559

4660
class RawSocketTransport(ITransport):
4761
def __init__(self, sock: socket.socket):
@@ -68,7 +82,7 @@ def connect(
6882

6983
sock.sendall(hs_request.to_bytes())
7084

71-
hs_response_bytes = sock.recv(RAW_SOCKET_HEADER_LENGTH)
85+
hs_response_bytes = _recv_exactly(sock, RAW_SOCKET_HEADER_LENGTH)
7286
hs_response = Handshake.from_bytes(hs_response_bytes)
7387

7488
if hs_request.protocol != hs_response.protocol:
@@ -77,20 +91,20 @@ def connect(
7791
return RawSocketTransport(sock)
7892

7993
def read(self) -> str | bytes:
80-
msg_header_bytes = self._sock.recv(RAW_SOCKET_HEADER_LENGTH)
94+
msg_header_bytes = _recv_exactly(self._sock, RAW_SOCKET_HEADER_LENGTH)
8195
msg_header = MessageHeader.from_bytes(msg_header_bytes)
8296

8397
if msg_header.kind == MSG_TYPE_WAMP:
84-
return self._sock.recv(msg_header.length)
98+
return _recv_exactly(self._sock, msg_header.length)
8599
elif msg_header.kind == MSG_TYPE_PING:
86-
ping_payload = self._sock.recv(msg_header.length)
100+
ping_payload = _recv_exactly(self._sock, msg_header.length)
87101
pong = MessageHeader(MSG_TYPE_PONG, msg_header.length)
88102
self._sock.sendall(pong.to_bytes())
89103
self._sock.sendall(ping_payload)
90104

91105
return self.read()
92106
elif msg_header.kind == MSG_TYPE_PONG:
93-
pong_payload = self._sock.recv(msg_header.length)
107+
pong_payload = _recv_exactly(self._sock, msg_header.length)
94108
pending_ping = self._pending_pings.pop(pong_payload, None)
95109
if pending_ping is not None:
96110
received_at = time.time() * 1000
@@ -157,7 +171,7 @@ async def connect(
157171
writer.write(hs_request.to_bytes())
158172
await writer.drain()
159173

160-
hs_response_bytes = await reader.read(RAW_SOCKET_HEADER_LENGTH)
174+
hs_response_bytes = await reader.readexactly(RAW_SOCKET_HEADER_LENGTH)
161175
hs_response = Handshake.from_bytes(hs_response_bytes)
162176

163177
if hs_request.protocol != hs_response.protocol:
@@ -166,13 +180,13 @@ async def connect(
166180
return AsyncRawSocketTransport(reader, writer)
167181

168182
async def read(self) -> str | bytes:
169-
msg_header_bytes = await self._reader.read(RAW_SOCKET_HEADER_LENGTH)
183+
msg_header_bytes = await self._reader.readexactly(RAW_SOCKET_HEADER_LENGTH)
170184
msg_header = MessageHeader.from_bytes(msg_header_bytes)
171185

172186
if msg_header.kind == MSG_TYPE_WAMP:
173-
return await self._reader.read(msg_header.length)
187+
return await self._reader.readexactly(msg_header.length)
174188
elif msg_header.kind == MSG_TYPE_PING:
175-
ping_payload = await self._reader.read(msg_header.length)
189+
ping_payload = await self._reader.readexactly(msg_header.length)
176190
pong = MessageHeader(MSG_TYPE_PONG, msg_header.length)
177191
self._writer.write(pong.to_bytes())
178192
await self._writer.drain()
@@ -181,7 +195,7 @@ async def read(self) -> str | bytes:
181195

182196
return await self.read()
183197
elif msg_header.kind == MSG_TYPE_PONG:
184-
pong_payload = await self._reader.read(msg_header.length)
198+
pong_payload = await self._reader.readexactly(msg_header.length)
185199
pending_ping = self._pending_pings.pop(pong_payload, None)
186200
if pending_ping is not None:
187201
received_at = time.time() * 1000

0 commit comments

Comments
 (0)