Skip to content

Commit 9943028

Browse files
Create S3 client for smart_open from session (#4886)
* Create S3 client for smart_open from session * Fix the tests by determining which method to use via the endpoint_url * Replace rows on conflict
1 parent f54612d commit 9943028

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

catalog/dags/data_augmentation/rekognition/add_rekognition_labels.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ def _insert_tags(tags_buffer: types.TagsBuffer, postgres_conn_id: str):
5555
postgres_conn_id=postgres_conn_id,
5656
default_statement_timeout=constants.INSERT_TIMEOUT,
5757
)
58-
postgres.insert_rows(constants.TEMP_TABLE_NAME, tags_buffer, executemany=True)
58+
postgres.insert_rows(
59+
constants.TEMP_TABLE_NAME,
60+
tags_buffer,
61+
executemany=True,
62+
replace=True,
63+
)
5964

6065

6166
@task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
@@ -76,7 +81,15 @@ def parse_and_insert_labels(
7681
deserialize_json=True,
7782
)
7883

79-
s3_client = S3Hook(aws_conn_id=AWS_CONN_ID).get_client_type("s3")
84+
# If an endpoint is defined for the hook, use the `get_client_type` method
85+
# to retrieve the S3 client. Otherwise, create the client from the session
86+
# so that Airflow doesn't override the endpoint default we want on the S3 client
87+
hook = S3Hook(aws_conn_id=AWS_CONN_ID)
88+
if hook.conn_config.endpoint_url:
89+
get_client = hook.get_client_type
90+
else:
91+
get_client = hook.get_session().client
92+
s3_client = get_client("s3")
8093
with smart_open.open(
8194
f"{s3_bucket}/{s3_prefix}",
8295
transport_params={"buffer_size": file_buffer_size, "client": s3_client},

0 commit comments

Comments
 (0)