From 13604c7e59b53e7b9a0260210d8827ed4f6d7e51 Mon Sep 17 00:00:00 2001 From: Daniel McKnight <34697904+NeonDaniel@users.noreply.github.com> Date: Thu, 2 Jan 2025 13:15:15 -0800 Subject: [PATCH] Improve connection close and error handling (#107) * Update `SelectConsumerThread` to pass exceptions to `self.error_func` to match `BlockingConsumerThread` Handle channel/connection closed exceptions explicitly in `SelectConsumerThread` * Update `BlockingConsumerThread` to ensure connection is closed exactly once Update tests to check for expected error callbacks * Ignore `StreamLostError`s during consumer shutdown --- .../consumers/blocking_consumer.py | 21 ++++++++++++------- .../consumers/select_consumer.py | 8 +++++++ tests/test_consumers.py | 3 +++ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/neon_mq_connector/consumers/blocking_consumer.py b/neon_mq_connector/consumers/blocking_consumer.py index 60f0a85..a2aaa37 100644 --- a/neon_mq_connector/consumers/blocking_consumer.py +++ b/neon_mq_connector/consumers/blocking_consumer.py @@ -108,14 +108,19 @@ def run(self): self._create_connection() self._consumer_started.set() self.channel.start_consuming() - except Exception as e: - self._close_connection() - if isinstance(e, pika.exceptions.ChannelClosed): - LOG.info(f"Channel closed by broker: {self.callback_func}") - elif isinstance(e, pika.exceptions.StreamLostError): - LOG.info("Connection closed by broker") - else: + except (pika.exceptions.ChannelClosed, + pika.exceptions.ConnectionClosed) as e: + LOG.info(f"Closed {e.reply_code}: {e.reply_text}") + if self._is_consumer_alive: + self._close_connection() + self.error_func(self, e) + except pika.exceptions.StreamLostError as e: + if self._is_consumer_alive: self.error_func(self, e) + except Exception as e: + if self._is_consumer_alive: + self._close_connection() + self.error_func(self, e) def _create_connection(self): self.connection = pika.BlockingConnection(self.connection_params) @@ -145,6 +150,7 @@ def join(self, timeout: Optional[float] = None) -> None: super(BlockingConsumerThread, self).join(timeout=timeout) def _close_connection(self): + self._is_consumer_alive = False try: if self.connection and self.connection.is_open: self.connection.close() @@ -153,4 +159,3 @@ def _close_connection(self): except Exception as e: LOG.exception(f"Failed to close connection due to unexpected exception: {e}") self._consumer_started.clear() - self._is_consumer_alive = False diff --git a/neon_mq_connector/consumers/select_consumer.py b/neon_mq_connector/consumers/select_consumer.py index 2dad745..c18f782 100644 --- a/neon_mq_connector/consumers/select_consumer.py +++ b/neon_mq_connector/consumers/select_consumer.py @@ -206,9 +206,17 @@ def run(self): super(SelectConsumerThread, self).run() self.connection: pika.SelectConnection = self.create_connection() self.connection.ioloop.start() + except (pika.exceptions.ChannelClosed, + pika.exceptions.ConnectionClosed) as e: + LOG.info(f"Closed {e.reply_code}: {e.reply_text}") + if not self._stopping: + # Connection was unexpectedly closed + self._close_connection() + self.error_func(self, e) except Exception as e: LOG.error(f"Failed to start io loop on consumer thread {self.name!r}: {e}") self._close_connection() + self.error_func(self, e) def _close_connection(self, mark_consumer_as_dead: bool = True): try: diff --git a/tests/test_consumers.py b/tests/test_consumers.py index 05a7f78..c034e47 100644 --- a/tests/test_consumers.py +++ b/tests/test_consumers.py @@ -81,6 +81,7 @@ def test_blocking_consumer_thread(self): self.assertFalse(test_thread.is_consuming) self.assertTrue(test_thread.channel.is_closed) self.assertFalse(test_thread.is_consumer_alive) + test_thread.error_func.assert_not_called() # Invalid thread connection connection_params.port = 80 @@ -90,6 +91,7 @@ def test_blocking_consumer_thread(self): test_thread._consumer_started.wait(5) self.assertFalse(test_thread.is_consuming) self.assertIsNone(test_thread.channel) + test_thread.error_func.assert_called_once() test_thread.join(30) self.assertFalse(test_thread.is_consuming) @@ -147,6 +149,7 @@ def test_select_consumer_thread(self): self.assertFalse(test_thread.is_consumer_alive) self.assertTrue(test_thread.channel.is_closed) test_thread.on_close.assert_called_once() + error.assert_not_called() # Invalid thread connection connection_params.port = 80