From 91b3d6c7437767614707065fc56ced840dbb5f14 Mon Sep 17 00:00:00 2001 From: Chad Nelson Date: Fri, 29 Sep 2023 15:37:03 -0700 Subject: [PATCH] fix: factor out shared tasks --- dags/mapper_dag.py | 6 +++--- dags/shared_tasks.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 dags/shared_tasks.py diff --git a/dags/mapper_dag.py b/dags/mapper_dag.py index 6116d3398..000656590 100644 --- a/dags/mapper_dag.py +++ b/dags/mapper_dag.py @@ -2,9 +2,9 @@ from airflow.decorators import dag, task from airflow.models.param import Param -from rikolti.dags.harvest_dag import get_collection_metadata_task -from rikolti.dags.harvest_dag import map_page_task -from rikolti.dags.harvest_dag import get_mapping_status_task +from shared_tasks import get_collection_metadata_task +from shared_tasks import map_page_task +from shared_tasks import get_mapping_status_task from rikolti.metadata_mapper.lambda_shepherd import get_vernacular_pages diff --git a/dags/shared_tasks.py b/dags/shared_tasks.py new file mode 100644 index 000000000..a1eb19b3d --- /dev/null +++ b/dags/shared_tasks.py @@ -0,0 +1,35 @@ +import requests +from airflow.decorators import task +from rikolti.metadata_mapper.lambda_function import map_page +from rikolti.metadata_mapper.lambda_shepherd import get_mapping_status + +@task() +def get_collection_metadata_task(params=None): + if not params or not params.get('collection_id'): + raise ValueError("Collection ID not found in params") + collection_id = params.get('collection_id') + + resp = requests.get( + "https://registry.cdlib.org/api/v1/" + f"rikolticollection/{collection_id}/?format=json" + ) + resp.raise_for_status() + + return resp.json() + + +# max_active_tis_per_dag - setting on the task to restrict how many +# instances can be running at the same time, *across all DAG runs* +@task() +def map_page_task(page: str, collection: dict): + collection_id = collection.get('id') + if not collection_id: + return False + mapped_page = map_page(collection_id, page, collection) + return mapped_page + + +@task() +def get_mapping_status_task(collection: dict, mapped_pages: list): + mapping_status = get_mapping_status(collection, mapped_pages) + return mapping_status