diff --git a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py index 8bf01509..1950ec81 100644 --- a/airbyte_cdk/sources/declarative/concurrent_declarative_source.py +++ b/airbyte_cdk/sources/declarative/concurrent_declarative_source.py @@ -310,7 +310,9 @@ def _group_streams( declarative_stream=declarative_stream ) and hasattr(declarative_stream.retriever, "stream_slicer") - and isinstance(declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor) + and isinstance( + declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor + ) ): stream_state = state_manager.get_stream_state( stream_name=declarative_stream.name, namespace=declarative_stream.namespace @@ -318,16 +320,15 @@ def _group_streams( partition_router = declarative_stream.retriever.stream_slicer._partition_router cursor = self._constructor.create_concurrent_cursor_from_perpartition_cursor( - state_manager=state_manager, - model_type=DatetimeBasedCursorModel, - component_definition=incremental_sync_component_definition, - stream_name=declarative_stream.name, - stream_namespace=declarative_stream.namespace, - config=config or {}, - stream_state=stream_state, - partition_router=partition_router, - ) - + state_manager=state_manager, + model_type=DatetimeBasedCursorModel, + component_definition=incremental_sync_component_definition, + stream_name=declarative_stream.name, + stream_namespace=declarative_stream.namespace, + config=config or {}, + stream_state=stream_state, + partition_router=partition_router, + ) partition_generator = StreamSlicerPartitionGenerator( DeclarativePartitionFactory( diff --git a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py index dc89e5dd..7091931a 100644 --- a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py @@ -104,7 +104,9 @@ def state(self) -> MutableMapping[str, Any]: return state def close_partition(self, partition: Partition) -> None: - self._cursor_per_partition[self._to_partition_key(partition._stream_slice.partition)].close_partition_without_emit(partition=partition) + self._cursor_per_partition[ + self._to_partition_key(partition._stream_slice.partition) + ].close_partition_without_emit(partition=partition) def ensure_at_least_one_state_emitted(self) -> None: """ @@ -124,7 +126,6 @@ def _emit_state_message(self) -> None: ) self._message_repository.emit_message(state_message) - def stream_slices(self) -> Iterable[StreamSlice]: slices = self._partition_router.stream_slices() for partition in slices: @@ -217,7 +218,9 @@ def _set_initial_state(self, stream_state: StreamState) -> None: self._partition_router.set_initial_state(stream_state) def observe(self, record: Record) -> None: - self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)].observe(record) + self._cursor_per_partition[ + self._to_partition_key(record.associated_slice.partition) + ].observe(record) def _to_partition_key(self, partition: Mapping[str, Any]) -> str: return self._partition_serializer.to_partition_key(partition) diff --git a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index d7322709..1529e90e 100644 --- a/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -303,7 +303,10 @@ def get_request_body_json( raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: - if self._to_partition_key(record.associated_slice.partition) not in self._cursor_per_partition: + if ( + self._to_partition_key(record.associated_slice.partition) + not in self._cursor_per_partition + ): partition_state = ( self._state_to_migrate_from if self._state_to_migrate_from @@ -311,7 +314,9 @@ def should_be_synced(self, record: Record) -> bool: ) cursor = self._create_cursor(partition_state) - self._cursor_per_partition[self._to_partition_key(record.associated_slice.partition)] = cursor + self._cursor_per_partition[ + self._to_partition_key(record.associated_slice.partition) + ] = cursor return self._get_cursor(record).should_be_synced( self._convert_record_to_cursor_record(record) ) diff --git a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py index e2663c92..c8b25b07 100644 --- a/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py +++ b/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py @@ -930,7 +930,7 @@ def create_concurrent_cursor_from_perpartition_cursor( config: Config, stream_state: MutableMapping[str, Any], partition_router, - **kwargs: Any, + **kwargs: Any, ) -> ConcurrentPerPartitionCursor: component_type = component_definition.get("type") if component_definition.get("type") != model_type.__name__: @@ -966,15 +966,15 @@ def create_concurrent_cursor_from_perpartition_cursor( # Return the concurrent cursor and state converter return ConcurrentPerPartitionCursor( - cursor_factory=cursor_factory, - partition_router=partition_router, - stream_name=stream_name, - stream_namespace=stream_namespace, - stream_state=stream_state, - message_repository=self._message_repository, # type: ignore - connector_state_manager=state_manager, - cursor_field=cursor_field, - ) + cursor_factory=cursor_factory, + partition_router=partition_router, + stream_name=stream_name, + stream_namespace=stream_namespace, + stream_state=stream_state, + message_repository=self._message_repository, # type: ignore + connector_state_manager=state_manager, + cursor_field=cursor_field, + ) @staticmethod def create_constant_backoff_strategy( @@ -1258,15 +1258,15 @@ def create_declarative_stream( raise ValueError( "Unsupported Slicer is used. PerPartitionWithGlobalCursor should be used here instead" ) - cursor = combined_slicers if isinstance( - combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) - ) else self._create_component_from_model( - model=model.incremental_sync, config=config + cursor = ( + combined_slicers + if isinstance( + combined_slicers, (PerPartitionWithGlobalCursor, GlobalSubstreamCursor) + ) + else self._create_component_from_model(model=model.incremental_sync, config=config) ) - client_side_incremental_sync = { - "cursor": cursor - } + client_side_incremental_sync = {"cursor": cursor} if model.incremental_sync and isinstance(model.incremental_sync, DatetimeBasedCursorModel): cursor_model = model.incremental_sync diff --git a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py index 09ed2bc8..31f6377f 100644 --- a/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py +++ b/airbyte_cdk/sources/declarative/stream_slicers/declarative_partition_generator.py @@ -38,6 +38,7 @@ def create(self, stream_slice: StreamSlice) -> Partition: stream_slice, ) + class DeclarativePartition(Partition): def __init__( self, diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index bbbcfdc2..a093fb5c 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -263,9 +263,7 @@ def _run_read( source = ConcurrentDeclarativeSource( source_config=manifest, config=config, catalog=catalog, state=state ) - messages = list( - source.read(logger=source.logger, config=config, catalog=catalog, state=[]) - ) + messages = list(source.read(logger=source.logger, config=config, catalog=catalog, state=[])) return messages @@ -514,7 +512,9 @@ def test_incremental_parent_state_no_incremental_dependency( output = _run_read(manifest, config, _stream_name, initial_state) output_data = [message.record.data for message in output if message.record] - assert set(tuple(sorted(d.items())) for d in output_data) == set(tuple(sorted(d.items())) for d in expected_records) + assert set(tuple(sorted(d.items())) for d in output_data) == set( + tuple(sorted(d.items())) for d in expected_records + ) final_state = [ orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output