Skip to content

Commit

Permalink
solve problem with connector builder server test reads, rework record…
Browse files Browse the repository at this point in the history
… yielding to be simpler, fix tests, formatting, mypy errors
  • Loading branch information
brianjlai committed Dec 24, 2024
1 parent ef51643 commit 96e6cb1
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 70 deletions.
31 changes: 23 additions & 8 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,11 @@ def _group_streams(
# Some low-code sources use a combination of DeclarativeStream and regular Python streams. We can't inspect
# these legacy Python streams the way we do low-code streams to determine if they are concurrent compatible,
# so we need to treat them as synchronous
if (
isinstance(declarative_stream, DeclarativeStream)
and name_to_stream_mapping[declarative_stream.name]["retriever"]["type"]
if isinstance(declarative_stream, DeclarativeStream) and (
name_to_stream_mapping[declarative_stream.name]["retriever"]["type"]
== "SimpleRetriever"
or name_to_stream_mapping[declarative_stream.name]["retriever"]["type"]
== "AsyncRetriever"
):
incremental_sync_component_definition = name_to_stream_mapping[
declarative_stream.name
Expand All @@ -217,6 +218,11 @@ def _group_streams(
and not incremental_sync_component_definition
)

is_async_job_stream = (
name_to_stream_mapping[declarative_stream.name].get("retriever", {}).get("type")
== "AsyncRetriever"
)

if self._is_datetime_incremental_without_partition_routing(
declarative_stream, incremental_sync_component_definition
):
Expand Down Expand Up @@ -268,15 +274,24 @@ def _group_streams(
elif (
is_substream_without_incremental or is_without_partition_router_or_cursor
) and hasattr(declarative_stream.retriever, "stream_slicer"):
if is_async_job_stream:
async_retriever = declarative_stream.retriever

def async_retriever_factory_method() -> Retriever:
return async_retriever

retriever_factory = async_retriever_factory_method
else:
retriever_factory = self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
{},
)
partition_generator = StreamSlicerPartitionGenerator(
DeclarativePartitionFactory(
declarative_stream.name,
declarative_stream.get_json_schema(),
self._retriever_factory(
name_to_stream_mapping[declarative_stream.name],
config,
{},
),
retriever_factory,
self.message_repository,
),
declarative_stream.retriever.stream_slicer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def next_page_token(
last_record: Optional[Record],
last_page_token_value: Optional[Any] = None,
) -> Optional[Mapping[str, Any]]:
print("At the DefaultPaginator")
next_page_token = self.pagination_strategy.next_page_token(
response=response,
last_page_size=last_page_size,
Expand All @@ -141,7 +142,7 @@ def next_page_token(
else:
return None

def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]:
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
token = next_page_token.get("next_page_token") if next_page_token else None
if token and self.page_token_option and isinstance(self.page_token_option, RequestPath):
# Replace url base to only return the path
Expand Down Expand Up @@ -213,6 +214,9 @@ class PaginatorTestReadDecorator(Paginator):
"""
In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of
pages that are queried throughout a read command.
WARNING: This decorator is not currently thread-safe like the rest of the low-code framework because it has
an internal state to track the current number of pages counted so that it can exit early during a test read
"""

_PAGE_COUNT_BEFORE_FIRST_NEXT_CALL = 1
Expand All @@ -227,6 +231,7 @@ def __init__(self, decorated: Paginator, maximum_number_of_pages: int = 5) -> No
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL

def get_initial_token(self) -> Optional[Any]:
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
return self._decorated.get_initial_token()

def next_page_token(
Expand All @@ -236,6 +241,8 @@ def next_page_token(
last_record: Optional[Record],
last_page_token_value: Optional[Any] = None,
) -> Optional[Mapping[str, Any]]:
print("At the PaginatorTestReadDecorator")
print(f"page count = {self._page_count} and max pages = {self._maximum_number_of_pages}")
if self._page_count >= self._maximum_number_of_pages:
return None

Expand All @@ -244,7 +251,7 @@ def next_page_token(
response, last_page_size, last_record, last_page_token_value
)

def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]:
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
return self._decorated.path(next_page_token)

def get_request_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class NoPagination(Paginator):

parameters: InitVar[Mapping[str, Any]]

def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]:
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
return None

def get_request_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def next_page_token(
pass

@abstractmethod
def path(self, next_page_token: Mapping[str, Any]) -> Optional[str]:
def path(self, next_page_token: Optional[Mapping[str, Any]]) -> Optional[str]:
"""
Returns the URL path to hit to fetch the next page of records
Expand Down
74 changes: 25 additions & 49 deletions airbyte_cdk/sources/declarative/retrievers/simple_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
Callable,
Generator,
Iterable,
List,
Mapping,
Expand Down Expand Up @@ -266,30 +267,18 @@ def _parse_response(
records_schema: Mapping[str, Any],
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Iterable[Union[Record, LastResponseValue]]:
) -> Iterable[Record]:
if not response:
yield from []
return LastResponseValue(last_response=None, last_page_size=0, last_record=None)
else:
self._last_response = response
record_generator = self.record_selector.select_records(
yield from self.record_selector.select_records(
response=response,
stream_state=stream_state,
records_schema=records_schema,
stream_slice=stream_slice,
next_page_token=next_page_token,
)

last_page_size = 0
last_record = None
for record in record_generator:
last_page_size += 1
last_record = record
yield record
return LastResponseValue(
last_response=response, last_page_size=last_page_size, last_record=last_record
)

@property # type: ignore
def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]:
"""The stream's primary key"""
Expand Down Expand Up @@ -357,27 +346,24 @@ def _fetch_next_page(
# This logic is similar to _read_pages in the HttpStream class. When making changes here, consider making changes there as well.
def _read_pages(
self,
records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]],
records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]],
stream_state: Mapping[str, Any],
stream_slice: StreamSlice,
) -> Iterable[StreamData]:
) -> Iterable[Record]:
pagination_complete = False
initial_token = self._paginator.get_initial_token()
next_page_token = {"next_page_token": initial_token} if initial_token else None
next_page_token: Optional[Mapping[str, Any]] = (
{"next_page_token": initial_token} if initial_token else None
)
while not pagination_complete:
response = self._fetch_next_page(stream_state, stream_slice, next_page_token)

last_page_size = 0
last_record = None

# todo: There has to be a better way of yielding records and still emitting a final return value
try:
yield from records_generator_fn(response)
except StopIteration as e:
last_response_value = e.value
if isinstance(last_response_value, LastResponseValue):
last_page_size = last_response_value.last_page_size
last_record = last_response_value.last_record
last_record: Optional[Record] = None
for record in records_generator_fn(response):
last_page_size += 1
last_record = record
yield record

if not response:
pagination_complete = True
Expand All @@ -399,33 +385,28 @@ def _read_pages(

def _read_single_page(
self,
records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]],
records_generator_fn: Callable[[Optional[requests.Response]], Iterable[Record]],
stream_state: Mapping[str, Any],
stream_slice: StreamSlice,
) -> Iterable[StreamData]:
initial_token = stream_state.get("next_page_token")
if initial_token is None:
initial_token = self._paginator.get_initial_token()
next_page_token = {"next_page_token": initial_token} if initial_token else None
next_page_token: Optional[Mapping[str, Any]] = (
{"next_page_token": initial_token} if initial_token else None
)

response = self._fetch_next_page(stream_state, stream_slice, next_page_token)

last_page_size = 0
last_record = None

# todo: There has to be a better way of yielding records and still emitting a final return value
try:
record_generator = records_generator_fn(response)
while True:
yield next(record_generator)
except StopIteration as e:
last_response_value = e.value
if isinstance(last_response_value, LastResponseValue):
last_page_size = last_response_value.last_page_size
last_record = last_response_value.last_record
last_record: Optional[Record] = None
for record in records_generator_fn(response):
last_page_size += 1
last_record = record
yield record

if not response:
next_page_token: Mapping[str, Any] = {FULL_REFRESH_SYNC_COMPLETE_KEY: True}
next_page_token = {FULL_REFRESH_SYNC_COMPLETE_KEY: True}
else:
last_page_token_value = (
next_page_token.get("next_page_token") if next_page_token else None
Expand Down Expand Up @@ -563,18 +544,13 @@ def _parse_records(
stream_state: Mapping[str, Any],
records_schema: Mapping[str, Any],
stream_slice: Optional[StreamSlice],
) -> Iterable[Union[StreamData, LastResponseValue]]:
record_generator = self._parse_response(
) -> Iterable[Record]:
yield from self._parse_response(
response,
stream_slice=stream_slice,
stream_state=stream_state,
records_schema=records_schema,
)
try:
while True:
yield next(record_generator)
except StopIteration as e:
return e.value

def must_deduplicate_query_params(self) -> bool:
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,13 +651,15 @@ def test_group_streams():
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

# 1 full refresh stream, 2 incremental streams, 1 substream w/o incremental, 1 list based substream w/o incremental
assert len(concurrent_streams) == 5
# 1 async job stream
assert len(concurrent_streams) == 6
(
concurrent_stream_0,
concurrent_stream_1,
concurrent_stream_2,
concurrent_stream_3,
concurrent_stream_4,
concurrent_stream_5,
) = concurrent_streams
assert isinstance(concurrent_stream_0, DefaultStream)
assert concurrent_stream_0.name == "party_members"
Expand All @@ -669,13 +671,13 @@ def test_group_streams():
assert concurrent_stream_3.name == "party_members_skills"
assert isinstance(concurrent_stream_4, DefaultStream)
assert concurrent_stream_4.name == "arcana_personas"
assert isinstance(concurrent_stream_5, DefaultStream)
assert concurrent_stream_5.name == "async_job_stream"

# 1 substream w/ incremental, 1 stream with async retriever
assert len(synchronous_streams) == 2
assert len(synchronous_streams) == 1
assert isinstance(synchronous_streams[0], DeclarativeStream)
assert synchronous_streams[0].name == "palace_enemies"
assert isinstance(synchronous_streams[1], DeclarativeStream)
assert synchronous_streams[1].name == "async_job_stream"


@freezegun.freeze_time(time_to_freeze=datetime(2024, 9, 1, 0, 0, 0, 0, tzinfo=timezone.utc))
Expand Down Expand Up @@ -1456,10 +1458,10 @@ def test_streams_with_stream_state_interpolation_should_be_synchronous():
)
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

# 1 full refresh stream, 2 with parent stream without incremental dependency
assert len(concurrent_streams) == 3
# 2 incremental stream with interpolation on state (locations and party_members), 1 incremental with parent stream (palace_enemies), 1 stream with async retriever
assert len(synchronous_streams) == 4
# 1 full refresh stream, 2 with parent stream without incremental dependency, 1 stream with async retriever
assert len(concurrent_streams) == 4
# 2 incremental stream with interpolation on state (locations and party_members), 1 incremental with parent stream (palace_enemies)
assert len(synchronous_streams) == 3


def test_given_partition_routing_and_incremental_sync_then_stream_is_not_concurrent():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ def _create_page(response_body):
{"ABC": 0, "id": 1},
{"AED": 1, "id": 2},
],
[call({}, {})],
[call({}, {}, None)],
),
(
"test_read_with_pagination_no_partitions",
Expand Down

0 comments on commit 96e6cb1

Please sign in to comment.