diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py index b79c30188..bf26ddf7c 100644 --- a/quixstreams/state/recovery.py +++ b/quixstreams/state/recovery.py @@ -1,6 +1,6 @@ import logging import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from confluent_kafka import OFFSET_BEGINNING from confluent_kafka import TopicPartition as ConfluentPartition @@ -62,6 +62,8 @@ def __init__( self._committed_offsets = committed_offsets self._recovery_consume_position: Optional[int] = None self._initial_offset: Optional[int] = None + self._invalid_offset_count = 0 # Track consecutive invalid offset attempts + self._last_valid_position_time: Optional[float] = None def __repr__(self): return ( @@ -128,6 +130,35 @@ def recovery_consume_position(self) -> Optional[int]: def had_recovery_changes(self) -> bool: return self._initial_offset != self.offset + def increment_invalid_offset_count(self) -> int: + """ + Increment the counter for consecutive invalid offset attempts. + Returns the new count. + """ + self._invalid_offset_count += 1 + return self._invalid_offset_count + + def reset_invalid_offset_count(self): + """ + Reset the invalid offset counter when a valid position is obtained. + """ + self._invalid_offset_count = 0 + self._last_valid_position_time = time.monotonic() + + @property + def invalid_offset_count(self) -> int: + """ + Get the number of consecutive invalid offset attempts. + """ + return self._invalid_offset_count + + @property + def last_valid_position_time(self) -> Optional[float]: + """ + Get the time when a valid position was last obtained. + """ + return self._last_valid_position_time + def recover_from_changelog_message( self, changelog_message: SuccessfulConfluentKafkaMessageProto ): @@ -315,12 +346,18 @@ class RecoveryManager: Recovery is attempted from the `Application` after any new partition assignment. """ + # Maximum number of consecutive invalid offset attempts before failing loudly + # At 10-second progress logging intervals, 60 attempts = ~10 minutes + MAX_INVALID_OFFSET_ATTEMPTS = 60 + def __init__(self, consumer: BaseConsumer, topic_manager: TopicManager): self._running = False self._consumer = consumer self._topic_manager = topic_manager self._recovery_partitions: Dict[int, Dict[str, RecoveryPartition]] = {} self._last_progress_logged_time = time.monotonic() + # Cache position results to avoid double calls in same iteration + self._position_cache: Dict[str, Tuple[float, ConfluentPartition]] = {} @property def partitions(self) -> Dict[int, Dict[str, RecoveryPartition]]: @@ -515,6 +552,9 @@ def _revoke_recovery_partitions(self, recovery_partitions: List[RecoveryPartitio ) for rp in recovery_partitions: del self._recovery_partitions[rp.partition_num][rp.changelog_name] + # Clean up position cache for revoked partition + cache_key = f"{rp.changelog_name}:{rp.partition_num}" + self._position_cache.pop(cache_key, None) for partition_num in partition_nums: if not self._recovery_partitions[partition_num]: del self._recovery_partitions[partition_num] @@ -536,6 +576,14 @@ def _update_recovery_status(self): rp_revokes = [] for rp in dict_values(self._recovery_partitions): position = self._get_changelog_offset(rp) + if position is None: + # Skip status update if position is not yet valid (e.g., during rebalancing) + # Will retry on next poll cycle + logger.debug( + f"Skipping recovery status update for {rp}: position not available" + ) + continue + rp.set_recovery_consume_position(position) if rp.finished_recovery_check: rp_revokes.append(rp) @@ -561,25 +609,118 @@ def _recovery_loop(self) -> None: rp = self._recovery_partitions[msg.partition()][msg.topic()] rp.recover_from_changelog_message(changelog_message=msg) + def _get_position_with_cache(self, rp: RecoveryPartition) -> ConfluentPartition: + """ + Get the consumer position for a RecoveryPartition, using cache to avoid + multiple calls in the same iteration. + + :param rp: RecoveryPartition to get position for + :return: ConfluentPartition with offset and error information + """ + cache_key = f"{rp.changelog_name}:{rp.partition_num}" + current_time = time.monotonic() + + # Check if we have a fresh cached value (within last second) + if cache_key in self._position_cache: + cached_time, cached_position = self._position_cache[cache_key] + if current_time - cached_time < 1.0: + return cached_position + + # Query position and cache it + position_tp = self._consumer.position( + [ConfluentPartition(rp.changelog_name, rp.partition_num)] + )[0] + self._position_cache[cache_key] = (current_time, position_tp) + return position_tp + def _log_recovery_progress(self) -> None: """ Periodically log the recovery progress of all RecoveryPartitions. """ if self._last_progress_logged_time < time.monotonic() - 10: for rp in dict_values(self._recovery_partitions): - last_consumed_offset = self._get_changelog_offset(rp) - 1 - logger.info( - f"Recovery progress for {rp}: {last_consumed_offset} / {rp.changelog_highwater}" - ) + # Use cached position to avoid redundant network calls + position_tp = self._get_position_with_cache(rp) + + if position_tp.error: + count = rp.invalid_offset_count + log_level = logger.warning if count > 30 else logger.info + log_level( + f"Recovery progress for {rp}: position unavailable " + f"(error: {position_tp.error}, attempts: {count})" + ) + elif position_tp.offset < 0: + count = rp.invalid_offset_count + log_level = logger.warning if count > 30 else logger.info + log_level( + f"Recovery progress for {rp}: position not yet established " + f"(offset: {position_tp.offset}, attempts: {count})" + ) + else: + last_consumed_offset = position_tp.offset - 1 + logger.info( + f"Recovery progress for {rp}: {last_consumed_offset} / {rp.changelog_highwater}" + ) self._last_progress_logged_time = time.monotonic() - def _get_changelog_offset(self, rp: RecoveryPartition) -> int: + def _get_changelog_offset(self, rp: RecoveryPartition) -> Optional[int]: """ Get the current offset of the changelog partition. + + Returns None if the position is not yet established (e.g., during rebalancing) + or if there's an error querying the position. + + Tracks consecutive invalid offset attempts and raises an exception if the + threshold is exceeded. + + :return: The current offset, or None if position is invalid/unavailable + :raises RuntimeError: If position remains invalid beyond MAX_INVALID_OFFSET_ATTEMPTS """ - return self._consumer.position( - [ConfluentPartition(rp.changelog_name, rp.partition_num)] - )[0].offset + # Use cached position to avoid redundant network calls + position_tp = self._get_position_with_cache(rp) + + # Check for Kafka errors (e.g., during rebalancing) + if position_tp.error: + count = rp.increment_invalid_offset_count() + logger.debug( + f"Cannot get position for {rp} due to Kafka error: {position_tp.error}. " + f"This is expected during rebalancing (attempt {count}/{self.MAX_INVALID_OFFSET_ATTEMPTS})." + ) + self._check_invalid_offset_threshold(rp, f"error: {position_tp.error}") + return None + + # Check for special Kafka offset values (OFFSET_INVALID=-1001, OFFSET_STORED=-1000, etc.) + offset = position_tp.offset + if offset < 0: + count = rp.increment_invalid_offset_count() + logger.debug( + f"Position not yet established for {rp}: offset={offset}. " + f"This is expected during rebalancing (attempt {count}/{self.MAX_INVALID_OFFSET_ATTEMPTS})." + ) + self._check_invalid_offset_threshold(rp, f"offset={offset}") + return None + + # Valid offset obtained - reset the counter + rp.reset_invalid_offset_count() + return offset + + def _check_invalid_offset_threshold(self, rp: RecoveryPartition, reason: str): + """ + Check if the invalid offset count exceeds the threshold and fail loudly if so. + + :param rp: RecoveryPartition being checked + :param reason: Description of why the offset is invalid + :raises RuntimeError: If threshold is exceeded + """ + if rp.invalid_offset_count > self.MAX_INVALID_OFFSET_ATTEMPTS: + error_msg = ( + f"Recovery stuck for {rp}: position has been invalid for " + f"{rp.invalid_offset_count} consecutive attempts ({reason}). " + f"This indicates a serious issue with the Kafka consumer or broker. " + f"Last valid position was at {rp.last_valid_position_time or 'never'}." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) def stop_recovery(self): self._running = False diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py index 70a3ecead..257b801bb 100644 --- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py +++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch import pytest -from confluent_kafka import TopicPartition +from confluent_kafka import OFFSET_INVALID, TopicPartition from confluent_kafka import TopicPartition as ConfluentPartition from quixstreams.kafka import Consumer @@ -350,6 +350,115 @@ def test_do_recovery_no_partitions_assigned(self, recovery_manager_factory): # Check that consumer.poll() is not called assert not consumer.poll.called + def test_do_recovery_handles_invalid_offset_during_rebalance( + self, + recovery_manager_factory, + topic_manager_factory, + ): + """ + Test that RecoveryManager handles OFFSET_INVALID gracefully when + consumer.position() returns invalid offset during rebalancing. + + This reproduces GitHub issue #1067 where recovery gets stuck in infinite + loop when partition stays assigned through rebalance but position becomes + temporarily invalid. + """ + topic_name = str(uuid.uuid4()) + store_name = "default" + lowwater, highwater = 0, 10 + + # Setup topics + topic_manager = topic_manager_factory() + data_topic = topic_manager.topic(topic_name) + changelog_topic = topic_manager.changelog_topic( + stream_id=topic_name, + store_name=store_name, + config=data_topic.broker_config, + ) + + data_tp = TopicPartition(topic=data_topic.name, partition=0) + changelog_tp = TopicPartition(topic=changelog_topic.name, partition=0) + assignment = [data_tp, changelog_tp] + + # Create changelog message for recovery + # Message at offset (highwater - 1) means after consuming it, + # position will be at highwater, completing recovery + changelog_message = ConfluentKafkaMessageStub( + topic=changelog_topic.name, + partition=0, + offset=highwater - 1, + key=b"key", + value=b"value", + headers=[(CHANGELOG_CF_MESSAGE_HEADER, b"default")], + ) + + # Create mocked consumer + consumer = MagicMock(spec_set=Consumer) + consumer.assignment.return_value = assignment + + # Simulate rebalancing scenario: + # 1. First poll returns None → OFFSET_INVALID detected and skipped + # 2. Second poll returns message → processes it + # 3. Third poll returns None → position check shows recovery complete + consumer.poll.side_effect = [None, changelog_message, None] + + # Simulate consumer.position() behavior during rebalance: + # 1. First call returns OFFSET_INVALID (mid-rebalance - gets skipped) + # 2. Subsequent calls return highwater (recovery complete after message) + position_call_count = 0 + + def position_side_effect(partitions): + nonlocal position_call_count + position_call_count += 1 + if position_call_count == 1: + # Mid-rebalance: return OFFSET_INVALID (will be skipped by fix) + return [ConfluentPartition(changelog_topic.name, 0, OFFSET_INVALID)] + else: + # After OFFSET_INVALID resolved, position is at highwater + return [ConfluentPartition(changelog_topic.name, 0, highwater)] + + consumer.position.side_effect = position_side_effect + + # Mock store partition + store_partition = MagicMock(spec_set=StorePartition) + # Stored offset is one before the message we'll recover + store_partition.get_changelog_offset.return_value = highwater - 2 + + # Setup recovery manager + recovery_manager = recovery_manager_factory( + consumer=consumer, topic_manager=topic_manager + ) + + consumer.get_watermark_offsets.return_value = (lowwater, highwater) + recovery_manager.assign_partition( + topic=topic_name, + partition=0, + committed_offsets={topic_name: -1001}, + store_partitions={store_name: store_partition}, + ) + + # Trigger recovery - should complete successfully despite OFFSET_INVALID + recovery_manager.do_recovery() + + # Verify recovery completed successfully + assert ( + not recovery_manager.partitions + ), "Recovery should complete and unassign all partitions" + assert consumer.poll.call_count == 3, ( + "Should poll three times: " + "1) None (OFFSET_INVALID detected and skipped), " + "2) message consumed, " + "3) None (position==highwater, completes recovery)" + ) + assert position_call_count == 2, ( + "Should call position twice: " + "first returns OFFSET_INVALID (skipped), " + "second returns highwater (recovery complete)" + ) + + # Verify the changelog message was processed + store_partition.recover_from_changelog_message.assert_called_once() + @pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True) class TestRecoveryManagerRecover: