diff --git a/config/step/gwas_catalog_sumstat_preprocess.yaml b/config/step/gwas_catalog_sumstat_preprocess.yaml new file mode 100644 index 000000000..29d3486e8 --- /dev/null +++ b/config/step/gwas_catalog_sumstat_preprocess.yaml @@ -0,0 +1,3 @@ +_target_: otg.gwas_catalog_sumstat_preprocess.GWASCatalogSumstatsPreprocessStep +raw_sumstats_path: ??? +out_sumstats_path: ??? diff --git a/src/airflow/dags/common_airflow.py b/src/airflow/dags/common_airflow.py index 4feed13db..7fde7ebe4 100644 --- a/src/airflow/dags/common_airflow.py +++ b/src/airflow/dags/common_airflow.py @@ -13,6 +13,7 @@ DataprocSubmitJobOperator, ) from airflow.utils.trigger_rule import TriggerRule +from google.cloud import dataproc_v1 if TYPE_CHECKING: from pathlib import Path @@ -64,6 +65,7 @@ def create_cluster( master_machine_type: str = "n1-highmem-8", worker_machine_type: str = "n1-standard-16", num_workers: int = 2, + num_preemptible_workers: int = 0, num_local_ssds: int = 1, autoscaling_policy: str = GCP_AUTOSCALING_POLICY, ) -> DataprocCreateClusterOperator: @@ -74,6 +76,7 @@ def create_cluster( master_machine_type (str): Machine type for the master node. Defaults to "n1-highmem-8". worker_machine_type (str): Machine type for the worker nodes. Defaults to "n1-standard-16". num_workers (int): Number of worker nodes. Defaults to 2. + num_preemptible_workers (int): Number of preemptible worker nodes. Defaults to 0. num_local_ssds (int): How many local SSDs to attach to each worker node, both primary and secondary. Defaults to 1. autoscaling_policy (str): Name of the autoscaling policy to use. Defaults to GCP_AUTOSCALING_POLICY. @@ -88,6 +91,7 @@ def create_cluster( worker_machine_type=worker_machine_type, master_disk_size=500, worker_disk_size=500, + num_preemptible_workers=num_preemptible_workers, num_workers=num_workers, image_version=GCP_DATAPROC_IMAGE, enable_component_gateway=True, @@ -303,3 +307,49 @@ def generate_dag(cluster_name: str, tasks: list[DataprocSubmitJobOperator]) -> A >> tasks >> delete_cluster(cluster_name) ) + + +def submit_pyspark_job_no_operator( + cluster_name: str, + step_id: str, + other_args: Optional[list[str]] = None, +) -> None: + """Submits the Pyspark job to the cluster. + + Args: + cluster_name (str): Cluster name + step_id (str): Step id + other_args (Optional[list[str]]): Other arguments to pass to the CLI step. Defaults to None. + """ + # Create the job client. + job_client = dataproc_v1.JobControllerClient( + client_options={"api_endpoint": f"{GCP_REGION}-dataproc.googleapis.com:443"} + ) + + python_uri = f"{INITIALISATION_BASE_PATH}/{PYTHON_CLI}" + # Create the job config. 'main_jar_file_uri' can also be a + # Google Cloud Storage URL. + job_description = { + "placement": {"cluster_name": cluster_name}, + "pyspark_job": { + "main_python_file_uri": python_uri, + "args": [f"step={step_id}"] + + (other_args if other_args is not None else []) + + [ + f"--config-dir={CLUSTER_CONFIG_DIR}", + f"--config-name={CONFIG_NAME}", + ], + "properties": { + "spark.jars": "/opt/conda/miniconda3/lib/python3.10/site-packages/hail/backend/hail-all-spark.jar", + "spark.driver.extraClassPath": "/opt/conda/miniconda3/lib/python3.10/site-packages/hail/backend/hail-all-spark.jar", + "spark.executor.extraClassPath": "./hail-all-spark.jar", + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + "spark.kryo.registrator": "is.hail.kryo.HailKryoRegistrator", + }, + }, + } + res = job_client.submit_job( + project_id=GCP_PROJECT, region=GCP_REGION, job=job_description + ) + job_id = res.reference.job_id + print(f"Submitted job ID {job_id}.") diff --git a/src/airflow/dags/gwas_catalog_harmonisation.py b/src/airflow/dags/gwas_catalog_harmonisation.py new file mode 100644 index 000000000..048c65a01 --- /dev/null +++ b/src/airflow/dags/gwas_catalog_harmonisation.py @@ -0,0 +1,117 @@ +"""Airflow DAG for the harmonisation part of the pipeline.""" +from __future__ import annotations + +import re +import time +from pathlib import Path +from typing import Any + +import common_airflow as common +from airflow.decorators import task +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator + +CLUSTER_NAME = "otg-gwascatalog-harmonisation" +AUTOSCALING = "gwascatalog-harmonisation" + +SUMMARY_STATS_BUCKET_NAME = "open-targets-gwas-summary-stats" + +with DAG( + dag_id=Path(__file__).stem, + description="Open Targets Genetics — GWAS Catalog harmonisation", + default_args=common.shared_dag_args, + **common.shared_dag_kwargs, +): + # List raw harmonised files from GWAS Catalog + list_inputs = GCSListObjectsOperator( + task_id="list_raw_harmonised", + bucket=SUMMARY_STATS_BUCKET_NAME, + prefix="raw-harmonised", + match_glob="**/*.h.tsv.gz", + ) + # List parquet files that have been previously processed + list_outputs = GCSListObjectsOperator( + task_id="list_harmonised_parquet", + bucket=SUMMARY_STATS_BUCKET_NAME, + prefix="harmonised", + match_glob="**/_SUCCESS", + ) + + # Create list of pending jobs + @task(task_id="create_to_do_list") + def create_to_do_list(**kwargs: Any) -> Any: + """Create the to-do list of studies. + + Args: + **kwargs (Any): Keyword arguments. + + Returns: + Any: To-do list. + """ + ti = kwargs["ti"] + raw_harmonised = ti.xcom_pull( + task_ids="list_raw_harmonised", key="return_value" + ) + print("Number of raw harmonised files: ", len(raw_harmonised)) + to_do_list = [] + # Remove the ones that have been processed + parquets = ti.xcom_pull(task_ids="list_harmonised_parquet", key="return_value") + print("Number of parquet files: ", len(parquets)) + for path in raw_harmonised: + match_result = re.search( + "raw-harmonised/(.*)/(GCST\d+)/harmonised/(.*)\.h\.tsv\.gz", path + ) + if match_result: + study_id = match_result.group(2) + if f"harmonised/{study_id}.parquet/_SUCCESS" not in parquets: + to_do_list.append(path) + print("Number of jobs to submit: ", len(to_do_list)) + ti.xcom_push(key="to_do_list", value=to_do_list) + + # Submit jobs to dataproc + @task(task_id="submit_jobs") + def submit_jobs(**kwargs: Any) -> None: + """Submit jobs to dataproc. + + Args: + **kwargs (Any): Keyword arguments. + """ + ti = kwargs["ti"] + todo = ti.xcom_pull(task_ids="create_to_do_list", key="to_do_list") + print("Number of jobs to submit: ", len(todo)) + for i in range(len(todo)): + # Not to exceed default quota 400 jobs per minute + if i > 0 and i % 399 == 0: + time.sleep(60) + input_path = todo[i] + match_result = re.search( + "raw-harmonised/(.*)/(GCST\d+)/harmonised/(.*)\.h\.tsv\.gz", input_path + ) + if match_result: + study_id = match_result.group(2) + print("Submitting job for study: ", study_id) + common.submit_pyspark_job_no_operator( + cluster_name=CLUSTER_NAME, + step_id="gwas_catalog_sumstat_preprocess", + other_args=[ + f"step.raw_sumstats_path=gs://{SUMMARY_STATS_BUCKET_NAME}/{input_path}", + f"step.out_sumstats_path=gs://{SUMMARY_STATS_BUCKET_NAME}/harmonised/{study_id}.parquet", + ], + ) + + # list_inputs >> + ( + [list_inputs, list_outputs] + >> create_to_do_list() + >> common.create_cluster( + CLUSTER_NAME, + autoscaling_policy=AUTOSCALING, + num_workers=8, + num_preemptible_workers=8, + master_machine_type="n1-highmem-64", + worker_machine_type="n1-standard-2", + ) + >> common.install_dependencies(CLUSTER_NAME) + >> submit_jobs() + # >> common.delete_cluster(CLUSTER_NAME) + )