Skip to content

Commit 996b0d0

Browse files
author
Szymon Szyszkowski
committed
feat: improved dag
1 parent b4615fd commit 996b0d0

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

src/ot_orchestration/task_groups/gwas_catalog/batch_processing.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from airflow.operators.python import get_current_context
1111
from airflow.utils.helpers import chain
1212
from ot_orchestration.utils import create_batch_job, create_task_spec
13-
from ot_orchestration.utils import GCSPath
1413
from airflow.models import TaskInstance
1514
import logging
1615
import time
@@ -23,8 +22,10 @@ def gwas_catalog_batch_processing() -> None:
2322
@task(task_id="get_manifests_from_preparation", multiple_outputs=True)
2423
def get_batch_task_inputs(
2524
task_instance: TaskInstance | None = None,
26-
) -> list[GCSPath]:
25+
) -> dict[str, str]:
2726
"""Get manifests from preparation step."""
27+
if task_instance is None:
28+
raise ValueError("Task instance is None")
2829
manifest_paths = task_instance.xcom_pull(
2930
task_ids="manifest_preparation.choose_manifest_paths"
3031
)
@@ -38,9 +39,7 @@ def get_batch_task_inputs(
3839
batch_inputs = get_batch_task_inputs()
3940

4041
@task(task_id="batch_job", multiple_outputs=True)
41-
def execute_batch_job(
42-
manifest_paths: list[str], config_path: str
43-
) -> CloudBatchSubmitJobOperator:
42+
def execute_batch_job(manifest_paths: list[str], config_path: str):
4443
"""Create a harmonisation batch job."""
4544
params = get_step_params("batch_processing")
4645
logging.info("PARAMS: %s", params)

src/ot_orchestration/task_groups/gwas_catalog/manifest_preparation.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from airflow.decorators import task, task_group
44
from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator
55
from ot_orchestration.types import Manifest_Object
6-
from ot_orchestration.utils import GCSIOManager, get_step_params, get_full_config
6+
from ot_orchestration.utils import IOManager, get_step_params, get_full_config
77
from airflow.models.baseoperator import chain
88
from ot_orchestration.utils.manifest import extract_study_id_from_path
99
from airflow.utils.edgemodifier import Label
@@ -57,6 +57,8 @@ def get_new_sumstat_paths(
5757
def collect_sumstats_and_generate_new_manifests(
5858
ti: TaskInstance | None = None,
5959
) -> list[Manifest_Object]:
60+
if ti is None:
61+
raise ValueError("Task instance is None")
6062
task_id: str = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode")
6163
logging.info("TASK ID: %s", task_id)
6264
new_sumstats = ti.xcom_pull(task_ids=task_id)
@@ -97,6 +99,8 @@ def amend_curation_metadata(new_manifests: list[Manifest_Object]):
9799
params = get_step_params("manifest_preparation")
98100
logging.info("USING FOLLOWING PARAMS: %s", params)
99101
curation_path = params["manual_curation_manifest"]
102+
if not isinstance(curation_path, str):
103+
raise ValueError("Curation path is not a string")
100104
logging.info("CURATING MANIFESTS WITH: %s", curation_path)
101105
curation_df = pd.read_csv(curation_path, sep="\t").drop(
102106
columns=["publicationTitle", "traitFromSource", "qualityControl"]
@@ -116,18 +120,20 @@ def amend_curation_metadata(new_manifests: list[Manifest_Object]):
116120
def read_manifests(manifest_paths: list[str]) -> list[Manifest_Object]:
117121
"""Read manifests."""
118122
manifest_paths = [f"gs://{path}" for path in manifest_paths]
119-
return GCSIOManager().load_many(manifest_paths)
123+
return IOManager().load_many(manifest_paths)
120124

121125

122126
@task(task_id="save_config")
123127
def save_config(task_instance: TaskInstance | None = None) -> str:
124128
"""Save configuration for batch processing."""
129+
if task_instance is None:
130+
raise ValueError("Task instance is None")
125131
run_id = task_instance.run_id
126132
params = get_step_params("manifest_preparation")
127133
full_config = get_full_config().serialize()
128134
config_path = f"gs://{params['staging_bucket']}/{params['staging_prefix']}/{run_id}/config.yaml"
129135
logging.info("DUMPING CONFIG TO THE FOLLOWING PATH: %s", config_path)
130-
GCSIOManager().dump(gcs_path=config_path, data=full_config)
136+
IOManager().resolve(config_path).dump(full_config)
131137
return config_path
132138

133139

@@ -137,6 +143,8 @@ def save_config(task_instance: TaskInstance | None = None) -> str:
137143
)
138144
def choose_manifest_paths(ti: TaskInstance | None = None) -> list[str]:
139145
"""Choose manifests to pass to the next."""
146+
if ti is None:
147+
raise ValueError("Task instance is None")
140148
task_id: str = ti.xcom_pull(task_ids="manifest_preparation.get_execution_mode")
141149
logging.info("TASK ID: %s", task_id)
142150
if not task_id.endswith("read_manifests"):
@@ -150,7 +158,7 @@ def save_manifests(manifests: list[Manifest_Object]) -> list[Manifest_Object]:
150158
"""Write manifests to persistant storage."""
151159
manifest_paths = [manifest["manifestPath"] for manifest in manifests]
152160
logging.info("MANIFEST PATHS: %s", manifest_paths)
153-
GCSIOManager().dump_many(manifests, manifest_paths)
161+
IOManager().dump_many(manifests, manifest_paths)
154162
return manifests
155163

156164

@@ -164,7 +172,6 @@ def exit_when_no_new_sumstats(new_sumstats: dict[str, str]) -> bool:
164172
@task_group(group_id=TASK_GROUP_ID)
165173
def gwas_catalog_manifest_preparation():
166174
"""Prepare initial manifest."""
167-
168175
fetch_existing_manifests = GCSListObjectsOperator(
169176
task_id="list_existing_manifests",
170177
bucket="{{ params.steps.manifest_preparation.staging_bucket }}",

0 commit comments

Comments
 (0)