@@ -42,6 +42,20 @@ def create_ping():
4242
4343 return payload , ping_header , created_at
4444
45+ def _recv_exactly (sock , n : int ) -> bytes :
46+ """Receive exactly n bytes from a socket or raise if connection breaks."""
47+ chunks = []
48+ received = 0
49+ while received < n :
50+ chunk = sock .recv (n - received )
51+ if not chunk :
52+ raise ConnectionError ("Socket connection broken" )
53+
54+ chunks .append (chunk )
55+ received += len (chunk )
56+
57+ return b"" .join (chunks )
58+
4559
4660class RawSocketTransport (ITransport ):
4761 def __init__ (self , sock : socket .socket ):
@@ -68,7 +82,7 @@ def connect(
6882
6983 sock .sendall (hs_request .to_bytes ())
7084
71- hs_response_bytes = sock . recv ( RAW_SOCKET_HEADER_LENGTH )
85+ hs_response_bytes = _recv_exactly ( sock , RAW_SOCKET_HEADER_LENGTH )
7286 hs_response = Handshake .from_bytes (hs_response_bytes )
7387
7488 if hs_request .protocol != hs_response .protocol :
@@ -77,20 +91,20 @@ def connect(
7791 return RawSocketTransport (sock )
7892
7993 def read (self ) -> str | bytes :
80- msg_header_bytes = self ._sock . recv ( RAW_SOCKET_HEADER_LENGTH )
94+ msg_header_bytes = _recv_exactly ( self ._sock , RAW_SOCKET_HEADER_LENGTH )
8195 msg_header = MessageHeader .from_bytes (msg_header_bytes )
8296
8397 if msg_header .kind == MSG_TYPE_WAMP :
84- return self ._sock . recv ( msg_header .length )
98+ return _recv_exactly ( self ._sock , msg_header .length )
8599 elif msg_header .kind == MSG_TYPE_PING :
86- ping_payload = self ._sock . recv ( msg_header .length )
100+ ping_payload = _recv_exactly ( self ._sock , msg_header .length )
87101 pong = MessageHeader (MSG_TYPE_PONG , msg_header .length )
88102 self ._sock .sendall (pong .to_bytes ())
89103 self ._sock .sendall (ping_payload )
90104
91105 return self .read ()
92106 elif msg_header .kind == MSG_TYPE_PONG :
93- pong_payload = self ._sock . recv ( msg_header .length )
107+ pong_payload = _recv_exactly ( self ._sock , msg_header .length )
94108 pending_ping = self ._pending_pings .pop (pong_payload , None )
95109 if pending_ping is not None :
96110 received_at = time .time () * 1000
@@ -157,7 +171,7 @@ async def connect(
157171 writer .write (hs_request .to_bytes ())
158172 await writer .drain ()
159173
160- hs_response_bytes = await reader .read (RAW_SOCKET_HEADER_LENGTH )
174+ hs_response_bytes = await reader .readexactly (RAW_SOCKET_HEADER_LENGTH )
161175 hs_response = Handshake .from_bytes (hs_response_bytes )
162176
163177 if hs_request .protocol != hs_response .protocol :
@@ -166,13 +180,13 @@ async def connect(
166180 return AsyncRawSocketTransport (reader , writer )
167181
168182 async def read (self ) -> str | bytes :
169- msg_header_bytes = await self ._reader .read (RAW_SOCKET_HEADER_LENGTH )
183+ msg_header_bytes = await self ._reader .readexactly (RAW_SOCKET_HEADER_LENGTH )
170184 msg_header = MessageHeader .from_bytes (msg_header_bytes )
171185
172186 if msg_header .kind == MSG_TYPE_WAMP :
173- return await self ._reader .read (msg_header .length )
187+ return await self ._reader .readexactly (msg_header .length )
174188 elif msg_header .kind == MSG_TYPE_PING :
175- ping_payload = await self ._reader .read (msg_header .length )
189+ ping_payload = await self ._reader .readexactly (msg_header .length )
176190 pong = MessageHeader (MSG_TYPE_PONG , msg_header .length )
177191 self ._writer .write (pong .to_bytes ())
178192 await self ._writer .drain ()
@@ -181,7 +195,7 @@ async def read(self) -> str | bytes:
181195
182196 return await self .read ()
183197 elif msg_header .kind == MSG_TYPE_PONG :
184- pong_payload = await self ._reader .read (msg_header .length )
198+ pong_payload = await self ._reader .readexactly (msg_header .length )
185199 pending_ping = self ._pending_pings .pop (pong_payload , None )
186200 if pending_ping is not None :
187201 received_at = time .time () * 1000
0 commit comments