Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/cpg_flow/metamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def get_analyses_by_sgid(

def create_analysis( # noqa: PLR0917
self,
output: Path | str,
outputs: dict[str, Path | str],
type_: str | AnalysisType,
status: str | AnalysisStatus,
cohort_ids: list[str] | None = None,
Expand All @@ -404,6 +404,12 @@ def create_analysis( # noqa: PLR0917
if isinstance(status, AnalysisStatus):
status = status.value

if len(outputs) == 1:
# If output is a single file, just use the string path
outputs_ = {'basename': str(next(iter(outputs.values())))}
else:
outputs_ = {k: {'basename': str(v)} for k, v in outputs.items()}

if not cohort_ids:
cohort_ids = []

Expand All @@ -413,7 +419,7 @@ def create_analysis( # noqa: PLR0917
am = models.Analysis(
type=type_,
status=models.AnalysisStatus(status),
output=str(output),
outputs=outputs_,
cohort_ids=list(cohort_ids),
sequencing_group_ids=list(sequencing_group_ids),
meta=meta or {},
Expand All @@ -425,11 +431,11 @@ def create_analysis( # noqa: PLR0917
)
if aid is None:
logger.error(
f'Failed to create Analysis(type={type_}, status={status}, output={output!s}) in {metamist_proj}',
f'Failed to create Analysis(type={type_}, status={status}, outputs={outputs_!s}) in {metamist_proj}',
)
return None
logger.info(
f'Created Analysis(id={aid}, type={type_}, status={status}, output={output!s}) in {metamist_proj}',
f'Created Analysis(id={aid}, type={type_}, status={status}, outputs={outputs_!s}) in {metamist_proj}',
)
return aid

Expand Down
76 changes: 52 additions & 24 deletions src/cpg_flow/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from cpg_flow.utils import ExpectedResultT, exists
from cpg_flow.workflow import Action, WorkflowError, get_workflow, path_walk
from cpg_utils import Path, to_path
from cpg_utils.config import get_config
from cpg_utils.config import config_retrieve, get_config
from cpg_utils.hail_batch import get_batch

StageDecorator = Callable[..., 'Stage']
Expand Down Expand Up @@ -362,11 +362,12 @@ def __init__(
required_stages: list[StageDecorator] | StageDecorator | None = None,
analysis_type: str | None = None,
analysis_keys: list[str] | None = None,
update_analysis_meta: Callable[[str], dict] | None = None,
update_analysis_meta: Callable[[dict], dict] | None = None,
tolerate_missing_output: bool = False,
skipped: bool = False,
assume_outputs_exist: bool = False,
forced: bool = False,
merge_analyses: bool = False,
):
self._name = name
self.required_stages_classes: list[StageDecorator] = []
Expand All @@ -384,9 +385,9 @@ def __init__(
# entries in Metamist.
self.analysis_type = analysis_type
# If `analysis_keys` are defined, it will be used to extract the value for
# `Analysis.output` if the Stage.expected_outputs() returns a dict.
# `Analysis.outputs` if the Stage.expected_outputs() returns a dict.
self.analysis_keys = analysis_keys
# if `update_analysis_meta` is defined, it is called on the `Analysis.output`
# if `update_analysis_meta` is defined, it is called on the `Analysis.outputs`
# field, and result is merged into the `Analysis.meta` dictionary.
self.update_analysis_meta = update_analysis_meta

Expand All @@ -402,6 +403,10 @@ def __init__(
)
self.assume_outputs_exist = assume_outputs_exist

# If true, only one analysis output will be created for this stage,
self.merge_analyses = merge_analyses or self.name in config_retrieve(['workflow', 'merge_analyses_stages'], [])


@property
def tmp_prefix(self):
return get_workflow().tmp_prefix / self.name
Expand Down Expand Up @@ -511,6 +516,8 @@ def make_outputs(
"""
Create StageOutput for this stage.
"""
if isinstance(data, (str | Path)):
data = {'output': data}
return StageOutput(
target=target,
data=data,
Expand Down Expand Up @@ -564,7 +571,7 @@ def _queue_jobs_with_checks(

# Adding status reporter jobs
if self.analysis_type and self.status_reporter and action == Action.QUEUE and outputs.data:
analysis_outputs: list[str | Path] = []
analysis_outputs: dict[str, str | Path] = {}
if isinstance(outputs.data, dict):
if not self.analysis_keys:
raise WorkflowError(
Expand All @@ -584,12 +591,12 @@ def _queue_jobs_with_checks(
for analysis_key in self.analysis_keys:
data = outputs.data.get(analysis_key)
if isinstance(data, list):
analysis_outputs.extend(data)
analysis_outputs[analysis_key] = data
elif data is not None:
analysis_outputs.append(data)
analysis_outputs['output'] = data

else:
analysis_outputs.append(outputs.data)
analysis_outputs['output'] = outputs.data

project_name = None
if isinstance(target, SequencingGroup | Cohort):
Expand All @@ -605,20 +612,24 @@ def _queue_jobs_with_checks(
if get_config()['workflow']['access_level'] == 'test' and 'test' not in project_name:
project_name = f'{project_name}-test'

for analysis_output in analysis_outputs:
if not outputs.jobs:
continue
if not outputs.jobs:
pass # No jobs means no analyses

if outputs.meta is None:
outputs.meta = {}

elif self.merge_analyses and len(analysis_outputs) > 1:
# If merge_analyses is True, we will create a single analysis entry
# for all outputs of this stage.
assert isinstance(analysis_outputs, dict), (
f'Expected outputs for stage {self.name} should be a dict with string keys and str or Path values, '
f'but got {analysis_outputs}'
)

assert isinstance(
analysis_output,
str | Path,
), f'{analysis_output} should be a str or Path object'
if outputs.meta is None:
outputs.meta = {}

self.status_reporter.create_analysis(
b=get_batch(),
output=str(analysis_output),
outputs=analysis_outputs,
analysis_type=self.analysis_type,
target=target,
jobs=outputs.jobs,
Expand All @@ -629,6 +640,23 @@ def _queue_jobs_with_checks(
project_name=project_name,
)

else:
for k, analysis_output in analysis_outputs.items():
assert isinstance(analysis_output, (str | Path)), f'{analysis_output} should be a str or Path object'

self.status_reporter.create_analysis(
b=get_batch(),
outputs={k: analysis_output},
analysis_type=self.analysis_type,
target=target,
jobs=outputs.jobs,
job_attr=self.get_job_attrs(target) | {'stage': self.name, 'tool': 'metamist'},
meta=outputs.meta,
update_analysis_meta=self.update_analysis_meta,
tolerate_missing_output=self.tolerate_missing_output,
project_name=project_name,
)

return outputs

def _get_action(self, target: TargetT) -> Action:
Expand Down Expand Up @@ -776,7 +804,7 @@ def stage(
*,
analysis_type: str | None = None,
analysis_keys: list[str | Path] | None = None,
update_analysis_meta: Callable[[str], dict] | None = None,
update_analysis_meta: Callable[[dict], dict] | None = None,
tolerate_missing_output: bool = False,
required_stages: list[StageDecorator] | StageDecorator | None = None,
skipped: bool = False,
Expand All @@ -791,7 +819,7 @@ def stage(
*,
analysis_type: str | None = None,
analysis_keys: list[str | Path] | None = None,
update_analysis_meta: Callable[[str], dict] | None = None,
update_analysis_meta: Callable[[dict], dict] | None = None,
tolerate_missing_output: bool = False,
required_stages: list[StageDecorator] | StageDecorator | None = None,
skipped: bool = False,
Expand All @@ -812,12 +840,12 @@ def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> S

@analysis_type: if defined, will be used to create/update `Analysis` entries
using the status reporter.
@analysis_keys: is defined, will be used to extract the value for `Analysis.output`
@analysis_keys: is defined, will be used to extract the value for `Analysis.outputs`
if the Stage.expected_outputs() returns a dict.
@update_analysis_meta: if defined, this function is called on the `Analysis.output`
@update_analysis_meta: if defined, this function is called on the `Analysis.outputs`
field, and returns a dictionary to be merged into the `Analysis.meta`
@tolerate_missing_output: if True, when registering the output of this stage,
allow for the output file to be missing (only relevant for metamist entry)
@tolerate_missing_output: if True, when registering the outputs of this stage,
allow for the outputs file to be missing (only relevant for metamist entry)
@required_stages: list of other stage classes that are required prerequisites
for this stage. Outputs of those stages will be passed to
`Stage.queue_jobs(... , inputs)` as `inputs`, and all required
Expand Down
44 changes: 22 additions & 22 deletions src/cpg_flow/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from cpg_flow.targets import Target
from cpg_flow.targets.cohort import Cohort
from cpg_flow.targets.multicohort import MultiCohort
from cpg_utils import to_path
from cpg_utils import Path, to_path
from cpg_utils.config import get_config


def complete_analysis_job( # noqa: PLR0917
output: str,
outputs: dict[str, str | Path],
analysis_type: str,
cohort_ids: list[str],
sg_ids: list[str],
Expand All @@ -31,7 +31,8 @@ def complete_analysis_job( # noqa: PLR0917
this will register the analysis outputs from a Stage

Args:
output (str): path to the output file
outputs dict[str, str | path]: dict of output files,
where the keys are the output names and the values are the paths.
analysis_type (str): metamist analysis type
sg_ids (list[str]): all CPG IDs relevant to this target
project_name (str): project/dataset name
Expand All @@ -40,11 +41,12 @@ def complete_analysis_job( # noqa: PLR0917
tolerate_missing (bool): if True, allow missing output
"""

assert isinstance(output, str)
output_cloudpath = to_path(output)
assert isinstance(outputs, dict)
output_cloudpaths = {k: to_path(v) for k, v in outputs.items()}


if update_analysis_meta is not None:
meta |= update_analysis_meta(output)
meta = meta | update_analysis_meta(outputs)

# if SG IDs are listed in the meta, remove them
# these are already captured in the sg_ids list
Expand All @@ -64,18 +66,15 @@ def complete_analysis_job( # noqa: PLR0917
# we know that es indexes are registered names, not files/dirs
# skip all relevant checks for this output type
if analysis_type != 'es-index':
if not output_cloudpath.exists():
if tolerate_missing:
print(f"Output {output} doesn't exist, allowing silent return")
return
raise ValueError(f"Output {output} doesn't exist")

# add file size to meta
if not output_cloudpath.is_dir():
meta |= {'size': output_cloudpath.stat().st_size}
for k, output_cloudpath in output_cloudpaths.items():
if not output_cloudpath.exists():
if tolerate_missing:
print(f"Output {output_cloudpath} doesn't exist, allowing silent return")
return
raise ValueError(f"Output {k}: {output_cloudpath} doesn't exist")

a_id = get_metamist().create_analysis(
output=output,
outputs=outputs,
type_=analysis_type,
status=AnalysisStatus('completed'),
cohort_ids=cohort_ids,
Expand All @@ -84,11 +83,11 @@ def complete_analysis_job( # noqa: PLR0917
meta=meta,
)
if a_id is None:
msg = f'Creation of Analysis failed (type={analysis_type}, output={output}) in {project_name}'
msg = f'Creation of Analysis failed (type={analysis_type}, outputs={outputs}) in {project_name}'
print(msg)
raise ConnectionError(msg)
print(
f'Created Analysis(id={a_id}, type={analysis_type}, output={output}) in {project_name}',
f'Created Analysis(id={a_id}, type={analysis_type}, outputs={outputs}) in {project_name}',
)


Expand All @@ -107,7 +106,7 @@ class StatusReporter(ABC):
def create_analysis( # noqa: PLR0917
self,
b: Batch,
output: str,
outputs: dict[str, str | Path],
analysis_type: str,
target: Target,
jobs: list[Job] | None = None,
Expand All @@ -133,7 +132,7 @@ def __init__(self) -> None:
@staticmethod
def create_analysis( # noqa: PLR0917
b: Batch,
output: str,
outputs: dict[str, str | Path],
analysis_type: str,
target: Target,
jobs: list[Job] | None = None,
Expand Down Expand Up @@ -169,14 +168,15 @@ def create_analysis( # noqa: PLR0917
else:
sequencing_group_ids = target.get_sequencing_group_ids()

outputs_str = ', '.join(f'{k}: {str(v)}' for k, v in outputs.items())
py_job = b.new_python_job(
f'Register analysis output {output}',
f'Register analysis output {outputs_str}',
job_attr or {} | {'tool': 'metamist'},
)
py_job.image(get_config()['workflow']['driver_image'])
py_job.call(
complete_analysis_job,
str(output),
outputs,
analysis_type,
cohort_ids,
sequencing_group_ids,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_metamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_metamist_create_analysis(self, mocker: MockFixture, metamist: Metamist)
mock_aapi_create_analysis_err,
)
analysis = metamist.create_analysis(
output=to_path('test_output'),
outputs=to_path('test_output'),
type_='test',
status='completed',
sequencing_group_ids=['test'],
Expand All @@ -236,7 +236,7 @@ def test_metamist_create_analysis(self, mocker: MockFixture, metamist: Metamist)
mock_aapi_create_analysis_timeout,
)
analysis = metamist.create_analysis(
output=to_path('test_output'),
outputs=to_path('test_output'),
type_=AnalysisType.parse('custom'),
status=AnalysisStatus.parse('completed'),
sequencing_group_ids=['test'],
Expand All @@ -249,7 +249,7 @@ def test_metamist_create_analysis(self, mocker: MockFixture, metamist: Metamist)
mock_aapi_create_analysis_ok,
)
analysis = metamist.create_analysis(
output=to_path('test_output'),
outputs=to_path('test_output'),
type_=AnalysisType.parse('custom'),
status=AnalysisStatus.parse('completed'),
sequencing_group_ids=['test'],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> S
assert get_batch().job_by_tool['metamist']['job_n'] == len(get_multicohort().get_sequencing_groups()) * 2


def _update_meta(output_path: str) -> dict[str, Any]:
with to_path(output_path).open() as f:
def _update_meta(output_paths: dict[str, str | Path]) -> dict[str, Any]:
with to_path(output_paths['output']).open() as f:
return {'result': f.read().strip()}


Expand Down
Loading