From 7f60bb59aab6bb4f935f8c2f5020d4e2bfe960ee Mon Sep 17 00:00:00 2001 From: Staci Cooper Date: Mon, 11 Mar 2024 16:11:08 -0700 Subject: [PATCH] Automatically exclude the running DAG from the ids to check --- catalog/dags/common/sensors/utils.py | 91 ++++++++++++------- .../data_refresh/create_filtered_index_dag.py | 1 - .../data_refresh/data_refresh_task_factory.py | 6 +- .../staging_database_restore_dag.py | 1 - .../create_new_es_index_dag.py | 2 +- ...roportional_by_source_staging_index_dag.py | 2 +- .../point_es_alias/point_es_alias_dag.py | 10 +- .../recreate_full_staging_index_dag.py | 2 +- 8 files changed, 68 insertions(+), 47 deletions(-) diff --git a/catalog/dags/common/sensors/utils.py b/catalog/dags/common/sensors/utils.py index 90cf9266ef5..0d0085295ee 100644 --- a/catalog/dags/common/sensors/utils.py +++ b/catalog/dags/common/sensors/utils.py @@ -10,6 +10,9 @@ from common.constants import REFRESH_POKE_INTERVAL +THREE_DAYS = 60 * 60 * 24 * 3 + + def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime: """ Retrieve the most recent DAG run's execution date. @@ -36,21 +39,40 @@ def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime: return [] -def _get_dags_with_tag(tag: str, excluded_dag_ids: list[str], session=None): - """Get a list of DAG ids with the given tag, optionally excluding certain ids.""" - if not excluded_dag_ids: - excluded_dag_ids = [] - +@task +def get_dags_with_concurrency_tag( + tag: str, excluded_dag_ids: list[str], session=None, **context +): + """ + Get a list of DAG ids with the given tag. The id of the running DAG is excluded, + as well as any ids in the `excluded_dag_ids` list. + """ dags = session.query(DagModel).filter(DagModel.tags.any(DagTag.name == tag)).all() - - # Return just the ids, excluding excluded_dag_ids - ids = [dag.dag_id for dag in dags if dag.dag_id not in excluded_dag_ids] - return ids - - -def wait_for_external_dag(external_dag_id: str, task_id: str | None = None): + dag_ids = [dag.dag_id for dag in dags] + + running_dag_id = context["dag"].dag_id + if running_dag_id not in dag_ids: + raise ValueError( + f"The `{running_dag_id}` DAG tried preventing concurrency with the `{tag}`," + " tag, but does not have the tag itself. To ensure that other DAGs with this" + f" tag will also avoid running concurrently with `{running_dag_id}`, it must" + f"have the `{tag}` tag applied." + ) + + # Return just the ids of DAGs to prevent concurrency with. This excludes the running dag id, + # and any supplied `excluded_dag_ids` + return [id for id in dag_ids if id not in [*excluded_dag_ids, running_dag_id]] + + +@task +def wait_for_external_dag( + external_dag_id: str, + task_id: str | None = None, + timeout: int | None = THREE_DAYS, + **context, +): """ - Return a Sensor task which will wait if the given external DAG is + Execute a Sensor task which will wait if the given external DAG is running. To fully ensure that the waiting DAG and the external DAG do not run @@ -64,7 +86,7 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None): if not task_id: task_id = f"wait_for_{external_dag_id}" - return ExternalTaskSensor( + sensor = ExternalTaskSensor( task_id=task_id, poke_interval=REFRESH_POKE_INTERVAL, external_dag_id=external_dag_id, @@ -75,24 +97,28 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None): mode="reschedule", # Any "finished" state is sufficient for us to continue allowed_states=[State.SUCCESS, State.FAILED], + # execution_timeout for the task does not include time that the sensor + # was up for reschedule but not actually running. `timeout` does + timeout=timeout, ) + sensor.execute(context) + @task_group(group_id="wait_for_external_dags") @provide_session def wait_for_external_dags_with_tag( - tag: str, excluded_dag_ids: list[str], session=None, **context + tag: str, excluded_dag_ids: list[str] = None, session=None ): """ Wait until all DAGs with the given `tag`, excluding those identified by the `excluded_dag_ids`, are no longer in the running state before continuing. """ - external_dag_ids = _get_dags_with_tag( - tag=tag, excluded_dag_ids=excluded_dag_ids, session=session - ) + external_dag_ids = get_dags_with_concurrency_tag.override( + task_id=f"get_dags_in_{tag}_group" + )(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session) - for dag_id in external_dag_ids: - wait_for_external_dag(dag_id) + wait_for_external_dag.expand(external_dag_id=external_dag_ids) @task(retries=0) @@ -101,13 +127,13 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context): Prevent concurrency with the given external DAG, by failing immediately if that DAG is running. """ - wait_for_dag = wait_for_external_dag( - external_dag_id=external_dag_id, - task_id=f"check_for_running_{external_dag_id}", - ) - wait_for_dag.timeout = 0 try: - wait_for_dag.execute(context) + wait_for_external_dag.function( + external_dag_id=external_dag_id, + task_id=f"check_for_running_{external_dag_id}", + timeout=0, + **context, + ) except AirflowSensorTimeout: raise ValueError(f"Concurrency check with {external_dag_id} failed.") @@ -115,7 +141,7 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context): @task_group(group_id="prevent_concurrency_with_dags") @provide_session def prevent_concurrency_with_dags_with_tag( - tag: str, excluded_dag_ids: list[str], session=None, **context + tag: str, excluded_dag_ids: list[str] = None, session=None ): """ Prevent concurrency with any DAGs that have the given `tag`, excluding @@ -123,14 +149,11 @@ def prevent_concurrency_with_dags_with_tag( failing the task immediately if any of the tagged DAGs are in the running state. """ - external_dag_ids = _get_dags_with_tag( - tag=tag, excluded_dag_ids=excluded_dag_ids, session=session - ) + external_dag_ids = get_dags_with_concurrency_tag.override( + task_id=f"get_dags_in_{tag}_group" + )(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session) - for external_dag_id in external_dag_ids: - prevent_concurrency_with_dag.override( - task_id=f"prevent_concurrency_with_{external_dag_id}" - )(external_dag_id) + prevent_concurrency_with_dag.expand(external_dag_id=external_dag_ids) @task(retries=0) diff --git a/catalog/dags/data_refresh/create_filtered_index_dag.py b/catalog/dags/data_refresh/create_filtered_index_dag.py index 9dccaa58774..0cfea3fe7db 100644 --- a/catalog/dags/data_refresh/create_filtered_index_dag.py +++ b/catalog/dags/data_refresh/create_filtered_index_dag.py @@ -121,7 +121,6 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh): # single production index simultaneously. prevent_concurrency = prevent_concurrency_with_dags_with_tag( tag=PRODUCTION_ES_CONCURRENCY_TAG, - excluded_dag_ids=[data_refresh.filtered_index_dag_id], ) # Once the concurrency check has passed, actually create the filtered diff --git a/catalog/dags/data_refresh/data_refresh_task_factory.py b/catalog/dags/data_refresh/data_refresh_task_factory.py index 2a6165199e9..c1e09066be8 100644 --- a/catalog/dags/data_refresh/data_refresh_task_factory.py +++ b/catalog/dags/data_refresh/data_refresh_task_factory.py @@ -125,9 +125,9 @@ def create_data_refresh_task_group( group_id="wait_for_es_dags" )( tag=PRODUCTION_ES_CONCURRENCY_TAG, - # Exclude the current DAG id, as well as all other data refresh DAG ids (these - # are waited on in the previous task) - excluded_dag_ids=[*external_dag_ids, data_refresh.dag_id], + # Exclude the other data refresh DAG ids, as waiting on these was handled in + # the previous task. + excluded_dag_ids=external_dag_ids, ) tasks.append([wait_for_data_refresh, wait_for_es_dags]) diff --git a/catalog/dags/database/staging_database_restore/staging_database_restore_dag.py b/catalog/dags/database/staging_database_restore/staging_database_restore_dag.py index e60a049399d..e03c16d2d85 100644 --- a/catalog/dags/database/staging_database_restore/staging_database_restore_dag.py +++ b/catalog/dags/database/staging_database_restore/staging_database_restore_dag.py @@ -80,7 +80,6 @@ def restore_staging_database(): # Wait for any DAGs that operate on the staging elasticsearch cluster wait_for_recreate_full_staging_index = wait_for_external_dags_with_tag( tag=STAGING_ES_CONCURRENCY_TAG, - excluded_dag_ids=[constants.DAG_ID], ) should_skip = skip_restore() latest_snapshot = get_latest_prod_snapshot() diff --git a/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py b/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py index a35c4f02b7d..ab73f0c5201 100644 --- a/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py +++ b/catalog/dags/elasticsearch_cluster/create_new_es_index/create_new_es_index_dag.py @@ -228,7 +228,7 @@ def create_new_es_index_dag(config: CreateNewIndex): # Fail early if any other DAG that operates on the relevant elasticsearch cluster # is running prevent_concurrency = prevent_concurrency_with_dags_with_tag( - tag=config.concurrency_tag, excluded_dag_ids=[config.dag_id] + tag=config.concurrency_tag, ) es_host = es.get_es_host(environment=config.environment) diff --git a/catalog/dags/elasticsearch_cluster/create_proportional_by_source_staging_index/create_proportional_by_source_staging_index_dag.py b/catalog/dags/elasticsearch_cluster/create_proportional_by_source_staging_index/create_proportional_by_source_staging_index_dag.py index 3d7fe127f9d..aeecd9381ab 100644 --- a/catalog/dags/elasticsearch_cluster/create_proportional_by_source_staging_index/create_proportional_by_source_staging_index_dag.py +++ b/catalog/dags/elasticsearch_cluster/create_proportional_by_source_staging_index/create_proportional_by_source_staging_index_dag.py @@ -106,7 +106,7 @@ def create_proportional_by_source_staging_index(): # Fail early if any conflicting DAGs are running prevent_concurrency = prevent_concurrency_with_dags_with_tag( - tag=STAGING_ES_CONCURRENCY_TAG, excluded_dag_ids=[DAG_ID] + tag=STAGING_ES_CONCURRENCY_TAG, ) es_host = es.get_es_host(environment=STAGING) diff --git a/catalog/dags/elasticsearch_cluster/point_es_alias/point_es_alias_dag.py b/catalog/dags/elasticsearch_cluster/point_es_alias/point_es_alias_dag.py index c7a452db61b..cf3608c4f1e 100644 --- a/catalog/dags/elasticsearch_cluster/point_es_alias/point_es_alias_dag.py +++ b/catalog/dags/elasticsearch_cluster/point_es_alias/point_es_alias_dag.py @@ -37,9 +37,9 @@ from common.sensors.utils import prevent_concurrency_with_dags_with_tag -def point_es_alias_dag(environment: str, dag_id: str): - dag = DAG( - dag_id=dag_id, +def point_es_alias_dag(environment: str): + dag = DAG( + dag_id=f"point_{environment}_alias", default_args=DAG_DEFAULT_ARGS, schedule=None, start_date=datetime(2024, 1, 31), @@ -79,7 +79,7 @@ def point_es_alias_dag(environment: str, dag_id: str): # Fail early if any other DAG that operates on the elasticsearch cluster for # this environment is running prevent_concurrency = prevent_concurrency_with_dags_with_tag( - tag=ES_CONCURRENCY_TAGS[environment], excluded_dag_ids=[dag_id] + tag=ES_CONCURRENCY_TAGS[environment], ) es_host = es.get_es_host(environment=environment) @@ -104,4 +104,4 @@ def point_es_alias_dag(environment: str, dag_id: str): for environment in ENVIRONMENTS: - point_es_alias_dag(environment, f"point_{environment}_alias") \ No newline at end of file + point_es_alias_dag(environment) diff --git a/catalog/dags/elasticsearch_cluster/recreate_staging_index/recreate_full_staging_index_dag.py b/catalog/dags/elasticsearch_cluster/recreate_staging_index/recreate_full_staging_index_dag.py index 8283da49c38..ddcd94b5b10 100644 --- a/catalog/dags/elasticsearch_cluster/recreate_staging_index/recreate_full_staging_index_dag.py +++ b/catalog/dags/elasticsearch_cluster/recreate_staging_index/recreate_full_staging_index_dag.py @@ -103,7 +103,7 @@ def recreate_full_staging_index(): # Fail early if any other DAG that operates on the staging elasticsearch cluster # is running prevent_concurrency = prevent_concurrency_with_dags_with_tag( - tag=STAGING_ES_CONCURRENCY_TAG, excluded_dag_ids=[DAG_ID] + tag=STAGING_ES_CONCURRENCY_TAG, ) target_alias = get_target_alias(