Skip to content

Commit 683b008

Browse files
author
Szymon Szyszkowski
committed
feat(gwas_catalog): harmonisation dag
1 parent 99518f0 commit 683b008

File tree

7 files changed

+619
-0
lines changed

7 files changed

+619
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
nodes:
2+
- id: generate_sumstat_index
3+
kind: Task
4+
prerequisites: []
5+
google_batch_index_specs:
6+
manifest_generator_label: gwas_catalog_harmonisation
7+
max_task_count: 1000
8+
manifest_generator_specs:
9+
commands:
10+
- -c
11+
- ./harmonise-sumstats.sh
12+
- $RAW
13+
- $HARMONISED
14+
- $QC
15+
- 1.0e-8
16+
options:
17+
manifest_kwargs:
18+
qc_output_pattern: gs://gwas_catalog_inputs/summary_statistics_qc/**_SUCCESS
19+
harm_output_pattern: gs://gwas_catalog_inputs/harmonised_summary_statistics/**_SUCCESS
20+
raw_input_pattern: gs://gwas_catalog_inputs/raw_summary_statistics/**.h.tsv.gz
21+
manifest_output_uri: gs://gwas_catalog_inputs/harmonisation_manifest.csv
22+
23+
- id: gwas_catalog_harmonisation
24+
kind: Task
25+
prerequisites:
26+
- generate_sumstat_index
27+
google_batch:
28+
entrypoint: /bin/sh
29+
image: europe-west1-docker.pkg.dev/open-targets-genetics-dev/gentropy-app/gentropy:orchestration
30+
resource_specs:
31+
cpu_milli: 4000
32+
memory_mib: 8000
33+
boot_disk_mib: 8000
34+
task_specs:
35+
max_retry_count: 2
36+
max_run_duration: "1h"
37+
policy_specs:
38+
machine_type: n1-standard-4
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Airflow DAG for GWAS Catalog sumstat harmonisation."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from pathlib import Path
7+
8+
from airflow.decorators import task
9+
from airflow.models.baseoperator import chain
10+
from airflow.models.dag import DAG
11+
12+
from ot_orchestration.operators.batch.harmonisation import (
13+
BatchIndexOperator,
14+
)
15+
from ot_orchestration.utils import (
16+
find_node_in_config,
17+
read_yaml_config,
18+
)
19+
from ot_orchestration.utils.common import shared_dag_args, shared_dag_kwargs
20+
21+
SOURCE_CONFIG_FILE_PATH = (
22+
Path(__file__).parent / "config" / "gwas_catalog_harmonisation.yaml"
23+
)
24+
config = read_yaml_config(SOURCE_CONFIG_FILE_PATH)
25+
26+
27+
@task(task_id="begin")
28+
def begin():
29+
"""Starting the DAG execution."""
30+
logging.info("STARTING")
31+
logging.info(config)
32+
33+
34+
@task(task_id="end")
35+
def end():
36+
"""Finish the DAG execution."""
37+
logging.info("FINISHED")
38+
39+
40+
with DAG(
41+
dag_id=Path(__file__).stem,
42+
description="Open Targets Genetics — GWAS Catalog Sumstat Harmonisation",
43+
default_args=shared_dag_args,
44+
**shared_dag_kwargs,
45+
):
46+
node_config = find_node_in_config(config["nodes"], "generate_sumstat_index")
47+
batch_index = BatchIndexOperator(
48+
task_id=node_config["id"],
49+
batch_index_specs=node_config["google_batch_index_specs"],
50+
)
51+
node_config = find_node_in_config(config["nodes"], "gwas_catalog_harmonisation")
52+
# harmonisation_batch_job = GeneticsBatchJobOperator.partial(
53+
# task_id=node_config["id"],
54+
# job_name="harmonisation",
55+
# google_batch=node_config["google_batch"],
56+
# ).expand(batch_index_row=batch_index.output)
57+
58+
chain(
59+
begin(),
60+
batch_index,
61+
# harmonisation_batch_job,
62+
end(),
63+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""Batch Index."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import TypedDict
7+
8+
from airflow.exceptions import AirflowSkipException
9+
from google.cloud.batch import Environment
10+
11+
from ot_orchestration.utils.batch import create_task_commands, create_task_env
12+
13+
14+
class BatchCommandsSerialized(TypedDict):
15+
options: dict[str, str]
16+
commands: list[str]
17+
18+
19+
class BatchEnvironmentsSerialized(TypedDict):
20+
vars_list: list[dict[str, str]]
21+
22+
23+
class BatchCommands:
24+
def __init__(self, options: dict[str, str], commands: list[str]):
25+
self.options = options
26+
self.commands = commands
27+
28+
def construct(self) -> list[str]:
29+
"""Construct Batch commands from mapping."""
30+
logging.info(
31+
"Constructing batch task commands from commands: %s and options: %s",
32+
self.commands,
33+
self.options,
34+
)
35+
commands = create_task_commands(self.commands, self.options)
36+
return commands
37+
38+
def serialize(self) -> BatchCommandsSerialized:
39+
"""Serialize batch commands."""
40+
return BatchCommandsSerialized(options=self.options, commands=self.commands)
41+
42+
@staticmethod
43+
def deserialize(data: BatchCommandsSerialized) -> BatchCommands:
44+
"""Deserialize batch commands."""
45+
return BatchCommands(options=data["options"], commands=data["commands"])
46+
47+
48+
class BatchEnvironments:
49+
def __init__(self, vars_list: list[dict[str, str]]):
50+
self.vars_list = vars_list
51+
52+
def construct(self) -> list[Environment]:
53+
"""Construct Batch Environment from list of mappings."""
54+
logging.info(
55+
"Constructing batch environments from vars_list: %s", self.vars_list
56+
)
57+
if not self.vars_list:
58+
logging.warning(
59+
"Can not create Batch environments from empty variable list, skipping"
60+
)
61+
raise AirflowSkipException(
62+
"Can not create Batch environments from empty variable list"
63+
)
64+
environments = create_task_env(self.vars_list)
65+
print(f"{environments=}")
66+
return environments
67+
68+
def serialize(self) -> BatchEnvironmentsSerialized:
69+
"""Serialize batch environments."""
70+
return BatchEnvironmentsSerialized(vars_list=self.vars_list)
71+
72+
@staticmethod
73+
def deserialize(data: BatchEnvironmentsSerialized) -> BatchEnvironments:
74+
"""Deserialize batch environments."""
75+
return BatchEnvironments(vars_list=data["vars_list"])
76+
77+
78+
class BatchIndexRow(TypedDict):
79+
idx: int
80+
command: BatchCommandsSerialized
81+
environment: BatchEnvironmentsSerialized
82+
83+
84+
class BatchIndex:
85+
"""Index of all batch jobs.
86+
87+
This object contains paths to individual manifest objects.
88+
Each of the manifests will be a single batch job.
89+
Each line of the individual manifest is a representation of the batch job task.
90+
"""
91+
92+
def __init__(
93+
self,
94+
vars_list: list[dict[str, str]],
95+
options: dict[str, str],
96+
commands: list[str],
97+
max_task_count: int,
98+
) -> None:
99+
self.vars_list = vars_list
100+
self.options = options
101+
self.commands = commands
102+
self.max_task_count = max_task_count
103+
self.vars_batches: list[BatchEnvironmentsSerialized] = []
104+
105+
def partition(self) -> BatchIndex:
106+
"""Partition batch index by N chunks taking into account max_task_count."""
107+
if not self.vars_list:
108+
msg = "BatchIndex can not partition variable list, as list is empty."
109+
logging.warning(msg)
110+
return self
111+
112+
if self.max_task_count > len(self.vars_list):
113+
logging.warning(
114+
"BatchIndex will use only one partition due to size of the dataset being smaller then max_task_count %s < %s",
115+
len(self.vars_list),
116+
self.max_task_count,
117+
)
118+
self.max_task_count = len(self.vars_list)
119+
120+
for i in range(0, len(self.vars_list), self.max_task_count):
121+
batch = self.vars_list[i : i + self.max_task_count]
122+
self.vars_batches.append(BatchEnvironmentsSerialized(vars_list=batch))
123+
124+
logging.info("Created %s task list batches.", len(self.vars_batches))
125+
126+
return self
127+
128+
@property
129+
def rows(self) -> list[BatchIndexRow]:
130+
"""Create the master manifest that will gather the information needed to create batch Environments."""
131+
rows: list[BatchIndexRow] = []
132+
logging.info("Preparing BatchIndexRows. Each row represents a batch job.")
133+
for idx, batch in enumerate(self.vars_batches):
134+
rows.append(
135+
{
136+
"idx": idx + 1,
137+
"command": BatchCommandsSerialized(
138+
options=self.options, commands=self.commands
139+
),
140+
"environment": batch,
141+
}
142+
)
143+
144+
logging.info("Prepared %s BatchIndexRows", len(rows))
145+
if not rows:
146+
raise AirflowSkipException(
147+
"Empty BatchIndexRows will not allow to create batch task. Skipping downstream"
148+
)
149+
return rows
150+
151+
def __repr__(self) -> str:
152+
"""Get batch index string representation."""
153+
return f"BatchIndex(vars_list={self.vars_list}, options={self.options}, commands={self.commands}, max_task_count={self.max_task_count})"
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Operators for batch job."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import time
7+
from typing import Type
8+
9+
from airflow.models.baseoperator import BaseOperator
10+
from airflow.providers.google.cloud.operators.cloud_batch import (
11+
CloudBatchSubmitJobOperator,
12+
)
13+
14+
from ot_orchestration.operators.batch.batch_index import (
15+
BatchCommands,
16+
BatchEnvironments,
17+
BatchIndexRow,
18+
)
19+
from ot_orchestration.operators.batch.manifest_generators import ProtoManifestGenerator
20+
from ot_orchestration.operators.batch.manifest_generators.harmonisation import (
21+
HarmonisationManifestGenerator,
22+
)
23+
from ot_orchestration.types import GoogleBatchIndexSpecs, GoogleBatchSpecs
24+
from ot_orchestration.utils.batch import create_batch_job, create_task_spec
25+
from ot_orchestration.utils.common import GCP_PROJECT_GENETICS, GCP_REGION
26+
27+
logging.basicConfig(level=logging.DEBUG)
28+
29+
30+
class BatchIndexOperator(BaseOperator):
31+
"""Operator to prepare google batch job index.
32+
33+
Each manifest prepared by the operator should create an environment for a single batch job.
34+
Each row of the individual manifest should represent individual batch task.
35+
"""
36+
37+
# NOTE: here register all manifest generators.
38+
manifest_generator_registry: dict[str, Type[ProtoManifestGenerator]] = {
39+
"gwas_catalog_harmonisation": HarmonisationManifestGenerator
40+
}
41+
42+
def __init__(
43+
self,
44+
batch_index_specs: GoogleBatchIndexSpecs,
45+
**kwargs,
46+
) -> None:
47+
self.generator_label = batch_index_specs["manifest_generator_label"]
48+
self.manifest_generator = self.get_generator(self.generator_label)
49+
self.manifest_generator_specs = batch_index_specs["manifest_generator_specs"]
50+
self.max_task_count = batch_index_specs["max_task_count"]
51+
super().__init__(**kwargs)
52+
53+
@classmethod
54+
def get_generator(cls, label: str) -> Type[ProtoManifestGenerator]:
55+
"""Get the generator by it's label in the registry."""
56+
return cls.manifest_generator_registry[label]
57+
58+
def execute(self, context) -> list[BatchIndexRow]:
59+
"""Execute the operator."""
60+
generator = self.manifest_generator.from_generator_config(
61+
self.manifest_generator_specs, max_task_count=self.max_task_count
62+
)
63+
index = generator.generate_batch_index()
64+
self.log.info(index)
65+
partitioned_index = index.partition()
66+
rows = partitioned_index.rows
67+
return rows
68+
69+
70+
class GeneticsBatchJobOperator(CloudBatchSubmitJobOperator):
71+
def __init__(
72+
self,
73+
job_name: str,
74+
batch_index_row: BatchIndexRow,
75+
google_batch: GoogleBatchSpecs,
76+
**kwargs,
77+
):
78+
super().__init__(
79+
project_id=GCP_PROJECT_GENETICS,
80+
region=GCP_REGION,
81+
job_name=f"{job_name}-job-{batch_index_row['idx']}-{time.strftime('%Y%m%d-%H%M%S')}",
82+
job=create_batch_job(
83+
task=create_task_spec(
84+
image=google_batch["image"],
85+
commands=BatchCommands.deserialize(
86+
batch_index_row["command"]
87+
).construct(),
88+
task_specs=google_batch["task_specs"],
89+
resource_specs=google_batch["resource_specs"],
90+
entrypoint=google_batch["entrypoint"],
91+
),
92+
task_env=BatchEnvironments.deserialize(
93+
batch_index_row["environment"]
94+
).construct(),
95+
policy_specs=google_batch["policy_specs"],
96+
),
97+
deferrable=False,
98+
**kwargs,
99+
)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Manifest generators."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Protocol
6+
7+
from ot_orchestration.operators.batch.batch_index import BatchIndex
8+
from ot_orchestration.types import ManifestGeneratorSpecs
9+
10+
11+
class ProtoManifestGenerator(Protocol):
12+
@classmethod
13+
def from_generator_config(
14+
cls, specs: ManifestGeneratorSpecs, max_task_count: int
15+
) -> ProtoManifestGenerator:
16+
"""Constructor for Manifest Generator."""
17+
raise NotImplementedError("Implement it in subclasses")
18+
19+
def generate_batch_index(self) -> BatchIndex:
20+
"""Generate batch index."""
21+
raise NotImplementedError("Implement it in subclasses")

0 commit comments

Comments
 (0)