Skip to content

Commit

Permalink
Update init scripts to use new data refresh (#4962)
Browse files Browse the repository at this point in the history
* Fix task dependencies, allow overriding of index suffix

The TaskGroups that created new indices were returning the name of the index from the taskgroup, in order to allow accessing the index names later in other parts of the DAG. Unfortunately this was breaking the expected dependency list in the DAG; only the returned task was being set upstream to downstream tasks, rather than the last task in the TaskGroup.

TL;DR it was possible for, for example, creation of the filtered index to fail, but then table/index promotion would proceed anyway!

This commit fixes that by pulling the tasks to generate the index names out of the TaskGroups so they can be used directly at any point in the DAG.

* Set schedule to None for staging data refreshes

This is necessary for the load sample data script, so we can load sample data by running the staging DAGs. If they were on any kind of automated schedule, we would have to contend with scheduled runs in addition to our manual ones.

* Add just recipe to run single airflow cli command

* Start using the new DAGs in load_sample_data

* Update DAG docs

* Add catalog dependency to api tests in ci

* Start catalog before initializing api

* Remove accidental copy/paste

* Populate AIRFLOW_CONN_SENSITIVE_TERMS in init scripts if missing using @sarayourfriend implementation

* Add option to allow concurrent refreshes, enable during load_sample_data

This allows the init scripts to run faster as the DAGs can run concurrently when there is no risk of negative impact to ES CPU.

* Avoid naming collisions in FDW, schema, connection id when data refreshes run concurrently
  • Loading branch information
stacimc authored Oct 8, 2024
1 parent f7a7cce commit 424702c
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 116 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ jobs:
uses: ./.github/actions/load-img
with:
run_id: ${{ github.run_id }}
setup_images: upstream_db ingestion_server
setup_images: upstream_db ingestion_server catalog

# Sets build args specifying versions needed to build Docker image.
- name: Prepare build args
Expand Down Expand Up @@ -486,6 +486,9 @@ jobs:
API_PY_VERSION=${{ steps.prepare-build-args.outputs.api_py_version }}
PDM_INSTALL_ARGS=--dev
- name: Start Catalog
run: just catalog/up

- name: Start API, ingest and index test data
run: just api/init

Expand Down
29 changes: 21 additions & 8 deletions catalog/dags/common/sensors/single_run_external_dags_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@ class SingleRunExternalDAGsSensor(BaseSensorOperator):
:param external_dag_ids: A list of dag_ids that you want to wait for
:param check_existence: Set to `True` to check if the external DAGs exist,
and immediately cease waiting if not (default value: False).
:param allow_concurrent_runs: Used to force the Sensor to pass, even
if there are concurrent runs.
"""

def __init__(
self,
*,
external_dag_ids: Iterable[str],
check_existence: bool = False,
allow_concurrent_runs: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.external_dag_ids = external_dag_ids
self.check_existence = check_existence
self._has_checked_existence = False
self.allow_concurrent_runs = allow_concurrent_runs

# Used to ensure some checks are only evaluated on the first poke
self._has_checked_params = False

@provide_session
def poke(self, context, session=None):
Expand All @@ -42,19 +48,27 @@ def poke(self, context, session=None):
self.external_dag_ids,
)

if self.check_existence:
self._check_for_existence(session=session)
if not self._has_checked_params:
if self.allow_concurrent_runs:
self.log.info(
"`allow_concurrent_runs` is enabled. Returning without"
" checking for running DAGs."
)
return True

if self.check_existence:
self._check_for_existence(session=session)

# Only check DAG existence and `allow_concurrent_runs`
# on the first execution.
self._has_checked_params = True

count_running = self.get_count(session)

self.log.info("%s DAGs are in the running state", count_running)
return count_running == 0

def _check_for_existence(self, session) -> None:
# Check DAG existence only once, on the first execution.
if self._has_checked_existence:
return

for dag_id in self.external_dag_ids:
dag_to_wait = (
session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
Expand All @@ -72,7 +86,6 @@ def _check_for_existence(self, session) -> None:
f"The external DAG {dag_id} does not have a task "
f"with id {self.task_id}."
)
self._has_checked_existence = True

def get_count(self, session) -> int:
# Get the count of running DAGs. A DAG is considered 'running' if
Expand Down
18 changes: 16 additions & 2 deletions catalog/dags/data_refresh/copy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,35 @@
def initialize_fdw(
upstream_conn_id: str,
downstream_conn_id: str,
media_type: str,
task: AbstractOperator = None,
):
"""Create the FDW and prepare it for copying."""
upstream_connection = Connection.get_connection_from_secrets(upstream_conn_id)
fdw_name = f"upstream_{media_type}"

run_sql.function(
postgres_conn_id=downstream_conn_id,
sql_template=queries.CREATE_FDW_QUERY,
task=task,
fdw_name=fdw_name,
host=upstream_connection.host,
port=upstream_connection.port,
dbname=upstream_connection.schema,
user=upstream_connection.login,
password=upstream_connection.password,
)

return fdw_name


@task(
max_active_tis_per_dagrun=1,
map_index_template="{{ task.op_kwargs['upstream_table_name'] }}",
)
def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
def create_schema(
downstream_conn_id: str, upstream_table_name: str, fdw_name: str
) -> str:
"""
Create a new schema in the downstream DB through which the upstream table
can be accessed. Returns the schema name.
Expand All @@ -73,7 +80,9 @@ def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
schema_name = f"upstream_{upstream_table_name}_schema"
downstream_pg.run(
queries.CREATE_SCHEMA_QUERY.format(
schema_name=schema_name, upstream_table_name=upstream_table_name
fdw_name=fdw_name,
schema_name=schema_name,
upstream_table_name=upstream_table_name,
)
)
return schema_name
Expand Down Expand Up @@ -183,6 +192,7 @@ def copy_data(
def copy_upstream_table(
upstream_conn_id: str,
downstream_conn_id: str,
fdw_name: str,
timeout: timedelta,
limit: int,
upstream_table_name: str,
Expand All @@ -206,6 +216,7 @@ def copy_upstream_table(
schema = create_schema(
downstream_conn_id=downstream_conn_id,
upstream_table_name=upstream_table_name,
fdw_name=fdw_name,
)

create_temp_table = run_sql.override(
Expand Down Expand Up @@ -286,6 +297,7 @@ def copy_upstream_tables(
init_fdw = initialize_fdw(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
media_type=data_refresh_config.media_type,
)

limit = get_record_limit()
Expand All @@ -294,13 +306,15 @@ def copy_upstream_tables(
copy_tables = copy_upstream_table.partial(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
fdw_name=init_fdw,
timeout=data_refresh_config.copy_data_timeout,
limit=limit,
).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings])

drop_fdw = run_sql.override(task_id="drop_fdw")(
postgres_conn_id=downstream_conn_id,
sql_template=queries.DROP_SERVER_QUERY,
fdw_name=init_fdw,
)

# Set up dependencies
Expand Down
13 changes: 3 additions & 10 deletions catalog/dags/data_refresh/create_and_populate_filtered_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,13 @@ def create_and_populate_filtered_index(
es_host: str,
media_type: MediaType,
origin_index_name: str,
filtered_index_name: str,
timeout: timedelta,
destination_index_name: str | None = None,
):
"""
Create and populate a filtered index based on the given origin index, excluding
documents with sensitive terms.
"""
filtered_index_name = get_filtered_index_name(
media_type=media_type, destination_index_name=destination_index_name
)

create_filtered_index = es.create_index.override(
trigger_rule=TriggerRule.NONE_FAILED,
)(
Expand All @@ -76,7 +72,6 @@ def create_and_populate_filtered_index(
method="GET",
response_check=lambda response: response.status_code == 200,
response_filter=response_filter_sensitive_terms_endpoint,
trigger_rule=TriggerRule.NONE_FAILED,
)

populate_filtered_index = es.trigger_and_wait_for_reindex(
Expand All @@ -99,7 +94,5 @@ def create_and_populate_filtered_index(

refresh_index = es.refresh_index(es_host=es_host, index_name=filtered_index_name)

sensitive_terms >> populate_filtered_index
create_filtered_index >> populate_filtered_index >> refresh_index

return filtered_index_name
# sensitive_terms >> populate_filtered_index
create_filtered_index >> sensitive_terms >> populate_filtered_index >> refresh_index
44 changes: 0 additions & 44 deletions catalog/dags/data_refresh/create_index.py

This file was deleted.

Loading

0 comments on commit 424702c

Please sign in to comment.