Skip to content

Commit

Permalink
Add elasticsearch concurrency tags for Airflow (#3921)
Browse files Browse the repository at this point in the history
* Update sensor utils to get list of dag_ids matching a specified tag

* Appropriately tag ES dags with their concurrency group, and use new sensor utils

* Update dag docs

* Automatically exclude the running DAG from the ids to check

* Fix bad merge

* Fix bad dag_id change, don't make db_restore wait unnecessarily

* Only pass required context, use set

* Rename variable for clarity

* Split concurrency groups out into ES and DB groups

* Clarify concurrency_tag variable name

* Fix typo in dag docs
  • Loading branch information
stacimc authored Mar 26, 2024
1 parent a872fec commit ed267d8
Show file tree
Hide file tree
Showing 13 changed files with 279 additions and 184 deletions.
19 changes: 19 additions & 0 deletions catalog/dags/common/sensors/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from common.constants import PRODUCTION, STAGING


# These DagTags are used to identify DAGs which should not be run concurrently
# with one another.

# Used to identify DAGs for each environment which affect the Elasticsearch cluster
# and should not be run simultaneously
PRODUCTION_ES_CONCURRENCY_TAG = "production_elasticsearch_concurrency"
STAGING_ES_CONCURRENCY_TAG = "staging_elasticsearch_concurrency"

# Used to identify DAGs which affect the staging API database in such a
# way that they should not be run simultaneously
STAGING_DB_CONCURRENCY_TAG = "staging_api_database_concurrency"

ES_CONCURRENCY_TAGS = {
PRODUCTION: PRODUCTION_ES_CONCURRENCY_TAG,
STAGING: STAGING_ES_CONCURRENCY_TAG,
}
108 changes: 81 additions & 27 deletions catalog/dags/common/sensors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from airflow.decorators import task, task_group
from airflow.exceptions import AirflowSensorTimeout
from airflow.models import DagRun
from airflow.models import DagModel, DagRun, DagTag
from airflow.sensors.external_task import ExternalTaskSensor
from airflow.utils.session import provide_session
from airflow.utils.state import State

from common.constants import REFRESH_POKE_INTERVAL


def get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
THREE_DAYS = 60 * 60 * 24 * 3


def _get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
"""
Retrieve the most recent DAG run's execution date.
Expand All @@ -35,9 +39,40 @@ def get_most_recent_dag_run(dag_id) -> list[datetime] | datetime:
return []


def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
@task
def get_dags_with_concurrency_tag(
tag: str, excluded_dag_ids: list[str], session=None, dag=None
):
"""
Get a list of DAG ids with the given tag. The id of the running DAG is excluded,
as well as any ids in the `excluded_dag_ids` list.
"""
Return a Sensor task which will wait if the given external DAG is
dags = session.query(DagModel).filter(DagModel.tags.any(DagTag.name == tag)).all()
dag_ids = [dag.dag_id for dag in dags]

running_dag_id = dag.dag_id
if running_dag_id not in dag_ids:
raise ValueError(
f"The `{running_dag_id}` DAG tried preventing concurrency with the `{tag}`,"
" tag, but does not have the tag itself. To ensure that other DAGs with this"
f" tag will also avoid running concurrently with `{running_dag_id}`, it must"
f"have the `{tag}` tag applied."
)

# Return just the ids of DAGs to prevent concurrency with. This excludes the running dag id,
# and any supplied `excluded_dag_ids`
return [id for id in dag_ids if id not in {*excluded_dag_ids, running_dag_id}]


@task
def wait_for_external_dag(
external_dag_id: str,
task_id: str | None = None,
timeout: int | None = THREE_DAYS,
**context,
):
"""
Execute a Sensor task which will wait if the given external DAG is
running.
To fully ensure that the waiting DAG and the external DAG do not run
Expand All @@ -51,28 +86,39 @@ def wait_for_external_dag(external_dag_id: str, task_id: str | None = None):
if not task_id:
task_id = f"wait_for_{external_dag_id}"

return ExternalTaskSensor(
sensor = ExternalTaskSensor(
task_id=task_id,
poke_interval=REFRESH_POKE_INTERVAL,
external_dag_id=external_dag_id,
# Wait for the whole DAG, not just a part of it
external_task_id=None,
check_existence=False,
execution_date_fn=lambda _: get_most_recent_dag_run(external_dag_id),
execution_date_fn=lambda _: _get_most_recent_dag_run(external_dag_id),
mode="reschedule",
# Any "finished" state is sufficient for us to continue
allowed_states=[State.SUCCESS, State.FAILED],
# execution_timeout for the task does not include time that the sensor
# was up for reschedule but not actually running. `timeout` does
timeout=timeout,
)

sensor.execute(context)


@task_group(group_id="wait_for_external_dags")
def wait_for_external_dags(external_dag_ids: list[str]):
@provide_session
def wait_for_external_dags_with_tag(
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Wait for all DAGs with the given external DAG ids to no longer be
in a running state before continuing.
Wait until all DAGs with the given `tag`, excluding those identified by the
`excluded_dag_ids`, are no longer in the running state before continuing.
"""
for dag_id in external_dag_ids:
wait_for_external_dag(dag_id)
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

wait_for_external_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
Expand All @@ -81,18 +127,35 @@ def prevent_concurrency_with_dag(external_dag_id: str, **context):
Prevent concurrency with the given external DAG, by failing
immediately if that DAG is running.
"""

wait_for_dag = wait_for_external_dag(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
)
wait_for_dag.timeout = 0
try:
wait_for_dag.execute(context)
wait_for_external_dag.function(
external_dag_id=external_dag_id,
task_id=f"check_for_running_{external_dag_id}",
timeout=0,
**context,
)
except AirflowSensorTimeout:
raise ValueError(f"Concurrency check with {external_dag_id} failed.")


@task_group(group_id="prevent_concurrency_with_dags")
@provide_session
def prevent_concurrency_with_dags_with_tag(
tag: str, excluded_dag_ids: list[str] = None, session=None
):
"""
Prevent concurrency with any DAGs that have the given `tag`, excluding
those identified by the `excluded_dag_ids`. Concurrency is prevented by
failing the task immediately if any of the tagged DAGs are in the running
state.
"""
external_dag_ids = get_dags_with_concurrency_tag.override(
task_id=f"get_dags_in_{tag}_group"
)(tag=tag, excluded_dag_ids=excluded_dag_ids or [], session=session)

prevent_concurrency_with_dag.expand(external_dag_id=external_dag_ids)


@task(retries=0)
def is_concurrent_with_any(external_dag_ids: list[str], **context):
"""
Expand All @@ -109,12 +172,3 @@ def is_concurrent_with_any(external_dag_ids: list[str], **context):

# Explicit return None to clarify expectations
return None


@task_group(group_id="prevent_concurrency")
def prevent_concurrency_with_dags(external_dag_ids: list[str]):
"""Fail immediately if any of the given external dags are in progress."""
for dag_id in external_dag_ids:
prevent_concurrency_with_dag.override(
task_id=f"prevent_concurrency_with_{dag_id}"
)(dag_id)
28 changes: 12 additions & 16 deletions catalog/dags/data_refresh/create_filtered_index_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
There are two mechanisms that prevent this from happening:
1. The filtered index creation DAGs are not allowed to run if a data refresh
for the media type is already running.
1. The filtered index creation DAGs fail immediately if any of the DAGs that are
tagged as part of the `production-es-concurrency` group (including the data
refreshes) are currently running.
2. The data refresh DAGs will wait for any pre-existing filtered index creation
DAG runs for the media type to finish before continuing.
Expand All @@ -56,15 +57,13 @@
from airflow import DAG
from airflow.models.param import Param

from common.constants import DAG_DEFAULT_ARGS, PRODUCTION
from common.sensors.utils import prevent_concurrency_with_dags
from common.constants import DAG_DEFAULT_ARGS
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sensors.utils import prevent_concurrency_with_dags_with_tag
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


# Note: We can't use the TaskFlow `@dag` DAG factory decorator
Expand All @@ -88,7 +87,7 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
default_args=DAG_DEFAULT_ARGS,
schedule=None,
start_date=datetime(2023, 4, 1),
tags=["data_refresh"],
tags=["data_refresh", PRODUCTION_ES_CONCURRENCY_TAG],
max_active_runs=1,
catchup=False,
doc_md=__doc__,
Expand Down Expand Up @@ -117,14 +116,11 @@ def create_filtered_index_creation_dag(data_refresh: DataRefresh):
},
render_template_as_native_obj=True,
) as dag:
# Immediately fail if the associated data refresh is running, or the
# create_new_production_es_index DAG is running. This prevents multiple
# DAGs from reindexing from a single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags(
external_dag_ids=[
data_refresh.dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
# Immediately fail if any DAG that operates on the production elasticsearch
# cluster is running. This prevents multiple DAGs from reindexing from a
# single production index simultaneously.
prevent_concurrency = prevent_concurrency_with_dags_with_tag(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
)

# Once the concurrency check has passed, actually create the filtered
Expand Down
3 changes: 2 additions & 1 deletion catalog/dags/data_refresh/dag_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
OPENLEDGER_API_CONN_ID,
XCOM_PULL_TEMPLATE,
)
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sql import PGExecuteQueryOperator, single_value
from data_refresh.data_refresh_task_factory import create_data_refresh_task_group
from data_refresh.data_refresh_types import DATA_REFRESH_CONFIGS, DataRefresh
Expand Down Expand Up @@ -70,7 +71,7 @@ def create_data_refresh_dag(data_refresh: DataRefresh, external_dag_ids: Sequenc
max_active_runs=1,
catchup=False,
doc_md=__doc__,
tags=["data_refresh"],
tags=["data_refresh", PRODUCTION_ES_CONCURRENCY_TAG],
)

with dag:
Expand Down
20 changes: 10 additions & 10 deletions catalog/dags/data_refresh/data_refresh_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,14 @@
from airflow.utils.trigger_rule import TriggerRule

from common import cloudwatch, ingestion_server
from common.constants import PRODUCTION, XCOM_PULL_TEMPLATE
from common.constants import XCOM_PULL_TEMPLATE
from common.sensors.constants import PRODUCTION_ES_CONCURRENCY_TAG
from common.sensors.single_run_external_dags_sensor import SingleRunExternalDAGsSensor
from common.sensors.utils import wait_for_external_dags
from common.sensors.utils import wait_for_external_dags_with_tag
from data_refresh.create_filtered_index import (
create_filtered_index_creation_task_groups,
)
from data_refresh.data_refresh_types import DataRefresh
from elasticsearch_cluster.create_new_es_index.create_new_es_index_types import (
CREATE_NEW_INDEX_CONFIGS,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -123,11 +121,13 @@ def create_data_refresh_task_group(
# Realistically the data refresh is too slow to beat the index creation process,
# even if it was triggered immediately after one of these DAGs; however, it is
# always safer to avoid the possibility of the race condition altogether.
wait_for_es_dags = wait_for_external_dags.override(group_id="wait_for_es_dags")(
external_dag_ids=[
data_refresh.filtered_index_dag_id,
CREATE_NEW_INDEX_CONFIGS[PRODUCTION].dag_id,
]
wait_for_es_dags = wait_for_external_dags_with_tag.override(
group_id="wait_for_es_dags"
)(
tag=PRODUCTION_ES_CONCURRENCY_TAG,
# Exclude the other data refresh DAG ids, as waiting on these was handled in
# the previous task.
excluded_dag_ids=external_dag_ids,
)
tasks.append([wait_for_data_refresh, wait_for_es_dags])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
# Update the staging database
# Staging Database Restore DAG
This DAG is responsible for updating the staging database using the most recent
snapshot of the production database.
Expand Down Expand Up @@ -35,7 +35,8 @@
DAG_DEFAULT_ARGS,
POSTGRES_API_STAGING_CONN_ID,
)
from common.sensors.utils import wait_for_external_dag
from common.sensors.constants import STAGING_DB_CONCURRENCY_TAG
from common.sensors.utils import wait_for_external_dags_with_tag
from common.sql import PGExecuteQueryOperator
from database.staging_database_restore import constants
from database.staging_database_restore.staging_database_restore import (
Expand All @@ -48,9 +49,6 @@
restore_staging_from_snapshot,
skip_restore,
)
from elasticsearch_cluster.recreate_staging_index.recreate_full_staging_index import (
DAG_ID as RECREATE_STAGING_INDEX_DAG_ID,
)


log = logging.getLogger(__name__)
Expand All @@ -60,7 +58,7 @@
dag_id=constants.DAG_ID,
schedule="@monthly",
start_date=datetime(2023, 5, 1),
tags=["database"],
tags=["database", STAGING_DB_CONCURRENCY_TAG],
max_active_runs=1,
dagrun_timeout=timedelta(days=1),
catchup=False,
Expand All @@ -76,9 +74,10 @@
def restore_staging_database():
# If the `recreate_full_staging_index` DAG was manually triggered prior
# to the database restoration starting, we should wait for it to
# finish.
wait_for_recreate_full_staging_index = wait_for_external_dag(
external_dag_id=RECREATE_STAGING_INDEX_DAG_ID,
# finish. It is not necessary to wait on any of the other ES DAGs as
# they do not directly affect the database.
wait_for_recreate_full_staging_index = wait_for_external_dags_with_tag(
tag=STAGING_DB_CONCURRENCY_TAG
)
should_skip = skip_restore()
latest_snapshot = get_latest_prod_snapshot()
Expand Down
Loading

0 comments on commit ed267d8

Please sign in to comment.