Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 68 additions & 119 deletions src/cpg_flow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,24 @@ def _render_node(node):
return out


def _compute_shadow(graph: nx.DiGraph, shadow_casters: set[str]) -> set[str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic looks more concise than the previous implementation. @folded, I've included a few differences I've noticed between the new logic and the previous implementation. Here, I did a 1:1 comparison, but if these changes are intentional, feel free to skip my comment.

(In the workflow examples, -> points to the execution order and not the edge direction in the DAG object)

  1. When last_stages contains multiple stages on the same path, the previous logic picks the downstream stage (to skip the stages further downstream).
    Let's say we have a workflow A -> B -> C -> D. If we define last_stages = [B, C], the previous logic skips only D, but the new logic will skip C, D.

This happens when B becomes a shadow caster with shadowed={C, D}.

  1. Stage skipping when both last_stages and first_stages are defined.
    Let's say we have a workflow with first_stages = B and last_stages=F
A->C
 ->B->D
    ->E->F->G  
    ->G

The previous logic will result in,

B->D
 ->E->F

But in the new logic,

  • A will not be skipped - Even though B is a shadow caster, C will light up A and the last_stage kept logic will include A.
  • Gwill not be skipped - Even though F is a shadow caster, E will light-up G and the first_stage kept logic will include G.

"""Compute the 'shadow' of a set of nodes on a directed graph.

Shadowed nodes are those that are not reachable from any root of the
graph without passing through a shadow caster node."""
shadowed: set[str] = set(graph.nodes)
unvisited: set[str] = {node for node, in_degree in graph.in_degree() if not in_degree}

while unvisited:
node = unvisited.pop()
shadowed.remove(node)
if node not in shadow_casters:
for descendant in graph.successors(node):
if descendant in shadowed:
unvisited.add(descendant)
return shadowed


class Workflow:
"""
Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups.
Expand Down Expand Up @@ -284,8 +302,8 @@ def __init__(
self.status_reporter = None
if config_retrieve(['workflow', 'status_reporter'], None) == 'metamist':
self.status_reporter = MetamistStatusReporter()
self._stages: list[StageDecorator] | None = stages
self.queued_stages: list[Stage] = []
self._stages: list['StageDecorator'] | None = stages # noqa: UP037
self.queued_stages: list['Stage'] = [] # noqa: UP037

@property
def output_version(self) -> str:
Expand Down Expand Up @@ -360,75 +378,24 @@ def _process_first_last_stages(
before first_stages, and all stages after last_stages (i.e. descendants and
ancestors on the stages DAG.)
"""
stages_d = {s.name: s for s in stages}
stage_names = list(stg.name for stg in stages)
lower_names = {s.lower() for s in stage_names}

for param, _stage_list in [
('first_stages', first_stages),
('last_stages', last_stages),
]:
for _s_name in _stage_list:
if _s_name.lower() not in lower_names:
raise WorkflowError(
f'Value in workflow/{param} "{_s_name}" must be a stage name '
f'or a subset of stages from the available list: '
f'{", ".join(stage_names)}',
)

if not (last_stages or first_stages):
if not (first_stages or last_stages):
return

# E.g. if our last_stages is CramQc, MtToEs would still run because it's in
# a different branch. So we want to collect all stages after first_stages
# and before last_stages in their respective branches, and mark as skipped
# everything in other branches.
first_stages_keeps: list[str] = first_stages[:]
last_stages_keeps: list[str] = last_stages[:]

for fs in first_stages:
for descendant in nx.descendants(graph, fs):
if not stages_d[descendant].skipped:
logger.info(
f'Skipping stage {descendant} (precedes {fs} listed in first_stages)',
)
stages_d[descendant].skipped = True
for grand_descendant in nx.descendants(graph, descendant):
if not stages_d[grand_descendant].assume_outputs_exist:
logger.info(
f'Not checking expected outputs of not immediately '
f'required stage {grand_descendant} (< {descendant} < {fs})',
)
stages_d[grand_descendant].assume_outputs_exist = True

for ancestor in nx.ancestors(graph, fs):
first_stages_keeps.append(ancestor)

for ls in last_stages:
# ancestors of this last_stage
ancestors = nx.ancestors(graph, ls)
if any(anc in last_stages for anc in ancestors):
# a downstream stage is also in last_stages, so this is not yet
# a "real" last stage that we want to run
continue
for ancestor in ancestors:
if stages_d[ancestor].skipped:
continue # already skipped
logger.info(f'Skipping stage {ancestor} (after last {ls})')
stages_d[ancestor].skipped = True
stages_d[ancestor].assume_outputs_exist = True

for ancestor in nx.descendants(graph, ls):
last_stages_keeps.append(ancestor)

for _stage in stages:
if _stage.name not in last_stages_keeps + first_stages_keeps:
_stage.skipped = True
_stage.assume_outputs_exist = True
pre_first = _compute_shadow(graph, set(first_stages))
post_last = _compute_shadow(graph.reverse(), set(last_stages))

for stage in stages:
if stage.skipped:
graph.nodes[stage.name]['skipped'] = True
kept = set()
for node in first_stages:
kept.update({node} | nx.ancestors(graph, node))
for node in last_stages:
kept.update({node} | nx.descendants(graph, node))

stage_d: dict[str, 'Stage'] = {s.name: s for s in stages} # noqa: UP037
for node in pre_first | post_last | (set(stage_d.keys() - kept)):
stage = stage_d[node]
stage.skipped = True
stage.assume_outputs_exist = True
graph.nodes[node]['skipped'] = True

@staticmethod
def _process_only_stages(
Expand All @@ -455,8 +422,7 @@ def _process_only_stages(
# imediate predecessor stages, but skip everything else.
required_stages: set[str] = set()
for os in only_stages:
rs = nx.descendants_at_distance(graph, os, 1)
required_stages |= set(rs)
required_stages.update(nx.descendants_at_distance(graph, os, 1))

for stage in stages:
# Skip stage not in only_stages, and assume outputs exist...
Expand Down Expand Up @@ -500,30 +466,11 @@ def set_stages(
logger.info(f' workflow/first_stages: {first_stages}')
logger.info(f' workflow/last_stages: {last_stages}')

# Round 1: initialising stage objects.
stages_dict: dict[str, Stage] = {}
for cls in requested_stages:
if cls.__name__ in stages_dict:
continue
stages_dict[cls.__name__] = cls()

# Round 2: depth search to find implicit stages.
stages_dict = self._resolve_implicit_stages(
stages_dict=stages_dict,
skip_stages=skip_stages,
only_stages=only_stages,
)
stages_dict: dict[str, 'Stage'] = self._instantiate_stages(requested_stages, skip_stages, only_stages) # noqa: UP037

# Round 3: set "stage.required_stages" fields to each stage.
for stg in stages_dict.values():
stg.required_stages = [
stages_dict[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in stages_dict
]

# Round 4: determining order of execution.
stages, dag = self._determine_order_of_execution(stages_dict)

# Round 5: applying workflow options first_stages and last_stages.
# Apply workflow options first_stages and last_stages.
if first_stages or last_stages:
logger.info('Applying workflow/first_stages and workflow/last_stages')
self._process_first_last_stages(stages, dag, first_stages, last_stages)
Expand Down Expand Up @@ -568,38 +515,40 @@ def set_stages(
self._show_workflow(dag, skip_stages, only_stages, first_stages, last_stages)

@staticmethod
def _resolve_implicit_stages(stages_dict: dict, skip_stages: list[str], only_stages: list[str]):
implicit_stages = {'first': 'loop'}

while len(implicit_stages) > 0:
implicit_stages = dict()
for stg in stages_dict.values():
if stg.name in skip_stages:
stg.skipped = True
continue # not searching deeper

if only_stages and stg.name not in only_stages:
stg.skipped = True

# Get all deps not already in stages_dict
not_in_stages_dict = {
cls().name: cls() for cls in stg.required_stages_classes if cls.__name__ not in stages_dict
}
implicit_stages |= not_in_stages_dict

# If there's nothing more to add, finish search
if not implicit_stages:
break

logger.info(
f'Additional implicit stages: {list(implicit_stages.keys())}',
)
stages_dict |= implicit_stages
def _instantiate_stages(
requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str]
) -> dict[str, 'Stage']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is nice and improves some inefficiencies in the previous logic (eg, processing the same stage more than once).

stages_dict: dict[str, 'Stage'] = {} # noqa: UP037

def _make_once(cls) -> tuple['Stage', bool]:
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of catching the KeyError, we can simplify this by using a safe lookup on stages_dict. Something like:

instance = stages_dict.get(cls.__name__)
if instance is not None:
    return instance, False

instance = stages_dict[cls.__name__] = cls()
return instance, True

return stages_dict[cls.__name__], False
except KeyError:
instance = stages_dict[cls.__name__] = cls()
return instance, True

def _recursively_make_stage(cls):
instance, is_new = _make_once(cls)
if is_new:
instance.skipped = cls.__name__ in skip_stages
if not instance.skipped:
instance.required_stages.extend(
filter(None, map(_recursively_make_stage, instance.required_stages_classes)),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, if we move the only_stages logic here, we can avoid re-iterating over the stages_dict logic (between lines 543-546).
Something like:

if only_stages:
    if cls.__name__ not in only_stages:
        instance.skipped = True

return instance

for cls in requested_stages:
_recursively_make_stage(cls)

if only_stages:
for stage_name, stage in stages_dict.items():
if stage_name not in only_stages:
stage.skipped = True

return stages_dict

@staticmethod
def _determine_order_of_execution(stages_dict: dict):
def _determine_order_of_execution(stages_dict: dict) -> tuple[list['Stage'], nx.DiGraph]:
dag_node2nodes = dict() # building a DAG
for stg in stages_dict.values():
dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
Expand Down
4 changes: 2 additions & 2 deletions tests/stages/test_first_last_stages_misconfigured.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytest_mock import MockFixture

from tests import set_config
from tests.stages import run_workflow
from tests.stages import D, run_workflow


def test_first_last_stages_misconfigured(mocker: MockFixture, tmp_path):
Expand Down Expand Up @@ -47,4 +47,4 @@ def test_first_last_stages_misconfigured(mocker: MockFixture, tmp_path):
from cpg_flow.workflow import WorkflowError

with pytest.raises(WorkflowError, match='No stages to run'):
run_workflow(mocker)
run_workflow(mocker, stages=[D])
28 changes: 27 additions & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
Test building Workflow object.
"""

import itertools
import pathlib
import re
from collections.abc import Collection, Mapping, Sequence
from typing import Any, Final
from unittest import mock
Expand All @@ -19,7 +21,7 @@
stage,
)
from cpg_flow.targets import Cohort, MultiCohort, SequencingGroup
from cpg_flow.workflow import _render_graph, get_workflow, path_walk, run_workflow
from cpg_flow.workflow import _compute_shadow, _render_graph, get_workflow, path_walk, run_workflow
from cpg_utils.config import dataset_path
from cpg_utils.hail_batch import get_batch

Expand Down Expand Up @@ -271,3 +273,27 @@ def test_render_graph_extra_args(
graph = _create_graph_with_attrs(edges, skipped_nodes)
result = ';'.join(_render_graph(graph, **extra_args))
assert result == expected


def _parse_graph(graph: str) -> nx.DiGraph:
g = nx.DiGraph()
for path in re.split(r'\s*;\s*', graph):
path = re.split(r'\s*->\s*', path)
for edge in itertools.pairwise(path):
g.add_edge(*edge)
return g


@pytest.mark.parametrize(
['graph', 'casters', 'expected'],
[
pytest.param('R->A->Caster->B->D', set(), set()),
pytest.param('R->A->Caster->B->D', {'X'}, set()),
pytest.param('R->A->Caster->B->D', {'Caster'}, {'B', 'D'}),
pytest.param('R->A->Caster->B->D;B->E', {'Caster'}, {'B', 'D', 'E'}),
pytest.param('R->A->Caster->B->D;A->D', {'Caster'}, {'B'}),
pytest.param('R1->A->B->D;R1->Caster1->B;R2->X->Y->Caster2->Z;R3->P', {'Caster1', 'Caster2'}, {'Z'}),
],
)
def test_compute_shadow(graph: str, casters: set[str], expected: set[str]):
assert _compute_shadow(_parse_graph(graph), casters) == expected
Loading