Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…thon into il-ammend-tests
  • Loading branch information
ireneisdoomed committed Dec 8, 2023
2 parents ab29359 + e0297cb commit 72c13c3
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 1 deletion.
3 changes: 3 additions & 0 deletions config/step/gwas_catalog_sumstat_preprocess.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: otg.gwas_catalog_sumstat_preprocess.GWASCatalogSumstatsPreprocessStep
raw_sumstats_path: ???
out_sumstats_path: ???
52 changes: 51 additions & 1 deletion src/airflow/dags/common_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DataprocSubmitJobOperator,
)
from airflow.utils.trigger_rule import TriggerRule
from google.cloud import dataproc_v1

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -54,7 +55,7 @@
shared_dag_kwargs = dict(
tags=["genetics_etl", "experimental"],
start_date=pendulum.now(tz="Europe/London").subtract(days=1),
schedule_interval="@once",
schedule="@once",
catchup=False,
)

Expand All @@ -64,6 +65,7 @@ def create_cluster(
master_machine_type: str = "n1-highmem-8",
worker_machine_type: str = "n1-standard-16",
num_workers: int = 2,
num_preemptible_workers: int = 0,
num_local_ssds: int = 1,
autoscaling_policy: str = GCP_AUTOSCALING_POLICY,
) -> DataprocCreateClusterOperator:
Expand All @@ -74,6 +76,7 @@ def create_cluster(
master_machine_type (str): Machine type for the master node. Defaults to "n1-highmem-8".
worker_machine_type (str): Machine type for the worker nodes. Defaults to "n1-standard-16".
num_workers (int): Number of worker nodes. Defaults to 2.
num_preemptible_workers (int): Number of preemptible worker nodes. Defaults to 0.
num_local_ssds (int): How many local SSDs to attach to each worker node, both primary and secondary. Defaults to 1.
autoscaling_policy (str): Name of the autoscaling policy to use. Defaults to GCP_AUTOSCALING_POLICY.
Expand All @@ -88,6 +91,7 @@ def create_cluster(
worker_machine_type=worker_machine_type,
master_disk_size=500,
worker_disk_size=500,
num_preemptible_workers=num_preemptible_workers,
num_workers=num_workers,
image_version=GCP_DATAPROC_IMAGE,
enable_component_gateway=True,
Expand Down Expand Up @@ -303,3 +307,49 @@ def generate_dag(cluster_name: str, tasks: list[DataprocSubmitJobOperator]) -> A
>> tasks
>> delete_cluster(cluster_name)
)


def submit_pyspark_job_no_operator(
cluster_name: str,
step_id: str,
other_args: Optional[list[str]] = None,
) -> None:
"""Submits the Pyspark job to the cluster.
Args:
cluster_name (str): Cluster name
step_id (str): Step id
other_args (Optional[list[str]]): Other arguments to pass to the CLI step. Defaults to None.
"""
# Create the job client.
job_client = dataproc_v1.JobControllerClient(
client_options={"api_endpoint": f"{GCP_REGION}-dataproc.googleapis.com:443"}
)

python_uri = f"{INITIALISATION_BASE_PATH}/{PYTHON_CLI}"
# Create the job config. 'main_jar_file_uri' can also be a
# Google Cloud Storage URL.
job_description = {
"placement": {"cluster_name": cluster_name},
"pyspark_job": {
"main_python_file_uri": python_uri,
"args": [f"step={step_id}"]
+ (other_args if other_args is not None else [])
+ [
f"--config-dir={CLUSTER_CONFIG_DIR}",
f"--config-name={CONFIG_NAME}",
],
"properties": {
"spark.jars": "/opt/conda/miniconda3/lib/python3.10/site-packages/hail/backend/hail-all-spark.jar",
"spark.driver.extraClassPath": "/opt/conda/miniconda3/lib/python3.10/site-packages/hail/backend/hail-all-spark.jar",
"spark.executor.extraClassPath": "./hail-all-spark.jar",
"spark.serializer": "org.apache.spark.serializer.KryoSerializer",
"spark.kryo.registrator": "is.hail.kryo.HailKryoRegistrator",
},
},
}
res = job_client.submit_job(
project_id=GCP_PROJECT, region=GCP_REGION, job=job_description
)
job_id = res.reference.job_id
print(f"Submitted job ID {job_id}.")
117 changes: 117 additions & 0 deletions src/airflow/dags/gwas_catalog_harmonisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Airflow DAG for the harmonisation part of the pipeline."""
from __future__ import annotations

import re
import time
from pathlib import Path
from typing import Any

import common_airflow as common
from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator

CLUSTER_NAME = "otg-gwascatalog-harmonisation"
AUTOSCALING = "gwascatalog-harmonisation"

SUMMARY_STATS_BUCKET_NAME = "open-targets-gwas-summary-stats"

with DAG(
dag_id=Path(__file__).stem,
description="Open Targets Genetics — GWAS Catalog harmonisation",
default_args=common.shared_dag_args,
**common.shared_dag_kwargs,
):
# List raw harmonised files from GWAS Catalog
list_inputs = GCSListObjectsOperator(
task_id="list_raw_harmonised",
bucket=SUMMARY_STATS_BUCKET_NAME,
prefix="raw-harmonised",
match_glob="**/*.h.tsv.gz",
)
# List parquet files that have been previously processed
list_outputs = GCSListObjectsOperator(
task_id="list_harmonised_parquet",
bucket=SUMMARY_STATS_BUCKET_NAME,
prefix="harmonised",
match_glob="**/_SUCCESS",
)

# Create list of pending jobs
@task(task_id="create_to_do_list")
def create_to_do_list(**kwargs: Any) -> Any:
"""Create the to-do list of studies.
Args:
**kwargs (Any): Keyword arguments.
Returns:
Any: To-do list.
"""
ti = kwargs["ti"]
raw_harmonised = ti.xcom_pull(
task_ids="list_raw_harmonised", key="return_value"
)
print("Number of raw harmonised files: ", len(raw_harmonised))
to_do_list = []
# Remove the ones that have been processed
parquets = ti.xcom_pull(task_ids="list_harmonised_parquet", key="return_value")
print("Number of parquet files: ", len(parquets))
for path in raw_harmonised:
match_result = re.search(
"raw-harmonised/(.*)/(GCST\d+)/harmonised/(.*)\.h\.tsv\.gz", path
)
if match_result:
study_id = match_result.group(2)
if f"harmonised/{study_id}.parquet/_SUCCESS" not in parquets:
to_do_list.append(path)
print("Number of jobs to submit: ", len(to_do_list))
ti.xcom_push(key="to_do_list", value=to_do_list)

# Submit jobs to dataproc
@task(task_id="submit_jobs")
def submit_jobs(**kwargs: Any) -> None:
"""Submit jobs to dataproc.
Args:
**kwargs (Any): Keyword arguments.
"""
ti = kwargs["ti"]
todo = ti.xcom_pull(task_ids="create_to_do_list", key="to_do_list")
print("Number of jobs to submit: ", len(todo))
for i in range(len(todo)):
# Not to exceed default quota 400 jobs per minute
if i > 0 and i % 399 == 0:
time.sleep(60)
input_path = todo[i]
match_result = re.search(
"raw-harmonised/(.*)/(GCST\d+)/harmonised/(.*)\.h\.tsv\.gz", input_path
)
if match_result:
study_id = match_result.group(2)
print("Submitting job for study: ", study_id)
common.submit_pyspark_job_no_operator(
cluster_name=CLUSTER_NAME,
step_id="gwas_catalog_sumstat_preprocess",
other_args=[
f"step.raw_sumstats_path=gs://{SUMMARY_STATS_BUCKET_NAME}/{input_path}",
f"step.out_sumstats_path=gs://{SUMMARY_STATS_BUCKET_NAME}/harmonised/{study_id}.parquet",
],
)

# list_inputs >>
(
[list_inputs, list_outputs]
>> create_to_do_list()
>> common.create_cluster(
CLUSTER_NAME,
autoscaling_policy=AUTOSCALING,
num_workers=8,
num_preemptible_workers=8,
master_machine_type="n1-highmem-64",
worker_machine_type="n1-standard-2",
)
>> common.install_dependencies(CLUSTER_NAME)
>> submit_jobs()
# >> common.delete_cluster(CLUSTER_NAME)
)

0 comments on commit 72c13c3

Please sign in to comment.