diff --git a/src/cpg_flow/workflow.py b/src/cpg_flow/workflow.py index fe196c29..3f6eafa3 100644 --- a/src/cpg_flow/workflow.py +++ b/src/cpg_flow/workflow.py @@ -237,6 +237,24 @@ def _render_node(node): return out +def _compute_shadow(graph: nx.DiGraph, shadow_casters: set[str]) -> set[str]: + """Compute the 'shadow' of a set of nodes on a directed graph. + + Shadowed nodes are those that are not reachable from any root of the + graph without passing through a shadow caster node.""" + shadowed: set[str] = set(graph.nodes) + unvisited: set[str] = {node for node, in_degree in graph.in_degree() if not in_degree} + + while unvisited: + node = unvisited.pop() + shadowed.remove(node) + if node not in shadow_casters: + for descendant in graph.successors(node): + if descendant in shadowed: + unvisited.add(descendant) + return shadowed + + class Workflow: """ Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. @@ -284,8 +302,8 @@ def __init__( self.status_reporter = None if config_retrieve(['workflow', 'status_reporter'], None) == 'metamist': self.status_reporter = MetamistStatusReporter() - self._stages: list[StageDecorator] | None = stages - self.queued_stages: list[Stage] = [] + self._stages: list['StageDecorator'] | None = stages # noqa: UP037 + self.queued_stages: list['Stage'] = [] # noqa: UP037 @property def output_version(self) -> str: @@ -360,75 +378,24 @@ def _process_first_last_stages( before first_stages, and all stages after last_stages (i.e. descendants and ancestors on the stages DAG.) """ - stages_d = {s.name: s for s in stages} - stage_names = list(stg.name for stg in stages) - lower_names = {s.lower() for s in stage_names} - - for param, _stage_list in [ - ('first_stages', first_stages), - ('last_stages', last_stages), - ]: - for _s_name in _stage_list: - if _s_name.lower() not in lower_names: - raise WorkflowError( - f'Value in workflow/{param} "{_s_name}" must be a stage name ' - f'or a subset of stages from the available list: ' - f'{", ".join(stage_names)}', - ) - - if not (last_stages or first_stages): + if not (first_stages or last_stages): return - # E.g. if our last_stages is CramQc, MtToEs would still run because it's in - # a different branch. So we want to collect all stages after first_stages - # and before last_stages in their respective branches, and mark as skipped - # everything in other branches. - first_stages_keeps: list[str] = first_stages[:] - last_stages_keeps: list[str] = last_stages[:] - - for fs in first_stages: - for descendant in nx.descendants(graph, fs): - if not stages_d[descendant].skipped: - logger.info( - f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', - ) - stages_d[descendant].skipped = True - for grand_descendant in nx.descendants(graph, descendant): - if not stages_d[grand_descendant].assume_outputs_exist: - logger.info( - f'Not checking expected outputs of not immediately ' - f'required stage {grand_descendant} (< {descendant} < {fs})', - ) - stages_d[grand_descendant].assume_outputs_exist = True - - for ancestor in nx.ancestors(graph, fs): - first_stages_keeps.append(ancestor) - - for ls in last_stages: - # ancestors of this last_stage - ancestors = nx.ancestors(graph, ls) - if any(anc in last_stages for anc in ancestors): - # a downstream stage is also in last_stages, so this is not yet - # a "real" last stage that we want to run - continue - for ancestor in ancestors: - if stages_d[ancestor].skipped: - continue # already skipped - logger.info(f'Skipping stage {ancestor} (after last {ls})') - stages_d[ancestor].skipped = True - stages_d[ancestor].assume_outputs_exist = True - - for ancestor in nx.descendants(graph, ls): - last_stages_keeps.append(ancestor) - - for _stage in stages: - if _stage.name not in last_stages_keeps + first_stages_keeps: - _stage.skipped = True - _stage.assume_outputs_exist = True + pre_first = _compute_shadow(graph, set(first_stages)) + post_last = _compute_shadow(graph.reverse(), set(last_stages)) - for stage in stages: - if stage.skipped: - graph.nodes[stage.name]['skipped'] = True + kept = set() + for node in first_stages: + kept.update({node} | nx.ancestors(graph, node)) + for node in last_stages: + kept.update({node} | nx.descendants(graph, node)) + + stage_d: dict[str, 'Stage'] = {s.name: s for s in stages} # noqa: UP037 + for node in pre_first | post_last | (set(stage_d.keys() - kept)): + stage = stage_d[node] + stage.skipped = True + stage.assume_outputs_exist = True + graph.nodes[node]['skipped'] = True @staticmethod def _process_only_stages( @@ -455,8 +422,7 @@ def _process_only_stages( # imediate predecessor stages, but skip everything else. required_stages: set[str] = set() for os in only_stages: - rs = nx.descendants_at_distance(graph, os, 1) - required_stages |= set(rs) + required_stages.update(nx.descendants_at_distance(graph, os, 1)) for stage in stages: # Skip stage not in only_stages, and assume outputs exist... @@ -500,30 +466,11 @@ def set_stages( logger.info(f' workflow/first_stages: {first_stages}') logger.info(f' workflow/last_stages: {last_stages}') - # Round 1: initialising stage objects. - stages_dict: dict[str, Stage] = {} - for cls in requested_stages: - if cls.__name__ in stages_dict: - continue - stages_dict[cls.__name__] = cls() - - # Round 2: depth search to find implicit stages. - stages_dict = self._resolve_implicit_stages( - stages_dict=stages_dict, - skip_stages=skip_stages, - only_stages=only_stages, - ) + stages_dict: dict[str, 'Stage'] = self._instantiate_stages(requested_stages, skip_stages, only_stages) # noqa: UP037 - # Round 3: set "stage.required_stages" fields to each stage. - for stg in stages_dict.values(): - stg.required_stages = [ - stages_dict[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in stages_dict - ] - - # Round 4: determining order of execution. stages, dag = self._determine_order_of_execution(stages_dict) - # Round 5: applying workflow options first_stages and last_stages. + # Apply workflow options first_stages and last_stages. if first_stages or last_stages: logger.info('Applying workflow/first_stages and workflow/last_stages') self._process_first_last_stages(stages, dag, first_stages, last_stages) @@ -568,38 +515,40 @@ def set_stages( self._show_workflow(dag, skip_stages, only_stages, first_stages, last_stages) @staticmethod - def _resolve_implicit_stages(stages_dict: dict, skip_stages: list[str], only_stages: list[str]): - implicit_stages = {'first': 'loop'} - - while len(implicit_stages) > 0: - implicit_stages = dict() - for stg in stages_dict.values(): - if stg.name in skip_stages: - stg.skipped = True - continue # not searching deeper - - if only_stages and stg.name not in only_stages: - stg.skipped = True - - # Get all deps not already in stages_dict - not_in_stages_dict = { - cls().name: cls() for cls in stg.required_stages_classes if cls.__name__ not in stages_dict - } - implicit_stages |= not_in_stages_dict - - # If there's nothing more to add, finish search - if not implicit_stages: - break - - logger.info( - f'Additional implicit stages: {list(implicit_stages.keys())}', - ) - stages_dict |= implicit_stages + def _instantiate_stages( + requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str] + ) -> dict[str, 'Stage']: + stages_dict: dict[str, 'Stage'] = {} # noqa: UP037 + + def _make_once(cls) -> tuple['Stage', bool]: + try: + return stages_dict[cls.__name__], False + except KeyError: + instance = stages_dict[cls.__name__] = cls() + return instance, True + + def _recursively_make_stage(cls): + instance, is_new = _make_once(cls) + if is_new: + instance.skipped = cls.__name__ in skip_stages + if not instance.skipped: + instance.required_stages.extend( + filter(None, map(_recursively_make_stage, instance.required_stages_classes)), + ) + return instance + + for cls in requested_stages: + _recursively_make_stage(cls) + + if only_stages: + for stage_name, stage in stages_dict.items(): + if stage_name not in only_stages: + stage.skipped = True return stages_dict @staticmethod - def _determine_order_of_execution(stages_dict: dict): + def _determine_order_of_execution(stages_dict: dict) -> tuple[list['Stage'], nx.DiGraph]: dag_node2nodes = dict() # building a DAG for stg in stages_dict.values(): dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) diff --git a/tests/stages/test_first_last_stages_misconfigured.py b/tests/stages/test_first_last_stages_misconfigured.py index b444a06a..8671d3e3 100644 --- a/tests/stages/test_first_last_stages_misconfigured.py +++ b/tests/stages/test_first_last_stages_misconfigured.py @@ -6,7 +6,7 @@ from pytest_mock import MockFixture from tests import set_config -from tests.stages import run_workflow +from tests.stages import D, run_workflow def test_first_last_stages_misconfigured(mocker: MockFixture, tmp_path): @@ -47,4 +47,4 @@ def test_first_last_stages_misconfigured(mocker: MockFixture, tmp_path): from cpg_flow.workflow import WorkflowError with pytest.raises(WorkflowError, match='No stages to run'): - run_workflow(mocker) + run_workflow(mocker, stages=[D]) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index f7dcb164..12ff6bf9 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -2,7 +2,9 @@ Test building Workflow object. """ +import itertools import pathlib +import re from collections.abc import Collection, Mapping, Sequence from typing import Any, Final from unittest import mock @@ -19,7 +21,7 @@ stage, ) from cpg_flow.targets import Cohort, MultiCohort, SequencingGroup -from cpg_flow.workflow import _render_graph, get_workflow, path_walk, run_workflow +from cpg_flow.workflow import _compute_shadow, _render_graph, get_workflow, path_walk, run_workflow from cpg_utils.config import dataset_path from cpg_utils.hail_batch import get_batch @@ -271,3 +273,27 @@ def test_render_graph_extra_args( graph = _create_graph_with_attrs(edges, skipped_nodes) result = ';'.join(_render_graph(graph, **extra_args)) assert result == expected + + +def _parse_graph(graph: str) -> nx.DiGraph: + g = nx.DiGraph() + for path in re.split(r'\s*;\s*', graph): + path = re.split(r'\s*->\s*', path) + for edge in itertools.pairwise(path): + g.add_edge(*edge) + return g + + +@pytest.mark.parametrize( + ['graph', 'casters', 'expected'], + [ + pytest.param('R->A->Caster->B->D', set(), set()), + pytest.param('R->A->Caster->B->D', {'X'}, set()), + pytest.param('R->A->Caster->B->D', {'Caster'}, {'B', 'D'}), + pytest.param('R->A->Caster->B->D;B->E', {'Caster'}, {'B', 'D', 'E'}), + pytest.param('R->A->Caster->B->D;A->D', {'Caster'}, {'B'}), + pytest.param('R1->A->B->D;R1->Caster1->B;R2->X->Y->Caster2->Z;R3->P', {'Caster1', 'Caster2'}, {'Z'}), + ], +) +def test_compute_shadow(graph: str, casters: set[str], expected: set[str]): + assert _compute_shadow(_parse_graph(graph), casters) == expected