From fb5b9c981a5530bd11d6af8b9da0dc5b1e2d6862 Mon Sep 17 00:00:00 2001 From: EddieLF Date: Thu, 4 Sep 2025 10:41:28 +1000 Subject: [PATCH] Allow writing multiple files into a single analysis record --- src/cpg_flow/metamist.py | 14 +++++--- src/cpg_flow/stage.py | 76 +++++++++++++++++++++++++++------------- src/cpg_flow/status.py | 44 +++++++++++------------ tests/test_metamist.py | 6 ++-- tests/test_status.py | 4 +-- 5 files changed, 89 insertions(+), 55 deletions(-) diff --git a/src/cpg_flow/metamist.py b/src/cpg_flow/metamist.py index dd6bdbf7..21a3f96a 100644 --- a/src/cpg_flow/metamist.py +++ b/src/cpg_flow/metamist.py @@ -386,7 +386,7 @@ def get_analyses_by_sgid( def create_analysis( # noqa: PLR0917 self, - output: Path | str, + outputs: dict[str, Path | str], type_: str | AnalysisType, status: str | AnalysisStatus, cohort_ids: list[str] | None = None, @@ -404,6 +404,12 @@ def create_analysis( # noqa: PLR0917 if isinstance(status, AnalysisStatus): status = status.value + if len(outputs) == 1: + # If output is a single file, just use the string path + outputs_ = {'basename': str(next(iter(outputs.values())))} + else: + outputs_ = {k: {'basename': str(v)} for k, v in outputs.items()} + if not cohort_ids: cohort_ids = [] @@ -413,7 +419,7 @@ def create_analysis( # noqa: PLR0917 am = models.Analysis( type=type_, status=models.AnalysisStatus(status), - output=str(output), + outputs=outputs_, cohort_ids=list(cohort_ids), sequencing_group_ids=list(sequencing_group_ids), meta=meta or {}, @@ -425,11 +431,11 @@ def create_analysis( # noqa: PLR0917 ) if aid is None: logger.error( - f'Failed to create Analysis(type={type_}, status={status}, output={output!s}) in {metamist_proj}', + f'Failed to create Analysis(type={type_}, status={status}, outputs={outputs_!s}) in {metamist_proj}', ) return None logger.info( - f'Created Analysis(id={aid}, type={type_}, status={status}, output={output!s}) in {metamist_proj}', + f'Created Analysis(id={aid}, type={type_}, status={status}, outputs={outputs_!s}) in {metamist_proj}', ) return aid diff --git a/src/cpg_flow/stage.py b/src/cpg_flow/stage.py index e16097ba..49144404 100644 --- a/src/cpg_flow/stage.py +++ b/src/cpg_flow/stage.py @@ -30,7 +30,7 @@ from cpg_flow.utils import ExpectedResultT, exists from cpg_flow.workflow import Action, WorkflowError, get_workflow, path_walk from cpg_utils import Path, to_path -from cpg_utils.config import get_config +from cpg_utils.config import config_retrieve, get_config from cpg_utils.hail_batch import get_batch StageDecorator = Callable[..., 'Stage'] @@ -362,11 +362,12 @@ def __init__( required_stages: list[StageDecorator] | StageDecorator | None = None, analysis_type: str | None = None, analysis_keys: list[str] | None = None, - update_analysis_meta: Callable[[str], dict] | None = None, + update_analysis_meta: Callable[[dict], dict] | None = None, tolerate_missing_output: bool = False, skipped: bool = False, assume_outputs_exist: bool = False, forced: bool = False, + merge_analyses: bool = False, ): self._name = name self.required_stages_classes: list[StageDecorator] = [] @@ -384,9 +385,9 @@ def __init__( # entries in Metamist. self.analysis_type = analysis_type # If `analysis_keys` are defined, it will be used to extract the value for - # `Analysis.output` if the Stage.expected_outputs() returns a dict. + # `Analysis.outputs` if the Stage.expected_outputs() returns a dict. self.analysis_keys = analysis_keys - # if `update_analysis_meta` is defined, it is called on the `Analysis.output` + # if `update_analysis_meta` is defined, it is called on the `Analysis.outputs` # field, and result is merged into the `Analysis.meta` dictionary. self.update_analysis_meta = update_analysis_meta @@ -402,6 +403,10 @@ def __init__( ) self.assume_outputs_exist = assume_outputs_exist + # If true, only one analysis output will be created for this stage, + self.merge_analyses = merge_analyses or self.name in config_retrieve(['workflow', 'merge_analyses_stages'], []) + + @property def tmp_prefix(self): return get_workflow().tmp_prefix / self.name @@ -511,6 +516,8 @@ def make_outputs( """ Create StageOutput for this stage. """ + if isinstance(data, (str | Path)): + data = {'output': data} return StageOutput( target=target, data=data, @@ -564,7 +571,7 @@ def _queue_jobs_with_checks( # Adding status reporter jobs if self.analysis_type and self.status_reporter and action == Action.QUEUE and outputs.data: - analysis_outputs: list[str | Path] = [] + analysis_outputs: dict[str, str | Path] = {} if isinstance(outputs.data, dict): if not self.analysis_keys: raise WorkflowError( @@ -584,12 +591,12 @@ def _queue_jobs_with_checks( for analysis_key in self.analysis_keys: data = outputs.data.get(analysis_key) if isinstance(data, list): - analysis_outputs.extend(data) + analysis_outputs[analysis_key] = data elif data is not None: - analysis_outputs.append(data) + analysis_outputs['output'] = data else: - analysis_outputs.append(outputs.data) + analysis_outputs['output'] = outputs.data project_name = None if isinstance(target, SequencingGroup | Cohort): @@ -605,20 +612,24 @@ def _queue_jobs_with_checks( if get_config()['workflow']['access_level'] == 'test' and 'test' not in project_name: project_name = f'{project_name}-test' - for analysis_output in analysis_outputs: - if not outputs.jobs: - continue + if not outputs.jobs: + pass # No jobs means no analyses + + if outputs.meta is None: + outputs.meta = {} + + elif self.merge_analyses and len(analysis_outputs) > 1: + # If merge_analyses is True, we will create a single analysis entry + # for all outputs of this stage. + assert isinstance(analysis_outputs, dict), ( + f'Expected outputs for stage {self.name} should be a dict with string keys and str or Path values, ' + f'but got {analysis_outputs}' + ) - assert isinstance( - analysis_output, - str | Path, - ), f'{analysis_output} should be a str or Path object' - if outputs.meta is None: - outputs.meta = {} self.status_reporter.create_analysis( b=get_batch(), - output=str(analysis_output), + outputs=analysis_outputs, analysis_type=self.analysis_type, target=target, jobs=outputs.jobs, @@ -629,6 +640,23 @@ def _queue_jobs_with_checks( project_name=project_name, ) + else: + for k, analysis_output in analysis_outputs.items(): + assert isinstance(analysis_output, (str | Path)), f'{analysis_output} should be a str or Path object' + + self.status_reporter.create_analysis( + b=get_batch(), + outputs={k: analysis_output}, + analysis_type=self.analysis_type, + target=target, + jobs=outputs.jobs, + job_attr=self.get_job_attrs(target) | {'stage': self.name, 'tool': 'metamist'}, + meta=outputs.meta, + update_analysis_meta=self.update_analysis_meta, + tolerate_missing_output=self.tolerate_missing_output, + project_name=project_name, + ) + return outputs def _get_action(self, target: TargetT) -> Action: @@ -776,7 +804,7 @@ def stage( *, analysis_type: str | None = None, analysis_keys: list[str | Path] | None = None, - update_analysis_meta: Callable[[str], dict] | None = None, + update_analysis_meta: Callable[[dict], dict] | None = None, tolerate_missing_output: bool = False, required_stages: list[StageDecorator] | StageDecorator | None = None, skipped: bool = False, @@ -791,7 +819,7 @@ def stage( *, analysis_type: str | None = None, analysis_keys: list[str | Path] | None = None, - update_analysis_meta: Callable[[str], dict] | None = None, + update_analysis_meta: Callable[[dict], dict] | None = None, tolerate_missing_output: bool = False, required_stages: list[StageDecorator] | StageDecorator | None = None, skipped: bool = False, @@ -812,12 +840,12 @@ def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> S @analysis_type: if defined, will be used to create/update `Analysis` entries using the status reporter. - @analysis_keys: is defined, will be used to extract the value for `Analysis.output` + @analysis_keys: is defined, will be used to extract the value for `Analysis.outputs` if the Stage.expected_outputs() returns a dict. - @update_analysis_meta: if defined, this function is called on the `Analysis.output` + @update_analysis_meta: if defined, this function is called on the `Analysis.outputs` field, and returns a dictionary to be merged into the `Analysis.meta` - @tolerate_missing_output: if True, when registering the output of this stage, - allow for the output file to be missing (only relevant for metamist entry) + @tolerate_missing_output: if True, when registering the outputs of this stage, + allow for the outputs file to be missing (only relevant for metamist entry) @required_stages: list of other stage classes that are required prerequisites for this stage. Outputs of those stages will be passed to `Stage.queue_jobs(... , inputs)` as `inputs`, and all required diff --git a/src/cpg_flow/status.py b/src/cpg_flow/status.py index f8c6e1e9..be6c093b 100644 --- a/src/cpg_flow/status.py +++ b/src/cpg_flow/status.py @@ -12,12 +12,12 @@ from cpg_flow.targets import Target from cpg_flow.targets.cohort import Cohort from cpg_flow.targets.multicohort import MultiCohort -from cpg_utils import to_path +from cpg_utils import Path, to_path from cpg_utils.config import get_config def complete_analysis_job( # noqa: PLR0917 - output: str, + outputs: dict[str, str | Path], analysis_type: str, cohort_ids: list[str], sg_ids: list[str], @@ -31,7 +31,8 @@ def complete_analysis_job( # noqa: PLR0917 this will register the analysis outputs from a Stage Args: - output (str): path to the output file + outputs dict[str, str | path]: dict of output files, + where the keys are the output names and the values are the paths. analysis_type (str): metamist analysis type sg_ids (list[str]): all CPG IDs relevant to this target project_name (str): project/dataset name @@ -40,11 +41,12 @@ def complete_analysis_job( # noqa: PLR0917 tolerate_missing (bool): if True, allow missing output """ - assert isinstance(output, str) - output_cloudpath = to_path(output) + assert isinstance(outputs, dict) + output_cloudpaths = {k: to_path(v) for k, v in outputs.items()} + if update_analysis_meta is not None: - meta |= update_analysis_meta(output) + meta = meta | update_analysis_meta(outputs) # if SG IDs are listed in the meta, remove them # these are already captured in the sg_ids list @@ -64,18 +66,15 @@ def complete_analysis_job( # noqa: PLR0917 # we know that es indexes are registered names, not files/dirs # skip all relevant checks for this output type if analysis_type != 'es-index': - if not output_cloudpath.exists(): - if tolerate_missing: - print(f"Output {output} doesn't exist, allowing silent return") - return - raise ValueError(f"Output {output} doesn't exist") - - # add file size to meta - if not output_cloudpath.is_dir(): - meta |= {'size': output_cloudpath.stat().st_size} + for k, output_cloudpath in output_cloudpaths.items(): + if not output_cloudpath.exists(): + if tolerate_missing: + print(f"Output {output_cloudpath} doesn't exist, allowing silent return") + return + raise ValueError(f"Output {k}: {output_cloudpath} doesn't exist") a_id = get_metamist().create_analysis( - output=output, + outputs=outputs, type_=analysis_type, status=AnalysisStatus('completed'), cohort_ids=cohort_ids, @@ -84,11 +83,11 @@ def complete_analysis_job( # noqa: PLR0917 meta=meta, ) if a_id is None: - msg = f'Creation of Analysis failed (type={analysis_type}, output={output}) in {project_name}' + msg = f'Creation of Analysis failed (type={analysis_type}, outputs={outputs}) in {project_name}' print(msg) raise ConnectionError(msg) print( - f'Created Analysis(id={a_id}, type={analysis_type}, output={output}) in {project_name}', + f'Created Analysis(id={a_id}, type={analysis_type}, outputs={outputs}) in {project_name}', ) @@ -107,7 +106,7 @@ class StatusReporter(ABC): def create_analysis( # noqa: PLR0917 self, b: Batch, - output: str, + outputs: dict[str, str | Path], analysis_type: str, target: Target, jobs: list[Job] | None = None, @@ -133,7 +132,7 @@ def __init__(self) -> None: @staticmethod def create_analysis( # noqa: PLR0917 b: Batch, - output: str, + outputs: dict[str, str | Path], analysis_type: str, target: Target, jobs: list[Job] | None = None, @@ -169,14 +168,15 @@ def create_analysis( # noqa: PLR0917 else: sequencing_group_ids = target.get_sequencing_group_ids() + outputs_str = ', '.join(f'{k}: {str(v)}' for k, v in outputs.items()) py_job = b.new_python_job( - f'Register analysis output {output}', + f'Register analysis output {outputs_str}', job_attr or {} | {'tool': 'metamist'}, ) py_job.image(get_config()['workflow']['driver_image']) py_job.call( complete_analysis_job, - str(output), + outputs, analysis_type, cohort_ids, sequencing_group_ids, diff --git a/tests/test_metamist.py b/tests/test_metamist.py index a7ca5a47..fcfe9327 100644 --- a/tests/test_metamist.py +++ b/tests/test_metamist.py @@ -222,7 +222,7 @@ def test_metamist_create_analysis(self, mocker: MockFixture, metamist: Metamist) mock_aapi_create_analysis_err, ) analysis = metamist.create_analysis( - output=to_path('test_output'), + outputs=to_path('test_output'), type_='test', status='completed', sequencing_group_ids=['test'], @@ -236,7 +236,7 @@ def test_metamist_create_analysis(self, mocker: MockFixture, metamist: Metamist) mock_aapi_create_analysis_timeout, ) analysis = metamist.create_analysis( - output=to_path('test_output'), + outputs=to_path('test_output'), type_=AnalysisType.parse('custom'), status=AnalysisStatus.parse('completed'), sequencing_group_ids=['test'], @@ -249,7 +249,7 @@ def test_metamist_create_analysis(self, mocker: MockFixture, metamist: Metamist) mock_aapi_create_analysis_ok, ) analysis = metamist.create_analysis( - output=to_path('test_output'), + outputs=to_path('test_output'), type_=AnalysisType.parse('custom'), status=AnalysisStatus.parse('completed'), sequencing_group_ids=['test'], diff --git a/tests/test_status.py b/tests/test_status.py index e47bfff2..522df708 100644 --- a/tests/test_status.py +++ b/tests/test_status.py @@ -142,8 +142,8 @@ def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> S assert get_batch().job_by_tool['metamist']['job_n'] == len(get_multicohort().get_sequencing_groups()) * 2 -def _update_meta(output_path: str) -> dict[str, Any]: - with to_path(output_path).open() as f: +def _update_meta(output_paths: dict[str, str | Path]) -> dict[str, Any]: + with to_path(output_paths['output']).open() as f: return {'result': f.read().strip()}