Skip to content

Commit

Permalink
fix: check writer is closing in AIOKafkaConnection.send
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Aug 20, 2024
1 parent 01c60cd commit e82bd33
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 6 deletions.
6 changes: 6 additions & 0 deletions aiokafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@ def send(self, request, expect_response=True):
f"No connection to broker at {self._host}:{self._port}"
)

if self._writer.is_closing():
self.close(reason=CloseReason.CONNECTION_BROKEN)
raise Errors.KafkaConnectionError(
f"Connection at {self._host}:{self._port} is closing"
)

correlation_id = self._next_correlation_id()
header = request.build_request_header(
correlation_id=correlation_id, client_id=self._client_id
Expand Down
1 change: 1 addition & 0 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Pygments==2.15.0
gssapi==1.8.3
async-timeout==4.0.1
cramjam==2.8.0
uvloop==0.19.0
92 changes: 86 additions & 6 deletions tests/test_conn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import gc
import socket
import struct
from typing import Any
import sys
from typing import Any, AsyncIterable, Iterable, Tuple
from unittest import mock

import pytest
import pytest_asyncio

from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn
from aiokafka.errors import (
Expand Down Expand Up @@ -144,7 +147,7 @@ async def test_send_to_closed(self):
with self.assertRaises(KafkaConnectionError):
await conn.send(request)

conn._writer = mock.MagicMock()
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
conn._writer.write.side_effect = OSError("mocked writer is closed")

with self.assertRaises(KafkaConnectionError):
Expand Down Expand Up @@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any):
return resp

reader.readexactly.side_effect = [first_resp(), second_resp()]
writer = mock.MagicMock()
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))

conn._reader = reader
conn._writer = writer
Expand Down Expand Up @@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any):
return resp

reader.readexactly.side_effect = [first_resp(), second_resp()]
writer = mock.MagicMock()
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))

conn._reader = reader
conn._writer = writer
Expand Down Expand Up @@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw):
# setup reader
reader = mock.MagicMock()
reader.readexactly.return_value = invoke_osserror()
writer = mock.MagicMock()
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))

conn._reader = reader
conn._writer = writer
Expand Down Expand Up @@ -394,7 +397,7 @@ async def test__send_sasl_token(self):
# setup connection with mocked transport and protocol
conn = AIOKafkaConnection(host="", port=9999)
conn.close = mock.MagicMock()
conn._writer = mock.MagicMock()
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
out_buffer = []
conn._writer.write = mock.Mock(side_effect=out_buffer.append)
conn._reader = mock.MagicMock()
Expand Down Expand Up @@ -424,3 +427,80 @@ async def test__send_sasl_token(self):
conn._send_sasl_token(b"Super data")
# We don't need to close 2ce
self.assertEqual(conn.close.call_count, 1)


class TestClosedSocket:
@pytest.mark.skipif(sys.platform == "win32")
@pytest.fixture(
params=(
pytest.param("asyncio", id="asyncio"),
pytest.param("uvloop", id="uvloop"),
),
)
def event_loop(
self, request: pytest.FixtureRequest
) -> Iterable[asyncio.AbstractEventLoop]:
if request.param == "asyncio":
policy = asyncio.DefaultEventLoopPolicy()
elif request.param == "uvloop":
import uvloop

policy = uvloop.EventLoopPolicy()
else:
raise ValueError(f"loop {request.param} is not supported")

Check warning on line 450 in tests/test_conn.py

View check run for this annotation

Codecov / codecov/patch

tests/test_conn.py#L450

Added line #L450 was not covered by tests

loop: asyncio.AbstractEventLoop = policy.new_event_loop()
yield loop
loop.close()

@pytest.fixture()
def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]:
host = "localhost"
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((host, unused_tcp_port))
sock.listen(8)
sock.setblocking(False)

yield host, unused_tcp_port, sock

sock.close()

@pytest_asyncio.fixture()
async def conn(
self, server: Tuple[str, int, socket.socket]
) -> AsyncIterable[AIOKafkaConnection]:
host, port, _ = server

conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000)
conn._create_reader_task = mock.Mock()

yield conn

fut = conn.close()
if fut:
await fut

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.

@pytest.mark.asyncio
async def test_send_to_closed_socket(
self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection
) -> None:
host, port, sock = server

request = MetadataRequest([])

with pytest.raises(
KafkaConnectionError,
match=f"KafkaConnectionError: No connection to broker at {host}:{port}",
):
await conn.send(request)

await conn.connect()

sock.close()
await asyncio.sleep(0.1)

with pytest.raises(
KafkaConnectionError,
match=f"KafkaConnectionError: Connection at {host}:{port} is closing",
):
await conn.send(request)

0 comments on commit e82bd33

Please sign in to comment.