Skip to content

Commit

Permalink
Extend update_license_url tasks timeout to a day and a half (#4209)
Browse files Browse the repository at this point in the history
  • Loading branch information
krysal authored Apr 26, 2024
1 parent 25b9137 commit b55d6d5
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ def run_sql(


@task
def get_license_groups(
query: str, dag_task: AbstractOperator = None
) -> list[tuple[str, str]]:
def get_license_groups(query: str, ti=None) -> list[tuple[str, str]]:
"""
Get license groups of rows that don't have a `license_url` in their
`meta_data` field.
:return: List of (license, version) tuples.
"""
license_groups = run_sql(query, dag_task=dag_task)
license_groups = run_sql(query, dag_task=ti.task)

total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
Expand All @@ -84,18 +82,14 @@ def get_license_groups(
return [(group[0], group[1]) for group in license_groups]


@task(max_active_tis_per_dag=1)
def update_license_url(
license_group: tuple[str, str],
batch_size: int,
dag_task: AbstractOperator = None,
) -> int:
@task(max_active_tis_per_dag=1, execution_timeout=timedelta(hours=36))
def update_license_url(license_group: tuple[str, str], batch_size: int, ti=None) -> int:
"""
Add license_url to meta_data batching all records with the same license.
:param license_group: tuple of license and version
:param batch_size: number of records to update in one update statement
:param dag_task: automatically passed by Airflow, used to set the execution timeout.
:param ti: automatically passed by Airflow, used to set the execution timeout.
"""
license_, version = license_group
license_info = get_license_info_from_license_pair(license_, version)
Expand Down Expand Up @@ -135,7 +129,7 @@ def update_license_url(
method="run",
handler=RETURN_ROW_COUNT,
autocommit=True,
dag_task=dag_task,
dag_task=ti.task,
)
total_updated += updated_count
logger.info(f"Updated {total_updated} rows with {license_url}.")
Expand All @@ -144,18 +138,18 @@ def update_license_url(


@task(trigger_rule=TriggerRule.ALL_DONE)
def report_completion(updated, query: str, dag_task: AbstractOperator = None):
def report_completion(updated, query: str, ti=None):
"""
Check for null in `meta_data` and send a message to Slack with the statistics
of the DAG run.
:param updated: total number of records updated
:param query: SQL query to get the count of records left with `license_url` as NULL
:param dag_task: automatically passed by Airflow, used to set the execution timeout.
:param ti: automatically passed by Airflow, used to set the execution timeout.
"""
total_updated = sum(updated) if updated else 0

license_groups = run_sql(query, dag_task=dag_task)
license_groups = run_sql(query, dag_task=ti.task)
total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups
Expand Down

0 comments on commit b55d6d5

Please sign in to comment.