Skip to content

Commit

Permalink
Eliminated "self joining" problem in consumer threads and added grace…
Browse files Browse the repository at this point in the history
…ful restarts handling
  • Loading branch information
kirgrim committed Dec 5, 2024
1 parent a39737e commit 2243dc6
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 58 deletions.
80 changes: 47 additions & 33 deletions neon_mq_connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,15 @@ def register_consumer(self, name: str, vhost: str, queue: str,
Registers a consumer for the specified queue.
The callback function will handle items in the queue.
Any raised exceptions will be passed as arguments to on_error.
:param name: Human readable name of the consumer
:param name: Human-readable name of the consumer
:param vhost: vhost to register on
:param queue: MQ Queue to read messages from
:param queue_reset: to delete queue if exists (defaults to False)
:param exchange: MQ Exchange to bind to
:param exchange_reset: to delete exchange if exists (defaults to False)
:param exchange_type: Type of MQ Exchange to use, documentation:
https://www.rabbitmq.com/tutorials/amqp-concepts.html
:param callback: Method to passed queued messages to
:param callback: Callback method on received messages
:param on_error: Optional method to handle any exceptions
raised in message handling
:param auto_ack: Boolean to enable ack of messages upon receipt
Expand All @@ -439,17 +439,32 @@ def register_consumer(self, name: str, vhost: str, queue: str,
if skip_on_existing:
LOG.info(f'Consumer under index "{name}" already declared')
return
self.stop_consumers(names=(name,), allow_restart=False)
self.stop_consumers(names=(name,))
self.consumer_properties.setdefault(name, {})
self.consumer_properties[name]['properties'] = \
dict(connection_params=self.get_connection_params(vhost),
queue=queue, queue_reset=queue_reset, callback_func=callback,
exchange=exchange, exchange_reset=exchange_reset,
exchange_type=exchange_type, error_func=error_handler,
auto_ack=auto_ack, name=name, queue_exclusive=queue_exclusive, )
self.consumer_properties[name]['restart_attempts'] = \
int(restart_attempts)
dict(
name=name,
connection_params=self.get_connection_params(vhost),
queue=queue,
queue_reset=queue_reset,
callback_func=callback,
exchange=exchange,
exchange_reset=exchange_reset,
exchange_type=exchange_type,
error_func=error_handler,
auto_ack=auto_ack,
queue_exclusive=queue_exclusive,
)
self.consumer_properties[name]['restart_attempts'] = int(restart_attempts)
self.consumer_properties[name]['started'] = False

if exchange_type == ExchangeType.fanout.value:
LOG.info(f'Subscriber exchange listener registered: '
f'[name={name},exchange={exchange},vhost={vhost}]')
else:
LOG.info(f'Consumer queue listener registered: '
f'[name={name},queue={queue},vhost={vhost}]')

self.consumers[name] = self.consumer_thread_cls(**self.consumer_properties[name]['properties'])

@property
Expand All @@ -459,17 +474,16 @@ def consumer_thread_cls(self) -> Type[ConsumerThreadInstance]:
return BlockingConsumerThread

def restart_consumer(self, name: str):
self.stop_consumers(names=(name,), allow_restart=True)
self.stop_consumers(names=(name,))
consumer_data = self.consumer_properties.get(name, {})
restart_attempts = consumer_data.get('restart_attempts',
self.__max_consumer_restarts__)
err_msg = ''
if not consumer_data.get('is_alive', True):
LOG.debug(f'Skipping joined consumer = "{name}"')
elif not consumer_data.get('properties'):
if not consumer_data.get('properties'):
err_msg = 'creation properties not found'
elif 0 < restart_attempts < consumer_data.get('num_restarted', 0):
err_msg = 'num restarts exceeded'
self.consumers.pop(name, None)
else:
self.consumers[name] = self.consumer_thread_cls(**consumer_data['properties'])
self.run_consumers(names=(name,))
Expand All @@ -481,19 +495,19 @@ def restart_consumer(self, name: str):
def register_subscriber(self, name: str, vhost: str,
callback: callable,
on_error: Optional[callable] = None,
exchange: str = None, exchange_reset: bool = False,
exchange: str = None,
exchange_reset: bool = False,
auto_ack: bool = True,
skip_on_existing: bool = False,
restart_attempts: int = __max_consumer_restarts__):
"""
Registers fanout exchange subscriber, wraps register_consumer()
Any raised exceptions will be passed as arguments to on_error.
:param name: Human readable name of the consumer
:param name: Human-readable name of the consumer
:param vhost: vhost to register on
:param exchange: MQ Exchange to bind to
:param exchange_reset: to delete exchange if exists
(defaults to False)
:param callback: Method to passed queued messages to
:param exchange: MQ Exchange for binding to
:param exchange_reset: delete exchange if exists (defaults to False)
:param callback: Callback method on received messages
:param on_error: Optional method to handle any exceptions raised
in message handling
:param auto_ack: Boolean to enable ack of messages upon receipt
Expand All @@ -503,10 +517,8 @@ def register_subscriber(self, name: str, vhost: str,
(if < 0 - will restart infinitely times)
"""
# for fanout exchange queue does not matter unless its non-conflicting
# and is binded
# and is bounded
subscriber_queue = f'subscriber_{exchange}_{uuid.uuid4().hex[:6]}'
LOG.info(f'Subscriber queue registered: {subscriber_queue} '
f'[subscriber_name={name},exchange={exchange},vhost={vhost}]')
return self.register_consumer(name=name, vhost=vhost,
queue=subscriber_queue,
callback=callback, queue_reset=False,
Expand All @@ -521,34 +533,36 @@ def register_subscriber(self, name: str, vhost: str,
def default_error_handler(thread: ConsumerThreadInstance, exception: Exception):
LOG.error(f"{exception} occurred in {thread}")

def run_consumers(self, names: tuple = (), daemon=True):
def run_consumers(self, names: Optional[tuple] = None, daemon=True):
"""
Runs consumer threads based on the name if present
(starts all of the declared consumers by default)
:param names: names of consumers to consider
:param daemon: to kill consumer threads once main thread is over
"""
if not names or len(names) == 0:
if not names:
names = list(self.consumers)
for name in names:
if isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS) and self.consumers[name].is_consumer_alive:
if (isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS)
and self.consumers[name].is_consumer_alive
and not self.consumers[name].is_consuming):
self.consumers[name].daemon = daemon
self.consumers[name].start()
self.consumer_properties[name]['started'] = True

def stop_consumers(self, names: tuple = (), allow_restart: bool = True):
def stop_consumers(self, names: Optional[tuple] = None):
"""
Stops consumer threads based on the name if present
(stops all of the declared consumers by default)
"""
if not names or len(names) == 0:
if not names:
names = list(self.consumers)
for name in names:
try:
if isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS) and self.consumers[name].is_alive():
self.consumers[name].join(timeout=self.__consumer_join_timeout__, allow_restart=allow_restart)
self.consumer_properties[name]['is_alive'] = self.consumers[name].is_consumer_alive
self.consumers[name].join(timeout=self.__consumer_join_timeout__)
time.sleep(self.__consumer_join_timeout__)
self.consumer_properties[name]['started'] = False
except Exception as e:
raise ChildProcessError(e)
Expand Down Expand Up @@ -628,10 +642,10 @@ def observe_consumers(self):
# LOG.debug('Observers state observation')
consumers_dict = copy.copy(self.consumers)
for consumer_name, consumer_instance in consumers_dict.items():
if self.consumer_properties[consumer_name]['started'] and \
if (self.consumer_properties[consumer_name]['started'] and
not (isinstance(consumer_instance, SUPPORTED_THREADED_CONSUMERS)
and consumer_instance.is_alive()
and consumer_instance.is_consuming):
and consumer_instance.is_consumer_alive)):
LOG.info(f'Consumer "{consumer_name}" is dead, restarting')
self.restart_consumer(name=consumer_name)

Expand All @@ -653,7 +667,7 @@ def stop_observer_thread(self):

def stop(self):
"""Generic method for graceful instance stopping"""
self.stop_consumers(allow_restart=False)
self.stop_consumers()
self.stop_sync_thread()
self.stop_observer_thread()

Expand Down
29 changes: 15 additions & 14 deletions neon_mq_connector/consumers/blocking_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,27 @@ def run(self):
self._is_consuming = True
self.channel.start_consuming()
except Exception as e:
self._is_consuming = False
self._close_connection()
if isinstance(e, pika.exceptions.ChannelClosed):
LOG.info(f"Channel closed by broker: {self.callback_func}")
elif isinstance(e, pika.exceptions.StreamLostError):
LOG.info("Connection closed by broker")
else:
self.error_func(self, e)
self.join(allow_restart=True)

def join(self, timeout: Optional[float] = ..., allow_restart: bool = True) -> None:
def join(self, timeout: Optional[float] = None) -> None:
"""Terminating consumer channel"""
if self._is_consumer_alive:
try:
self.channel.stop_consuming()
if self.connection.is_open:
self.connection.close()
except Exception as e:
LOG.error(e)
finally:
self._is_consuming = False
if not allow_restart:
self._is_consumer_alive = False
super(BlockingConsumerThread, self).join(timeout=timeout)
self._close_connection()
super(BlockingConsumerThread, self).join(timeout=timeout)

def _close_connection(self):
try:
if self.connection and self.connection.is_open:
self.connection.close()
except pika.exceptions.StreamLostError:
pass
except Exception as e:
LOG.exception(f"Failed to close connection due to unexpected exception: {e}")
self._is_consuming = False
self._is_consumer_alive = False
19 changes: 9 additions & 10 deletions neon_mq_connector/consumers/select_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def on_connection_fail(self, _):
self.connection_failed_attempts += 1
if self.connection_failed_attempts > self.max_connection_failed_attempts:
LOG.error(f'Failed establish MQ connection after {self.max_connection_failed_attempts} attempts')
self.join(timeout=1)
self._close_connection()
else:
self.reconnect()

Expand Down Expand Up @@ -173,7 +173,7 @@ def on_message(self, channel, method, properties, body):

def on_close(self, _, e):
LOG.error(f"Closing MQ connection due to exception: {e}")
self.join()
self.reconnect()

@property
def is_consumer_alive(self) -> bool:
Expand All @@ -192,27 +192,26 @@ def run(self):
self.connection.ioloop.start()
except Exception as e:
LOG.error(f"Failed to start io loop on consumer thread {self.name!r}: {e}")
self.join(allow_restart=True)
self._close_connection()

def _close_connection(self):
def _close_connection(self, mark_consumer_as_dead: bool = True):
try:
if self.connection and not (self.connection.is_closed or self.connection.is_closing):
self.connection.ioloop.stop()
self.connection.close()
except Exception as e:
LOG.error(f"Failed to close connection for Consumer {self.name!r}: {e}")
self._is_consuming = False
if mark_consumer_as_dead:
self._is_consumer_alive = False

def reconnect(self, wait_interval: int = 1):
self._close_connection()
self._close_connection(mark_consumer_as_dead=False)
time.sleep(wait_interval)
self.run()

def join(self, timeout: Optional[float] = None, allow_restart: bool = True) -> None:
def join(self, timeout: Optional[float] = None) -> None:
"""Terminating consumer channel"""
if self.is_consumer_alive and self.is_consuming:
self._is_consuming = False
self._close_connection()
if not allow_restart:
self._is_consumer_alive = False
self._close_connection(mark_consumer_as_dead=True)
super().join(timeout=timeout)
2 changes: 1 addition & 1 deletion tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_not_null_service_id(self):
self.assertIsNotNone(self.connector_instance.service_id)

def tearDown(self):
self.connector_instance.stop_consumers(allow_restart=False)
self.connector_instance.stop_consumers()

@pytest.mark.timeout(30)
def test_mq_messaging(self):
Expand Down

0 comments on commit 2243dc6

Please sign in to comment.