diff --git a/tests/processor/test_base_processor.py b/tests/processor/test_base_processor.py index dd889c4..7294e3a 100644 --- a/tests/processor/test_base_processor.py +++ b/tests/processor/test_base_processor.py @@ -18,7 +18,7 @@ def test_initialiseBaseProcessor(): mock_task = mock.Mock() mock_task.application_id = 'test_id' mock_task_id = TaskId('test_group', 0) - mock_context = ProcessorContext(mock_task_id, mock_task, None, None, {}) + mock_context = ProcessorContext(mock_task_id, mock_task, None, {}, None) bp = wks_processor.BaseProcessor() bp.initialise('my-name', mock_context) diff --git a/tests/processor/test_sink_processor.py b/tests/processor/test_sink_processor.py index 3e50335..b21a8c7 100644 --- a/tests/processor/test_sink_processor.py +++ b/tests/processor/test_sink_processor.py @@ -26,7 +26,7 @@ def test_sinkProcessorProcess(): mock_task = mock.Mock() mock_task.application_id = 'test_id' mock_task_id = TaskId('test_group', 0) - processor_context = wks_processor.ProcessorContext(mock_task_id, mock_task, None, None, {}) + processor_context = wks_processor.ProcessorContext(mock_task_id, mock_task, None, {}, None) processor_context.record_collector = mock.MagicMock() sink = wks_processor.SinkProcessor('topic1') diff --git a/tests/state/test_change_logging.py b/tests/state/test_change_logging.py new file mode 100644 index 0000000..a20581e --- /dev/null +++ b/tests/state/test_change_logging.py @@ -0,0 +1,100 @@ +from collections import deque +from typing import Iterator, Tuple + +import pytest + +from winton_kafka_streams.processor.serialization.serdes import IntegerSerde, StringSerde +from winton_kafka_streams.state.in_memory.in_memory_state_store import InMemoryStateStore +from winton_kafka_streams.state.logging.change_logging_state_store import ChangeLoggingStateStore +from winton_kafka_streams.state.logging.store_change_logger import StoreChangeLogger + + +class MockChangeLogger(StoreChangeLogger): + def __init__(self): + super(MockChangeLogger, self).__init__() + self.change_log = deque() + + def log_change(self, key: bytes, value: bytes) -> None: + self.change_log.append((key, value)) + + def __iter__(self) -> Iterator[Tuple[bytes, bytes]]: + return self.change_log.__iter__() + + +def _get_store(): + inner_store = InMemoryStateStore('teststore', StringSerde(), IntegerSerde(), False) + store = ChangeLoggingStateStore('teststore', StringSerde(), IntegerSerde(), False, inner_store) + store._get_change_logger = lambda context: MockChangeLogger() + store.initialize(None, None) + return store + + +def test_change_store_is_dict(): + store = _get_store() + kv_store = store.get_key_value_store() + + kv_store['a'] = 1 + assert kv_store['a'] == 1 + + kv_store['a'] = 2 + assert kv_store['a'] == 2 + + del kv_store['a'] + assert kv_store.get('a') is None + with pytest.raises(KeyError): + _ = kv_store['a'] + + +def test_change_log_is_written_to(): + store = _get_store() + kv_store = store.get_key_value_store() + + kv_store['a'] = 12 + assert len(store.change_logger.change_log) == 1 + assert store.change_logger.change_log[0] == (b'a', b'\x0c\0\0\0') + + del kv_store['a'] + assert len(store.change_logger.change_log) == 2 + assert store.change_logger.change_log[1] == (b'a', b'') + + +def test_can_replay_log(): + store = _get_store() + kv_store = store.get_key_value_store() + + kv_store['a'] = 12 + kv_store['b'] = 123 + del kv_store['a'] + + keys = [] + values = [] + + for k, v in store.change_logger: + keys.append(k) + values.append(v) + + assert keys == [b'a', b'b', b'a'] + assert values == [b'\x0c\0\0\0', b'\x7b\0\0\0', b''] + + +def test_rebuild_state_from_log(): + store = _get_store() + kv_store = store.get_key_value_store() + + kv_store['a'] = 12 + kv_store['b'] = 123 + del kv_store['a'] + kv_store['c'] = 1234 + + log = store.change_logger + + # reattach previous changelog and run initialize() + store = _get_store() + kv_store = store.get_key_value_store() + store._get_change_logger = lambda context: log + store.initialize(None, None) + + with pytest.raises(KeyError): + _ = kv_store['a'] + assert kv_store['b'] == 123 + assert kv_store['c'] == 1234 diff --git a/tests/state/test_in_memory_key_value_store.py b/tests/state/test_in_memory_key_value_store.py index d34af0e..9f31430 100644 --- a/tests/state/test_in_memory_key_value_store.py +++ b/tests/state/test_in_memory_key_value_store.py @@ -4,7 +4,7 @@ from winton_kafka_streams.state.in_memory.in_memory_state_store import InMemoryStateStore -def test_inMemoryKeyValueStore(): +def test_in_memory_key_value_store(): store = InMemoryStateStore('teststore', BytesSerde(), BytesSerde(), False) kv_store = store.get_key_value_store() diff --git a/winton_kafka_streams/processor/_context.py b/winton_kafka_streams/processor/_context.py index bc3aa0a..8a77122 100644 --- a/winton_kafka_streams/processor/_context.py +++ b/winton_kafka_streams/processor/_context.py @@ -28,10 +28,9 @@ class Context: """ - def __init__(self, _state_record_collector, _state_stores): + def __init__(self, _state_stores): self.current_node = None self.current_record = None - self.state_record_collector = _state_record_collector self._state_stores = _state_stores def send(self, topic, key, obj): diff --git a/winton_kafka_streams/processor/_stream_task.py b/winton_kafka_streams/processor/_stream_task.py index 2e1f102..8605710 100644 --- a/winton_kafka_streams/processor/_stream_task.py +++ b/winton_kafka_streams/processor/_stream_task.py @@ -4,7 +4,6 @@ from confluent_kafka import TopicPartition from confluent_kafka.cimpl import KafkaException, KafkaError -from winton_kafka_streams.processor.serialization.serdes import BytesSerde from ..errors._kafka_error_codes import _get_invalid_producer_epoch_code from ._punctuation_queue import PunctuationQueue from ._record_collector import RecordCollector @@ -69,11 +68,9 @@ def __init__(self, _task_id, _application_id, _partitions, _topology_builder, _c self.value_serde.configure(self.config, False) self.record_collector = RecordCollector(self.producer, self.key_serde, self.value_serde) - self.state_record_collector = RecordCollector(self.producer, BytesSerde(), BytesSerde()) self.queue = queue.Queue() - self.context = ProcessorContext(self.task_id, self, self.record_collector, - self.state_record_collector, self.state_stores) + self.context = ProcessorContext(self.task_id, self, self.record_collector, self.state_stores, self.config) self.punctuation_queue = PunctuationQueue(self.punctuate) # TODO: use the configured timestamp extractor. diff --git a/winton_kafka_streams/processor/processor_context.py b/winton_kafka_streams/processor/processor_context.py index a537289..9eaf561 100644 --- a/winton_kafka_streams/processor/processor_context.py +++ b/winton_kafka_streams/processor/processor_context.py @@ -16,14 +16,15 @@ class ProcessorContext(_context.Context): values to downstream processors. """ - def __init__(self, _task_id, _task, _record_collector, _state_record_collector, _state_stores): + def __init__(self, _task_id, _task, _record_collector, _state_stores, _config): - super().__init__(_state_record_collector, _state_stores) + super().__init__(_state_stores) self.application_id = _task.application_id self.task_id = _task_id self.task = _task self.record_collector = _record_collector + self.config = _config def commit(self): """ diff --git a/winton_kafka_streams/state/logging/change_logging_state_store.py b/winton_kafka_streams/state/logging/change_logging_state_store.py index f422094..f37024e 100644 --- a/winton_kafka_streams/state/logging/change_logging_state_store.py +++ b/winton_kafka_streams/state/logging/change_logging_state_store.py @@ -3,7 +3,7 @@ from winton_kafka_streams.processor.serialization import Serde from ..key_value_state_store import KeyValueStateStore from ..state_store import StateStore -from .store_change_logger import StoreChangeLogger +from .store_change_logger import StoreChangeLogger, StoreChangeLoggerImpl KT = TypeVar('KT') # Key type. VT = TypeVar('VT') # Value type. @@ -14,12 +14,21 @@ def __init__(self, name: str, key_serde: Serde[KT], value_serde: Serde[VT], log inner_state_store: StateStore[KT, VT]) -> None: super().__init__(name, key_serde, value_serde, logging_enabled) self.inner_state_store = inner_state_store - self.change_logger = None + self.change_logger: StoreChangeLogger = None + + def _get_change_logger(self, context) -> StoreChangeLogger: + return StoreChangeLoggerImpl(self.inner_state_store.name, context) def initialize(self, context, root): self.inner_state_store.initialize(context, root) - self.change_logger = StoreChangeLogger(self.inner_state_store.name, context) - # TODO rebuild state into inner here + self.change_logger = self._get_change_logger(context) + for k, v in self.change_logger: + deserialized_key = self.deserialize_key(k) + inner_kv_store = self.inner_state_store.get_key_value_store() + if v == b'': + del inner_kv_store[deserialized_key] + else: + inner_kv_store[deserialized_key] = self.deserialize_value(v) def get_key_value_store(self) -> KeyValueStateStore[KT, VT]: parent = self diff --git a/winton_kafka_streams/state/logging/store_change_logger.py b/winton_kafka_streams/state/logging/store_change_logger.py index 29b5730..fafbbd0 100644 --- a/winton_kafka_streams/state/logging/store_change_logger.py +++ b/winton_kafka_streams/state/logging/store_change_logger.py @@ -1,10 +1,48 @@ -class StoreChangeLogger: +from abc import abstractmethod +from typing import Iterator, Iterable, Tuple + +from confluent_kafka.cimpl import TopicPartition, OFFSET_BEGINNING, KafkaError + +from winton_kafka_streams.processor.serialization.serdes import BytesSerde +from winton_kafka_streams.kafka_client_supplier import KafkaClientSupplier +from winton_kafka_streams.processor._record_collector import RecordCollector + + +class StoreChangeLogger(Iterable[Tuple[bytes, bytes]]): + @abstractmethod + def log_change(self, key: bytes, value: bytes) -> None: + pass + + @abstractmethod + def __iter__(self) -> Iterator[Tuple[bytes, bytes]]: + pass + + +class StoreChangeLoggerImpl(StoreChangeLogger): def __init__(self, store_name, context) -> None: self.topic = f'{context.application_id}-{store_name}-changelog' self.context = context self.partition = context.task_id.partition - self.record_collector = context.state_record_collector + self.client_supplier = KafkaClientSupplier(self.context.config) + self.record_collector = RecordCollector(self.client_supplier.producer(), BytesSerde(), BytesSerde()) def log_change(self, key: bytes, value: bytes) -> None: if self.record_collector: self.record_collector.send(self.topic, key, value, self.context.timestamp, partition=self.partition) + + def __iter__(self) -> Iterator[Tuple[bytes, bytes]]: + consumer = self.client_supplier.consumer() + partition = TopicPartition(self.topic, self.partition, OFFSET_BEGINNING) + consumer.assign([partition]) + + class TopicIterator(Iterator[Tuple[bytes, bytes]]): + def __next__(self) -> Tuple[bytes, bytes]: + msg = consumer.poll(1.0) + if msg.error(): + if msg.error().code() == KafkaError._PARTITION_EOF: + raise StopIteration() + if msg is None: + raise StopIteration() + return msg.key(), msg.value() + + return TopicIterator()