From a021789d4427e634077c6ca19ab67119acf37f9c Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Mon, 26 May 2025 10:48:31 +1000 Subject: [PATCH 1/8] Simplify the implementations of graph walking. --- src/cpg_flow/workflow.py | 176 +++++++----------- .../test_first_last_stages_misconfigured.py | 4 +- tests/test_workflow.py | 30 ++- 3 files changed, 94 insertions(+), 116 deletions(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index be0f811e..924c5b60 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -164,6 +164,7 @@ def run_workflow( return wfl + _TARGET: Final[str] = '\U0001f3af' _ONLY: Final[str] = '\U0001f449' _START: Final[str] = '\u23e9' @@ -233,6 +234,24 @@ def _render_node(node): return out +def _compute_shadow(graph: nx.DiGraph, shadow_casters: set[str]) -> set[str]: + """Compute the 'shadow' of a set of nodes on a directed graph. + + Shadowed nodes are those that are only not reachable from any root of the + graph without passing through a shadow caster node.""" + shadowed: set[str] = set(graph.nodes) + unvisited: set[str] = {node for node, in_degree in graph.in_degree() if not in_degree} + + while unvisited: + node = unvisited.pop() + shadowed.remove(node) + if node not in shadow_casters: + for descendant in graph.successors(node): + if descendant in shadowed: + unvisited.add(descendant) + return shadowed + + class Workflow: """ Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. @@ -356,75 +375,24 @@ def _process_first_last_stages( before first_stages, and all stages after last_stages (i.e. descendants and ancestors on the stages DAG.) """ - stages_d = {s.name: s for s in stages} - stage_names = list(stg.name for stg in stages) - lower_names = {s.lower() for s in stage_names} - - for param, _stage_list in [ - ('first_stages', first_stages), - ('last_stages', last_stages), - ]: - for _s_name in _stage_list: - if _s_name.lower() not in lower_names: - raise WorkflowError( - f'Value in workflow/{param} "{_s_name}" must be a stage name ' - f'or a subset of stages from the available list: ' - f'{", ".join(stage_names)}', - ) - - if not (last_stages or first_stages): + if not (first_stages or last_stages): return - # E.g. if our last_stages is CramQc, MtToEs would still run because it's in - # a different branch. So we want to collect all stages after first_stages - # and before last_stages in their respective branches, and mark as skipped - # everything in other branches. - first_stages_keeps: list[str] = first_stages[:] - last_stages_keeps: list[str] = last_stages[:] - - for fs in first_stages: - for descendant in nx.descendants(graph, fs): - if not stages_d[descendant].skipped: - logger.info( - f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', - ) - stages_d[descendant].skipped = True - for grand_descendant in nx.descendants(graph, descendant): - if not stages_d[grand_descendant].assume_outputs_exist: - logger.info( - f'Not checking expected outputs of not immediately ' - f'required stage {grand_descendant} (< {descendant} < {fs})', - ) - stages_d[grand_descendant].assume_outputs_exist = True - - for ancestor in nx.ancestors(graph, fs): - first_stages_keeps.append(ancestor) - - for ls in last_stages: - # ancestors of this last_stage - ancestors = nx.ancestors(graph, ls) - if any(anc in last_stages for anc in ancestors): - # a downstream stage is also in last_stages, so this is not yet - # a "real" last stage that we want to run - continue - for ancestor in ancestors: - if stages_d[ancestor].skipped: - continue # already skipped - logger.info(f'Skipping stage {ancestor} (after last {ls})') - stages_d[ancestor].skipped = True - stages_d[ancestor].assume_outputs_exist = True - - for ancestor in nx.descendants(graph, ls): - last_stages_keeps.append(ancestor) - - for _stage in stages: - if _stage.name not in last_stages_keeps + first_stages_keeps: - _stage.skipped = True - _stage.assume_outputs_exist = True + pre_first = _compute_shadow(graph, set(first_stages)) + post_last = _compute_shadow(graph.reverse(), set(last_stages)) - for stage in stages: - if stage.skipped: - graph.nodes[stage.name]['skipped'] = True + kept = set() + for node in first_stages: + kept.update({node} | nx.ancestors(graph, node)) + for node in last_stages: + kept.update({node} | nx.descendants(graph, node)) + + stage_d: dict[str, Stage] = {s.name: s for s in stages} + for node in pre_first | post_last | (set(stage_d.keys() - kept)): + stage = stage_d[node] + stage.skipped = True + stage.assume_outputs_exist = True + graph.nodes[node]['skipped'] = True @staticmethod def _process_only_stages( @@ -451,8 +419,7 @@ def _process_only_stages( # imediate predecessor stages, but skip everything else. required_stages: set[str] = set() for os in only_stages: - rs = nx.descendants_at_distance(graph, os, 1) - required_stages |= set(rs) + required_stages.update(nx.descendants_at_distance(graph, os, 1)) for stage in stages: # Skip stage not in only_stages, and assume outputs exist... @@ -497,24 +464,7 @@ def set_stages( logger.info(f' workflow/last_stages: {last_stages}') # Round 1: initialising stage objects. - stages_dict: dict[str, Stage] = {} - for cls in requested_stages: - if cls.__name__ in stages_dict: - continue - stages_dict[cls.__name__] = cls() - - # Round 2: depth search to find implicit stages. - stages_dict = self._resolve_implicit_stages( - stages_dict=stages_dict, - skip_stages=skip_stages, - only_stages=only_stages, - ) - - # Round 3: set "stage.required_stages" fields to each stage. - for stg in stages_dict.values(): - stg.required_stages = [ - stages_dict[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in stages_dict - ] + stages_dict: dict[str, Stage] = self._instantiate_stages(requested_stages, skip_stages, only_stages) # Round 4: determining order of execution. stages, dag = self._determine_order_of_execution(stages_dict) @@ -564,33 +514,33 @@ def set_stages( self._show_workflow(dag, skip_stages, only_stages, first_stages, last_stages) @staticmethod - def _resolve_implicit_stages(stages_dict: dict, skip_stages: list[str], only_stages: list[str]): - implicit_stages = {'first': 'loop'} - - while len(implicit_stages) > 0: - implicit_stages = dict() - for stg in stages_dict.values(): - if stg.name in skip_stages: - stg.skipped = True - continue # not searching deeper - - if only_stages and stg.name not in only_stages: - stg.skipped = True - - # Get all deps not already in stages_dict - not_in_stages_dict = { - cls().name: cls() for cls in stg.required_stages_classes if cls.__name__ not in stages_dict - } - implicit_stages |= not_in_stages_dict - - # If there's nothing more to add, finish search - if not implicit_stages: - break - - logger.info( - f'Additional implicit stages: {list(implicit_stages.keys())}', - ) - stages_dict |= implicit_stages + def _instantiate_stages(requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str]): + stages_dict: dict[str, Stage] = {} + + def _make_once(cls) -> tuple['Stage', bool]: + try: + return stages_dict[cls.__name__], False + except KeyError: + instance = stages_dict[cls.__name__] = cls() + return instance, True + + def _recursively_make_stage(cls): + instance, is_new = _make_once(cls) + if is_new: + instance.skipped = cls.__name__ in skip_stages + if not instance.skipped: + instance.required_stages.extend( + filter(None, map(_recursively_make_stage, instance.required_stages_classes)), + ) + return instance + + for cls in requested_stages: + _recursively_make_stage(cls) + + if only_stages: + for stage_name, stage in stages_dict.items(): + if stage_name not in only_stages: + stage.skipped = True return stages_dict diff --git a/tests/stages/test_first_last_stages_misconfigured.py b/tests/stages/test_first_last_stages_misconfigured.py index b444a06a..8671d3e3 100644 --- a/tests/stages/test_first_last_stages_misconfigured.py +++ b/tests/stages/test_first_last_stages_misconfigured.py @@ -6,7 +6,7 @@ from pytest_mock import MockFixture from tests import set_config -from tests.stages import run_workflow +from tests.stages import D, run_workflow def test_first_last_stages_misconfigured(mocker: MockFixture, tmp_path): @@ -47,4 +47,4 @@ def test_first_last_stages_misconfigured(mocker: MockFixture, tmp_path): from cpg_flow.workflow import WorkflowError with pytest.raises(WorkflowError, match='No stages to run'): - run_workflow(mocker) + run_workflow(mocker, stages=[D]) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index dfebba0d..4e23f1c4 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -2,7 +2,9 @@ Test building Workflow object. """ +import itertools import pathlib +import re from collections.abc import Collection, Mapping, Sequence from typing import Any, Final from unittest import mock @@ -19,7 +21,8 @@ stage, ) from cpg_flow.targets import Cohort, MultiCohort, SequencingGroup -from cpg_flow.workflow import _render_graph, path_walk, run_workflow +from cpg_flow.workflow import _compute_shadow, _render_graph, path_walk, run_workflow +from cpg_utils import Path, to_path from cpg_utils.config import dataset_path from cpg_utils.hail_batch import get_batch @@ -151,6 +154,7 @@ def test_path_walk(): } act = path_walk(exp) assert act == {pathlib.Path('this.txt'), pathlib.Path('that.txt'), pathlib.Path('the_other.txt')} + assert act == {to_path('this.txt'), to_path('that.txt'), to_path('the_other.txt')} @pytest.fixture() @@ -269,3 +273,27 @@ def test_render_graph_extra_args( graph = _create_graph_with_attrs(edges, skipped_nodes) result = ';'.join(_render_graph(graph, **extra_args)) assert result == expected + + +def _parse_graph(graph: str) -> nx.DiGraph: + g = nx.DiGraph() + for path in re.split(r'\s*;\s*', graph): + path = re.split(r'\s*->\s*', path) + for edge in itertools.pairwise(path): + g.add_edge(*edge) + return g + + +@pytest.mark.parametrize( + ['graph', 'casters', 'expected'], + [ + pytest.param('R->A->Caster->B->D', set(), set()), + pytest.param('R->A->Caster->B->D', {'X'}, set()), + pytest.param('R->A->Caster->B->D', {'Caster'}, {'B', 'D'}), + pytest.param('R->A->Caster->B->D;B->E', {'Caster'}, {'B', 'D', 'E'}), + pytest.param('R->A->Caster->B->D;A->D', {'Caster'}, {'B'}), + pytest.param('R1->A->B->D;R1->Caster1->B;R2->X->Y->Caster2->Z;R3->P', {'Caster1', 'Caster2'}, {'Z'}), + ], +) +def test_compute_shadow(graph: str, casters: set[str], expected: set[str]): + assert _compute_shadow(_parse_graph(graph), casters) == expected From a8bd619f8f40f4f0956f280783077cde358743a3 Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Mon, 26 May 2025 13:59:33 +1000 Subject: [PATCH 2/8] Update comments and move only_stages application Bring only_stages closer to first/last_stages. --- src/cpg_flow/workflow.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index 924c5b60..2d0f898f 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -463,13 +463,11 @@ def set_stages( logger.info(f' workflow/first_stages: {first_stages}') logger.info(f' workflow/last_stages: {last_stages}') - # Round 1: initialising stage objects. stages_dict: dict[str, Stage] = self._instantiate_stages(requested_stages, skip_stages, only_stages) - # Round 4: determining order of execution. stages, dag = self._determine_order_of_execution(stages_dict) - # Round 5: applying workflow options first_stages and last_stages. + # Apply workflow options first_stages and last_stages. if first_stages or last_stages: logger.info('Applying workflow/first_stages and workflow/last_stages') self._process_first_last_stages(stages, dag, first_stages, last_stages) From 07cb65c0ac6fdcc54a03e42b436bafbed590083d Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Mon, 26 May 2025 14:28:32 +1000 Subject: [PATCH 3/8] Fix grammatical error in comment --- src/cpg_flow/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index 2d0f898f..4a9bebb3 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -237,7 +237,7 @@ def _render_node(node): def _compute_shadow(graph: nx.DiGraph, shadow_casters: set[str]) -> set[str]: """Compute the 'shadow' of a set of nodes on a directed graph. - Shadowed nodes are those that are only not reachable from any root of the + Shadowed nodes are those that are not reachable from any root of the graph without passing through a shadow caster node.""" shadowed: set[str] = set(graph.nodes) unvisited: set[str] = {node for node, in_degree in graph.in_degree() if not in_degree} From d3afaadf411b978978bcd2a77d38bc9c68a5edd0 Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Tue, 10 Jun 2025 12:29:46 +1000 Subject: [PATCH 4/8] Improve type info --- src/cpg_flow/workflow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index 4a9bebb3..7a261eed 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -512,10 +512,12 @@ def set_stages( self._show_workflow(dag, skip_stages, only_stages, first_stages, last_stages) @staticmethod - def _instantiate_stages(requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str]): + def _instantiate_stages( + requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str] + ) -> dict[str, Stage]: stages_dict: dict[str, Stage] = {} - def _make_once(cls) -> tuple['Stage', bool]: + def _make_once(cls) -> tuple[Stage, bool]: try: return stages_dict[cls.__name__], False except KeyError: @@ -543,7 +545,7 @@ def _recursively_make_stage(cls): return stages_dict @staticmethod - def _determine_order_of_execution(stages_dict: dict): + def _determine_order_of_execution(stages_dict: dict) -> tuple[list[Stage], nx.DiGraph]: dag_node2nodes = dict() # building a DAG for stg in stages_dict.values(): dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) From 43fd2efe5eeeca6f200760929ecca17a7b8e1d8a Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Tue, 10 Jun 2025 12:33:42 +1000 Subject: [PATCH 5/8] remove incorrect merge conflict resolution --- tests/test_workflow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 4e23f1c4..f7e9331b 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -22,7 +22,6 @@ ) from cpg_flow.targets import Cohort, MultiCohort, SequencingGroup from cpg_flow.workflow import _compute_shadow, _render_graph, path_walk, run_workflow -from cpg_utils import Path, to_path from cpg_utils.config import dataset_path from cpg_utils.hail_batch import get_batch @@ -154,7 +153,6 @@ def test_path_walk(): } act = path_walk(exp) assert act == {pathlib.Path('this.txt'), pathlib.Path('that.txt'), pathlib.Path('the_other.txt')} - assert act == {to_path('this.txt'), to_path('that.txt'), to_path('the_other.txt')} @pytest.fixture() From 06821a4d91104ea7c705edcee2ad689fda23b573 Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Tue, 10 Jun 2025 12:40:37 +1000 Subject: [PATCH 6/8] Improve type info --- src/cpg_flow/workflow.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index 7a261eed..ba62b1e2 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -299,8 +299,8 @@ def __init__( self.status_reporter = None if get_config()['workflow'].get('status_reporter') == 'metamist': self.status_reporter = MetamistStatusReporter() - self._stages: list[StageDecorator] | None = stages - self.queued_stages: list[Stage] = [] + self._stages: list['StageDecorator'] | None = stages + self.queued_stages: list['Stage'] = [] @property def output_version(self) -> str: @@ -387,7 +387,7 @@ def _process_first_last_stages( for node in last_stages: kept.update({node} | nx.descendants(graph, node)) - stage_d: dict[str, Stage] = {s.name: s for s in stages} + stage_d: dict[str, 'Stage'] = {s.name: s for s in stages} for node in pre_first | post_last | (set(stage_d.keys() - kept)): stage = stage_d[node] stage.skipped = True @@ -463,7 +463,7 @@ def set_stages( logger.info(f' workflow/first_stages: {first_stages}') logger.info(f' workflow/last_stages: {last_stages}') - stages_dict: dict[str, Stage] = self._instantiate_stages(requested_stages, skip_stages, only_stages) + stages_dict: dict[str, 'Stage'] = self._instantiate_stages(requested_stages, skip_stages, only_stages) stages, dag = self._determine_order_of_execution(stages_dict) @@ -514,10 +514,10 @@ def set_stages( @staticmethod def _instantiate_stages( requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str] - ) -> dict[str, Stage]: - stages_dict: dict[str, Stage] = {} + ) -> dict[str, 'Stage']: + stages_dict: dict[str, 'Stage'] = {} - def _make_once(cls) -> tuple[Stage, bool]: + def _make_once(cls) -> tuple['Stage', bool]: try: return stages_dict[cls.__name__], False except KeyError: @@ -545,7 +545,7 @@ def _recursively_make_stage(cls): return stages_dict @staticmethod - def _determine_order_of_execution(stages_dict: dict) -> tuple[list[Stage], nx.DiGraph]: + def _determine_order_of_execution(stages_dict: dict) -> tuple[list['Stage'], nx.DiGraph]: dag_node2nodes = dict() # building a DAG for stg in stages_dict.values(): dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) From 7a39af0534bf8320e4e973f89e3a1626abd49421 Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Tue, 10 Jun 2025 12:45:53 +1000 Subject: [PATCH 7/8] whitespace --- src/cpg_flow/workflow.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index ba62b1e2..ef1de666 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -164,7 +164,6 @@ def run_workflow( return wfl - _TARGET: Final[str] = '\U0001f3af' _ONLY: Final[str] = '\U0001f449' _START: Final[str] = '\u23e9' From ecb4f457d49222aa3693b86843f94098a2d96d54 Mon Sep 17 00:00:00 2001 From: Tobias Sargeant Date: Tue, 10 Jun 2025 13:01:52 +1000 Subject: [PATCH 8/8] disable lint: Remove quotes from type annotation --- src/cpg_flow/workflow.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index ef1de666..59e7ff65 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -298,8 +298,8 @@ def __init__( self.status_reporter = None if get_config()['workflow'].get('status_reporter') == 'metamist': self.status_reporter = MetamistStatusReporter() - self._stages: list['StageDecorator'] | None = stages - self.queued_stages: list['Stage'] = [] + self._stages: list['StageDecorator'] | None = stages # noqa: UP037 + self.queued_stages: list['Stage'] = [] # noqa: UP037 @property def output_version(self) -> str: @@ -386,7 +386,7 @@ def _process_first_last_stages( for node in last_stages: kept.update({node} | nx.descendants(graph, node)) - stage_d: dict[str, 'Stage'] = {s.name: s for s in stages} + stage_d: dict[str, 'Stage'] = {s.name: s for s in stages} # noqa: UP037 for node in pre_first | post_last | (set(stage_d.keys() - kept)): stage = stage_d[node] stage.skipped = True @@ -462,7 +462,7 @@ def set_stages( logger.info(f' workflow/first_stages: {first_stages}') logger.info(f' workflow/last_stages: {last_stages}') - stages_dict: dict[str, 'Stage'] = self._instantiate_stages(requested_stages, skip_stages, only_stages) + stages_dict: dict[str, 'Stage'] = self._instantiate_stages(requested_stages, skip_stages, only_stages) # noqa: UP037 stages, dag = self._determine_order_of_execution(stages_dict) @@ -514,7 +514,7 @@ def set_stages( def _instantiate_stages( requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str] ) -> dict[str, 'Stage']: - stages_dict: dict[str, 'Stage'] = {} + stages_dict: dict[str, 'Stage'] = {} # noqa: UP037 def _make_once(cls) -> tuple['Stage', bool]: try: