From cc8ec474fe0536e5d46eba6e2a7dc3c0a0d92693 Mon Sep 17 00:00:00 2001 From: Artem Inzhyyants <36314070+artem1205@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:46:58 +0100 Subject: [PATCH] fix: fallback to `json` if `orjson` cannot serialize value (#210) Signed-off-by: Artem Inzhyyants Co-authored-by: maxi297 --- airbyte_cdk/entrypoint.py | 14 +++++++++++++- unit_tests/test_entrypoint.py | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/airbyte_cdk/entrypoint.py b/airbyte_cdk/entrypoint.py index b37c814f..a5052a57 100644 --- a/airbyte_cdk/entrypoint.py +++ b/airbyte_cdk/entrypoint.py @@ -5,6 +5,7 @@ import argparse import importlib import ipaddress +import json import logging import os.path import socket @@ -46,6 +47,7 @@ VALID_URL_SCHEMES = ["https"] CLOUD_DEPLOYMENT_MODE = "cloud" +_HAS_LOGGED_FOR_SERIALIZATION_ERROR = False class AirbyteEntrypoint(object): @@ -291,7 +293,17 @@ def set_up_secret_filter(config: TConfig, connection_specification: Mapping[str, @staticmethod def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str: - return orjson.dumps(AirbyteMessageSerializer.dump(airbyte_message)).decode() + global _HAS_LOGGED_FOR_SERIALIZATION_ERROR + serialized_message = AirbyteMessageSerializer.dump(airbyte_message) + try: + return orjson.dumps(serialized_message).decode() + except Exception as exception: + if not _HAS_LOGGED_FOR_SERIALIZATION_ERROR: + logger.warning( + f"There was an error during the serialization of an AirbyteMessage: `{exception}`. This might impact the sync performances." + ) + _HAS_LOGGED_FOR_SERIALIZATION_ERROR = True + return json.dumps(serialized_message) @classmethod def extract_state(cls, args: List[str]) -> Optional[Any]: diff --git a/unit_tests/test_entrypoint.py b/unit_tests/test_entrypoint.py index 2e50c11f..e906e8b3 100644 --- a/unit_tests/test_entrypoint.py +++ b/unit_tests/test_entrypoint.py @@ -768,3 +768,26 @@ def test_handle_record_counts( assert isinstance( actual_message.state.sourceStats.recordCount, float ), "recordCount value should be expressed as a float" + + +def test_given_serialization_error_using_orjson_then_fallback_on_json( + entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock +): + parsed_args = Namespace( + command="read", config="config_path", state="statepath", catalog="catalogpath" + ) + record = AirbyteMessage( + record=AirbyteRecordMessage( + stream="stream", data={"data": 7046723166326052303072}, emitted_at=1 + ), + type=Type.RECORD, + ) + mocker.patch.object(MockSource, "read_state", return_value={}) + mocker.patch.object(MockSource, "read_catalog", return_value={}) + mocker.patch.object(MockSource, "read", return_value=[record, record]) + + messages = list(entrypoint.run(parsed_args)) + + # There will be multiple messages here because the fixture `entrypoint` sets a control message. We only care about records here + record_messages = list(filter(lambda message: "RECORD" in message, messages)) + assert len(record_messages) == 2