From 2243dc642b1786dff50884254833408a0ec96074 Mon Sep 17 00:00:00 2001 From: NeonKirill Date: Thu, 5 Dec 2024 12:55:48 +0100 Subject: [PATCH] Eliminated "self joining" problem in consumer threads and added graceful restarts handling --- neon_mq_connector/connector.py | 80 +++++++++++-------- .../consumers/blocking_consumer.py | 29 +++---- .../consumers/select_consumer.py | 19 +++-- tests/test_connector.py | 2 +- 4 files changed, 72 insertions(+), 58 deletions(-) diff --git a/neon_mq_connector/connector.py b/neon_mq_connector/connector.py index 22228ec..2865285 100644 --- a/neon_mq_connector/connector.py +++ b/neon_mq_connector/connector.py @@ -415,7 +415,7 @@ def register_consumer(self, name: str, vhost: str, queue: str, Registers a consumer for the specified queue. The callback function will handle items in the queue. Any raised exceptions will be passed as arguments to on_error. - :param name: Human readable name of the consumer + :param name: Human-readable name of the consumer :param vhost: vhost to register on :param queue: MQ Queue to read messages from :param queue_reset: to delete queue if exists (defaults to False) @@ -423,7 +423,7 @@ def register_consumer(self, name: str, vhost: str, queue: str, :param exchange_reset: to delete exchange if exists (defaults to False) :param exchange_type: Type of MQ Exchange to use, documentation: https://www.rabbitmq.com/tutorials/amqp-concepts.html - :param callback: Method to passed queued messages to + :param callback: Callback method on received messages :param on_error: Optional method to handle any exceptions raised in message handling :param auto_ack: Boolean to enable ack of messages upon receipt @@ -439,17 +439,32 @@ def register_consumer(self, name: str, vhost: str, queue: str, if skip_on_existing: LOG.info(f'Consumer under index "{name}" already declared') return - self.stop_consumers(names=(name,), allow_restart=False) + self.stop_consumers(names=(name,)) self.consumer_properties.setdefault(name, {}) self.consumer_properties[name]['properties'] = \ - dict(connection_params=self.get_connection_params(vhost), - queue=queue, queue_reset=queue_reset, callback_func=callback, - exchange=exchange, exchange_reset=exchange_reset, - exchange_type=exchange_type, error_func=error_handler, - auto_ack=auto_ack, name=name, queue_exclusive=queue_exclusive, ) - self.consumer_properties[name]['restart_attempts'] = \ - int(restart_attempts) + dict( + name=name, + connection_params=self.get_connection_params(vhost), + queue=queue, + queue_reset=queue_reset, + callback_func=callback, + exchange=exchange, + exchange_reset=exchange_reset, + exchange_type=exchange_type, + error_func=error_handler, + auto_ack=auto_ack, + queue_exclusive=queue_exclusive, + ) + self.consumer_properties[name]['restart_attempts'] = int(restart_attempts) self.consumer_properties[name]['started'] = False + + if exchange_type == ExchangeType.fanout.value: + LOG.info(f'Subscriber exchange listener registered: ' + f'[name={name},exchange={exchange},vhost={vhost}]') + else: + LOG.info(f'Consumer queue listener registered: ' + f'[name={name},queue={queue},vhost={vhost}]') + self.consumers[name] = self.consumer_thread_cls(**self.consumer_properties[name]['properties']) @property @@ -459,17 +474,16 @@ def consumer_thread_cls(self) -> Type[ConsumerThreadInstance]: return BlockingConsumerThread def restart_consumer(self, name: str): - self.stop_consumers(names=(name,), allow_restart=True) + self.stop_consumers(names=(name,)) consumer_data = self.consumer_properties.get(name, {}) restart_attempts = consumer_data.get('restart_attempts', self.__max_consumer_restarts__) err_msg = '' - if not consumer_data.get('is_alive', True): - LOG.debug(f'Skipping joined consumer = "{name}"') - elif not consumer_data.get('properties'): + if not consumer_data.get('properties'): err_msg = 'creation properties not found' elif 0 < restart_attempts < consumer_data.get('num_restarted', 0): err_msg = 'num restarts exceeded' + self.consumers.pop(name, None) else: self.consumers[name] = self.consumer_thread_cls(**consumer_data['properties']) self.run_consumers(names=(name,)) @@ -481,19 +495,19 @@ def restart_consumer(self, name: str): def register_subscriber(self, name: str, vhost: str, callback: callable, on_error: Optional[callable] = None, - exchange: str = None, exchange_reset: bool = False, + exchange: str = None, + exchange_reset: bool = False, auto_ack: bool = True, skip_on_existing: bool = False, restart_attempts: int = __max_consumer_restarts__): """ Registers fanout exchange subscriber, wraps register_consumer() Any raised exceptions will be passed as arguments to on_error. - :param name: Human readable name of the consumer + :param name: Human-readable name of the consumer :param vhost: vhost to register on - :param exchange: MQ Exchange to bind to - :param exchange_reset: to delete exchange if exists - (defaults to False) - :param callback: Method to passed queued messages to + :param exchange: MQ Exchange for binding to + :param exchange_reset: delete exchange if exists (defaults to False) + :param callback: Callback method on received messages :param on_error: Optional method to handle any exceptions raised in message handling :param auto_ack: Boolean to enable ack of messages upon receipt @@ -503,10 +517,8 @@ def register_subscriber(self, name: str, vhost: str, (if < 0 - will restart infinitely times) """ # for fanout exchange queue does not matter unless its non-conflicting - # and is binded + # and is bounded subscriber_queue = f'subscriber_{exchange}_{uuid.uuid4().hex[:6]}' - LOG.info(f'Subscriber queue registered: {subscriber_queue} ' - f'[subscriber_name={name},exchange={exchange},vhost={vhost}]') return self.register_consumer(name=name, vhost=vhost, queue=subscriber_queue, callback=callback, queue_reset=False, @@ -521,7 +533,7 @@ def register_subscriber(self, name: str, vhost: str, def default_error_handler(thread: ConsumerThreadInstance, exception: Exception): LOG.error(f"{exception} occurred in {thread}") - def run_consumers(self, names: tuple = (), daemon=True): + def run_consumers(self, names: Optional[tuple] = None, daemon=True): """ Runs consumer threads based on the name if present (starts all of the declared consumers by default) @@ -529,26 +541,28 @@ def run_consumers(self, names: tuple = (), daemon=True): :param names: names of consumers to consider :param daemon: to kill consumer threads once main thread is over """ - if not names or len(names) == 0: + if not names: names = list(self.consumers) for name in names: - if isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS) and self.consumers[name].is_consumer_alive: + if (isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS) + and self.consumers[name].is_consumer_alive + and not self.consumers[name].is_consuming): self.consumers[name].daemon = daemon self.consumers[name].start() self.consumer_properties[name]['started'] = True - def stop_consumers(self, names: tuple = (), allow_restart: bool = True): + def stop_consumers(self, names: Optional[tuple] = None): """ Stops consumer threads based on the name if present (stops all of the declared consumers by default) """ - if not names or len(names) == 0: + if not names: names = list(self.consumers) for name in names: try: if isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS) and self.consumers[name].is_alive(): - self.consumers[name].join(timeout=self.__consumer_join_timeout__, allow_restart=allow_restart) - self.consumer_properties[name]['is_alive'] = self.consumers[name].is_consumer_alive + self.consumers[name].join(timeout=self.__consumer_join_timeout__) + time.sleep(self.__consumer_join_timeout__) self.consumer_properties[name]['started'] = False except Exception as e: raise ChildProcessError(e) @@ -628,10 +642,10 @@ def observe_consumers(self): # LOG.debug('Observers state observation') consumers_dict = copy.copy(self.consumers) for consumer_name, consumer_instance in consumers_dict.items(): - if self.consumer_properties[consumer_name]['started'] and \ + if (self.consumer_properties[consumer_name]['started'] and not (isinstance(consumer_instance, SUPPORTED_THREADED_CONSUMERS) and consumer_instance.is_alive() - and consumer_instance.is_consuming): + and consumer_instance.is_consumer_alive)): LOG.info(f'Consumer "{consumer_name}" is dead, restarting') self.restart_consumer(name=consumer_name) @@ -653,7 +667,7 @@ def stop_observer_thread(self): def stop(self): """Generic method for graceful instance stopping""" - self.stop_consumers(allow_restart=False) + self.stop_consumers() self.stop_sync_thread() self.stop_observer_thread() diff --git a/neon_mq_connector/consumers/blocking_consumer.py b/neon_mq_connector/consumers/blocking_consumer.py index b6e2d44..f6b6a75 100644 --- a/neon_mq_connector/consumers/blocking_consumer.py +++ b/neon_mq_connector/consumers/blocking_consumer.py @@ -114,26 +114,27 @@ def run(self): self._is_consuming = True self.channel.start_consuming() except Exception as e: - self._is_consuming = False + self._close_connection() if isinstance(e, pika.exceptions.ChannelClosed): LOG.info(f"Channel closed by broker: {self.callback_func}") elif isinstance(e, pika.exceptions.StreamLostError): LOG.info("Connection closed by broker") else: self.error_func(self, e) - self.join(allow_restart=True) - def join(self, timeout: Optional[float] = ..., allow_restart: bool = True) -> None: + def join(self, timeout: Optional[float] = None) -> None: """Terminating consumer channel""" if self._is_consumer_alive: - try: - self.channel.stop_consuming() - if self.connection.is_open: - self.connection.close() - except Exception as e: - LOG.error(e) - finally: - self._is_consuming = False - if not allow_restart: - self._is_consumer_alive = False - super(BlockingConsumerThread, self).join(timeout=timeout) + self._close_connection() + super(BlockingConsumerThread, self).join(timeout=timeout) + + def _close_connection(self): + try: + if self.connection and self.connection.is_open: + self.connection.close() + except pika.exceptions.StreamLostError: + pass + except Exception as e: + LOG.exception(f"Failed to close connection due to unexpected exception: {e}") + self._is_consuming = False + self._is_consumer_alive = False diff --git a/neon_mq_connector/consumers/select_consumer.py b/neon_mq_connector/consumers/select_consumer.py index d31b9ed..2109ef0 100644 --- a/neon_mq_connector/consumers/select_consumer.py +++ b/neon_mq_connector/consumers/select_consumer.py @@ -108,7 +108,7 @@ def on_connection_fail(self, _): self.connection_failed_attempts += 1 if self.connection_failed_attempts > self.max_connection_failed_attempts: LOG.error(f'Failed establish MQ connection after {self.max_connection_failed_attempts} attempts') - self.join(timeout=1) + self._close_connection() else: self.reconnect() @@ -173,7 +173,7 @@ def on_message(self, channel, method, properties, body): def on_close(self, _, e): LOG.error(f"Closing MQ connection due to exception: {e}") - self.join() + self.reconnect() @property def is_consumer_alive(self) -> bool: @@ -192,9 +192,9 @@ def run(self): self.connection.ioloop.start() except Exception as e: LOG.error(f"Failed to start io loop on consumer thread {self.name!r}: {e}") - self.join(allow_restart=True) + self._close_connection() - def _close_connection(self): + def _close_connection(self, mark_consumer_as_dead: bool = True): try: if self.connection and not (self.connection.is_closed or self.connection.is_closing): self.connection.ioloop.stop() @@ -202,17 +202,16 @@ def _close_connection(self): except Exception as e: LOG.error(f"Failed to close connection for Consumer {self.name!r}: {e}") self._is_consuming = False + if mark_consumer_as_dead: + self._is_consumer_alive = False def reconnect(self, wait_interval: int = 1): - self._close_connection() + self._close_connection(mark_consumer_as_dead=False) time.sleep(wait_interval) self.run() - def join(self, timeout: Optional[float] = None, allow_restart: bool = True) -> None: + def join(self, timeout: Optional[float] = None) -> None: """Terminating consumer channel""" if self.is_consumer_alive and self.is_consuming: - self._is_consuming = False - self._close_connection() - if not allow_restart: - self._is_consumer_alive = False + self._close_connection(mark_consumer_as_dead=True) super().join(timeout=timeout) diff --git a/tests/test_connector.py b/tests/test_connector.py index be7186d..a89e6ea 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -130,7 +130,7 @@ def test_not_null_service_id(self): self.assertIsNotNone(self.connector_instance.service_id) def tearDown(self): - self.connector_instance.stop_consumers(allow_restart=False) + self.connector_instance.stop_consumers() @pytest.mark.timeout(30) def test_mq_messaging(self):