-
Notifications
You must be signed in to change notification settings - Fork 0
chore(refactor): simplify graph ops #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a021789
a8bd619
07cb65c
d3afaad
43fd2ef
06821a4
7a39af0
ecb4f45
2557185
bb2a3f1
057a3c9
c2147cb
ab3b6ef
dfa7e69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -237,6 +237,24 @@ def _render_node(node): | |
| return out | ||
|
|
||
|
|
||
| def _compute_shadow(graph: nx.DiGraph, shadow_casters: set[str]) -> set[str]: | ||
| """Compute the 'shadow' of a set of nodes on a directed graph. | ||
|
|
||
| Shadowed nodes are those that are not reachable from any root of the | ||
| graph without passing through a shadow caster node.""" | ||
| shadowed: set[str] = set(graph.nodes) | ||
| unvisited: set[str] = {node for node, in_degree in graph.in_degree() if not in_degree} | ||
|
|
||
| while unvisited: | ||
| node = unvisited.pop() | ||
| shadowed.remove(node) | ||
| if node not in shadow_casters: | ||
| for descendant in graph.successors(node): | ||
| if descendant in shadowed: | ||
| unvisited.add(descendant) | ||
| return shadowed | ||
|
|
||
|
|
||
| class Workflow: | ||
| """ | ||
| Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. | ||
|
|
@@ -284,8 +302,8 @@ def __init__( | |
| self.status_reporter = None | ||
| if config_retrieve(['workflow', 'status_reporter'], None) == 'metamist': | ||
| self.status_reporter = MetamistStatusReporter() | ||
| self._stages: list[StageDecorator] | None = stages | ||
| self.queued_stages: list[Stage] = [] | ||
| self._stages: list['StageDecorator'] | None = stages # noqa: UP037 | ||
| self.queued_stages: list['Stage'] = [] # noqa: UP037 | ||
|
|
||
| @property | ||
| def output_version(self) -> str: | ||
|
|
@@ -360,75 +378,24 @@ def _process_first_last_stages( | |
| before first_stages, and all stages after last_stages (i.e. descendants and | ||
| ancestors on the stages DAG.) | ||
| """ | ||
| stages_d = {s.name: s for s in stages} | ||
| stage_names = list(stg.name for stg in stages) | ||
| lower_names = {s.lower() for s in stage_names} | ||
|
|
||
| for param, _stage_list in [ | ||
| ('first_stages', first_stages), | ||
| ('last_stages', last_stages), | ||
| ]: | ||
| for _s_name in _stage_list: | ||
| if _s_name.lower() not in lower_names: | ||
| raise WorkflowError( | ||
| f'Value in workflow/{param} "{_s_name}" must be a stage name ' | ||
| f'or a subset of stages from the available list: ' | ||
| f'{", ".join(stage_names)}', | ||
| ) | ||
|
|
||
| if not (last_stages or first_stages): | ||
| if not (first_stages or last_stages): | ||
folded marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return | ||
|
|
||
| # E.g. if our last_stages is CramQc, MtToEs would still run because it's in | ||
| # a different branch. So we want to collect all stages after first_stages | ||
| # and before last_stages in their respective branches, and mark as skipped | ||
| # everything in other branches. | ||
| first_stages_keeps: list[str] = first_stages[:] | ||
| last_stages_keeps: list[str] = last_stages[:] | ||
|
|
||
| for fs in first_stages: | ||
| for descendant in nx.descendants(graph, fs): | ||
| if not stages_d[descendant].skipped: | ||
| logger.info( | ||
| f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', | ||
| ) | ||
| stages_d[descendant].skipped = True | ||
| for grand_descendant in nx.descendants(graph, descendant): | ||
| if not stages_d[grand_descendant].assume_outputs_exist: | ||
| logger.info( | ||
| f'Not checking expected outputs of not immediately ' | ||
| f'required stage {grand_descendant} (< {descendant} < {fs})', | ||
| ) | ||
| stages_d[grand_descendant].assume_outputs_exist = True | ||
|
|
||
| for ancestor in nx.ancestors(graph, fs): | ||
| first_stages_keeps.append(ancestor) | ||
|
|
||
| for ls in last_stages: | ||
| # ancestors of this last_stage | ||
| ancestors = nx.ancestors(graph, ls) | ||
| if any(anc in last_stages for anc in ancestors): | ||
| # a downstream stage is also in last_stages, so this is not yet | ||
| # a "real" last stage that we want to run | ||
| continue | ||
| for ancestor in ancestors: | ||
| if stages_d[ancestor].skipped: | ||
| continue # already skipped | ||
| logger.info(f'Skipping stage {ancestor} (after last {ls})') | ||
| stages_d[ancestor].skipped = True | ||
| stages_d[ancestor].assume_outputs_exist = True | ||
|
|
||
| for ancestor in nx.descendants(graph, ls): | ||
| last_stages_keeps.append(ancestor) | ||
|
|
||
| for _stage in stages: | ||
| if _stage.name not in last_stages_keeps + first_stages_keeps: | ||
| _stage.skipped = True | ||
| _stage.assume_outputs_exist = True | ||
| pre_first = _compute_shadow(graph, set(first_stages)) | ||
| post_last = _compute_shadow(graph.reverse(), set(last_stages)) | ||
|
|
||
| for stage in stages: | ||
| if stage.skipped: | ||
| graph.nodes[stage.name]['skipped'] = True | ||
| kept = set() | ||
| for node in first_stages: | ||
| kept.update({node} | nx.ancestors(graph, node)) | ||
| for node in last_stages: | ||
| kept.update({node} | nx.descendants(graph, node)) | ||
folded marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| stage_d: dict[str, 'Stage'] = {s.name: s for s in stages} # noqa: UP037 | ||
| for node in pre_first | post_last | (set(stage_d.keys() - kept)): | ||
| stage = stage_d[node] | ||
| stage.skipped = True | ||
| stage.assume_outputs_exist = True | ||
| graph.nodes[node]['skipped'] = True | ||
|
|
||
| @staticmethod | ||
| def _process_only_stages( | ||
|
|
@@ -455,8 +422,7 @@ def _process_only_stages( | |
| # imediate predecessor stages, but skip everything else. | ||
| required_stages: set[str] = set() | ||
| for os in only_stages: | ||
| rs = nx.descendants_at_distance(graph, os, 1) | ||
| required_stages |= set(rs) | ||
| required_stages.update(nx.descendants_at_distance(graph, os, 1)) | ||
|
|
||
| for stage in stages: | ||
| # Skip stage not in only_stages, and assume outputs exist... | ||
|
|
@@ -500,30 +466,11 @@ def set_stages( | |
| logger.info(f' workflow/first_stages: {first_stages}') | ||
| logger.info(f' workflow/last_stages: {last_stages}') | ||
|
|
||
| # Round 1: initialising stage objects. | ||
| stages_dict: dict[str, Stage] = {} | ||
| for cls in requested_stages: | ||
| if cls.__name__ in stages_dict: | ||
| continue | ||
| stages_dict[cls.__name__] = cls() | ||
|
|
||
| # Round 2: depth search to find implicit stages. | ||
| stages_dict = self._resolve_implicit_stages( | ||
| stages_dict=stages_dict, | ||
| skip_stages=skip_stages, | ||
| only_stages=only_stages, | ||
| ) | ||
| stages_dict: dict[str, 'Stage'] = self._instantiate_stages(requested_stages, skip_stages, only_stages) # noqa: UP037 | ||
|
|
||
| # Round 3: set "stage.required_stages" fields to each stage. | ||
| for stg in stages_dict.values(): | ||
| stg.required_stages = [ | ||
| stages_dict[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in stages_dict | ||
| ] | ||
|
|
||
| # Round 4: determining order of execution. | ||
| stages, dag = self._determine_order_of_execution(stages_dict) | ||
|
|
||
| # Round 5: applying workflow options first_stages and last_stages. | ||
| # Apply workflow options first_stages and last_stages. | ||
| if first_stages or last_stages: | ||
| logger.info('Applying workflow/first_stages and workflow/last_stages') | ||
| self._process_first_last_stages(stages, dag, first_stages, last_stages) | ||
|
|
@@ -568,38 +515,40 @@ def set_stages( | |
| self._show_workflow(dag, skip_stages, only_stages, first_stages, last_stages) | ||
|
|
||
| @staticmethod | ||
| def _resolve_implicit_stages(stages_dict: dict, skip_stages: list[str], only_stages: list[str]): | ||
| implicit_stages = {'first': 'loop'} | ||
|
|
||
| while len(implicit_stages) > 0: | ||
| implicit_stages = dict() | ||
| for stg in stages_dict.values(): | ||
| if stg.name in skip_stages: | ||
| stg.skipped = True | ||
| continue # not searching deeper | ||
|
|
||
| if only_stages and stg.name not in only_stages: | ||
| stg.skipped = True | ||
|
|
||
| # Get all deps not already in stages_dict | ||
| not_in_stages_dict = { | ||
| cls().name: cls() for cls in stg.required_stages_classes if cls.__name__ not in stages_dict | ||
| } | ||
| implicit_stages |= not_in_stages_dict | ||
|
|
||
| # If there's nothing more to add, finish search | ||
| if not implicit_stages: | ||
| break | ||
|
|
||
| logger.info( | ||
| f'Additional implicit stages: {list(implicit_stages.keys())}', | ||
| ) | ||
| stages_dict |= implicit_stages | ||
| def _instantiate_stages( | ||
| requested_stages: list['StageDecorator'], skip_stages: list[str], only_stages: list[str] | ||
| ) -> dict[str, 'Stage']: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is nice and improves some inefficiencies in the previous logic (eg, processing the same stage more than once). |
||
| stages_dict: dict[str, 'Stage'] = {} # noqa: UP037 | ||
|
|
||
| def _make_once(cls) -> tuple['Stage', bool]: | ||
| try: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of catching the |
||
| return stages_dict[cls.__name__], False | ||
| except KeyError: | ||
| instance = stages_dict[cls.__name__] = cls() | ||
| return instance, True | ||
|
|
||
| def _recursively_make_stage(cls): | ||
| instance, is_new = _make_once(cls) | ||
| if is_new: | ||
| instance.skipped = cls.__name__ in skip_stages | ||
| if not instance.skipped: | ||
| instance.required_stages.extend( | ||
| filter(None, map(_recursively_make_stage, instance.required_stages_classes)), | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, if we move the |
||
| return instance | ||
|
|
||
| for cls in requested_stages: | ||
| _recursively_make_stage(cls) | ||
|
|
||
| if only_stages: | ||
| for stage_name, stage in stages_dict.items(): | ||
| if stage_name not in only_stages: | ||
| stage.skipped = True | ||
|
|
||
| return stages_dict | ||
|
|
||
| @staticmethod | ||
| def _determine_order_of_execution(stages_dict: dict): | ||
| def _determine_order_of_execution(stages_dict: dict) -> tuple[list['Stage'], nx.DiGraph]: | ||
| dag_node2nodes = dict() # building a DAG | ||
| for stg in stages_dict.values(): | ||
| dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic looks more concise than the previous implementation. @folded, I've included a few differences I've noticed between the new logic and the previous implementation. Here, I did a 1:1 comparison, but if these changes are intentional, feel free to skip my comment.
(In the workflow examples, -> points to the execution order and not the edge direction in the DAG object)
last_stagescontains multiple stages on the same path, the previous logic picks the downstream stage (to skip the stages further downstream).Let's say we have a workflow
A->B->C->D. If we definelast_stages= [B,C], the previous logic skips onlyD, but the new logic will skipC,D.This happens when
Bbecomes ashadow casterwithshadowed={C, D}.last_stagesandfirst_stagesare defined.Let's say we have a workflow with
first_stages = Bandlast_stages=FThe previous logic will result in,
But in the new logic,
Awill not be skipped - Even thoughBis ashadow caster,Cwill light upAand thelast_stage keptlogic will includeA.Gwill not be skipped - Even thoughFis ashadow caster,Ewill light-upGand thefirst_stage keptlogic will includeG.