Skip to content

Commit 8ab4add

Browse files
Improve safe sequence optimizations and reachability
Refactored safe sequence optimizations to fix edge variables directly via constraints and track zero/one assignments for use in product constraints. Added efficient node reachability queries to stDiGraph using SCC condensation, replacing networkx ancestor/descendant calls. Updated demo and test files to use new options and added a new cyclic graph test case.
1 parent 0b4ffc0 commit 8ab4add

File tree

7 files changed

+776
-36
lines changed

7 files changed

+776
-36
lines changed

examples/cycles_demo.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ def test_min_flow_decomp(filename: str):
2727
weight_type=int,
2828
subset_constraints=graph.graph["constraints"], # try with and without
2929
optimization_options={
30-
"optimize_with_safe_sequences": True, # set to false to deactivate the safe sequences optimization
30+
"optimize_with_safe_sequences": False, # set to false to deactivate the safe sequences optimization
3131
"optimize_with_safe_sequences_allow_geq_constraints": True,
32-
"optimize_with_safe_sequences_fix_via_bounds": True,
32+
# "optimize_with_safe_sequences_fix_via_bounds": True,
3333
"optimize_with_safe_sequences_fix_zero_edges": True,
3434
},
3535
solver_options={
3636
"external_solver": "gurobi", # we can try also "highs" at some point
3737
"time_limit": 300, # 300s = 5min, is it ok?
38+
"threads": 1
3839
},
3940
)
4041
mfd_model.solve()
@@ -56,7 +57,7 @@ def test_least_abs_errors(filename):
5657
"optimize_with_safe_sequences_allow_geq_constraints": False,
5758
},
5859
solver_options={
59-
"external_solver": "gurobi", # we can try also "highs" at some point
60+
"external_solver": "highs", # we can try also "highs" at some point
6061
"time_limit": 300, # 300s = 5min, is it ok?
6162
},
6263
trusted_edges_for_safety_percentile=0, # we trust for safety edges whose weight in >= 0 percentile, that is, all edges
@@ -75,7 +76,7 @@ def test_least_abs_errors(filename):
7576
"optimize_with_safe_sequences_allow_geq_constraints": False,
7677
},
7778
solver_options={
78-
"external_solver": "gurobi", # we can try also "highs" at some point
79+
"external_solver": "highs", # we can try also "highs" at some point
7980
"time_limit": 300, # 300s = 5min, is it ok?
8081
},
8182
trusted_edges_for_safety_percentile=25, # we trust for safety edges whose weight in >= 25 percentile, remove this if not using the safety optimization
@@ -99,7 +100,7 @@ def test_min_path_error(filename):
99100
"optimize_with_safe_sequences_allow_geq_constraints": True,
100101
},
101102
solver_options={
102-
"external_solver": "gurobi", # we can try also "highs" at some point
103+
"external_solver": "highs", # we can try also "highs" at some point
103104
"time_limit": 300, # 300s = 5min, is it ok?
104105
},
105106
)
@@ -115,10 +116,9 @@ def test_min_path_error(filename):
115116
optimization_options={
116117
"optimize_with_safe_sequences": True, # set to false to deactivate the safe sequences optimization
117118
"optimize_with_safe_sequences_allow_geq_constraints": False,
118-
"optimize_with_safe_sequences_fix_via_bounds": True,
119119
},
120120
solver_options={
121-
"external_solver": "gurobi", # we can try also "highs" at some point
121+
"external_solver": "highs", # we can try also "highs" at some point
122122
"time_limit": 300, # 300s = 5min, is it ok?
123123
},
124124
trusted_edges_for_safety_percentile=25, # we trust for safety edges whose weight in >= 25 percentile, remove this if not using the safety optimization
@@ -136,10 +136,9 @@ def test_min_path_error(filename):
136136
optimization_options={
137137
"optimize_with_safe_sequences": True, # set to false to deactivate the safe sequences optimization
138138
"optimize_with_safe_sequences_allow_geq_constraints": False,
139-
"optimize_with_safe_sequences_fix_via_bounds": True,
140139
},
141140
solver_options={
142-
"external_solver": "gurobi", # we can try also "highs" at some point
141+
"external_solver": "highs", # we can try also "highs" at some point
143142
"time_limit": 300, # 300s = 5min, is it ok?
144143
},
145144
)
@@ -155,20 +154,20 @@ def process_solution(model):
155154
else:
156155
print("Model could not be solved.")
157156

158-
fp.utils.draw(
159-
G=model.G,
160-
filename= "solution.pdf",
161-
flow_attr="flow",
162-
paths=model.get_solution().get('walks', None),
163-
weights=model.get_solution().get('weights', None),
164-
draw_options={
165-
"show_graph_edges": False,
166-
"show_edge_weights": False,
167-
"show_path_weights": False,
168-
"show_path_weight_on_first_edge": True,
169-
"pathwidth": 2,
170-
# "style": "points",
171-
})
157+
# fp.utils.draw(
158+
# G=model.G,
159+
# filename= "solution.pdf",
160+
# flow_attr="flow",
161+
# paths=model.get_solution().get('walks', None),
162+
# weights=model.get_solution().get('weights', None),
163+
# draw_options={
164+
# "show_graph_edges": False,
165+
# "show_edge_weights": False,
166+
# "show_path_weights": False,
167+
# "show_path_weight_on_first_edge": True,
168+
# "pathwidth": 2,
169+
# # "style": "points",
170+
# })
172171

173172
solve_statistics = model.solve_statistics
174173
print(solve_statistics)
@@ -189,7 +188,8 @@ def process_solution(model):
189188

190189
def main():
191190
test_min_flow_decomp(filename = "tests/cyclic_graphs/gt3.kmer15.(130000.132000).V23.E32.cyc100.graph")
192-
test_min_flow_decomp(filename = "tests/cyclic_graphs/gt5.kmer27.(1300000.1400000).V809.E1091.mincyc1000.graph")
191+
# test_min_flow_decomp(filename = "tests/cyclic_graphs/gt5.kmer27.(1300000.1400000).V809.E1091.mincyc1000.graph")
192+
test_min_flow_decomp(filename = "tests/cyclic_graphs/gt32.kmer63.(0.10000).V231.E336.mincyc1.e1.0.graph")
193193
# test_min_flow_decomp(filename = "tests/cyclic_graphs/gt4.kmer15.(0.10000).V1096.E1622.mincyc100.e1.0.graph")
194194
test_least_abs_errors(filename = "tests/cyclic_graphs/gt5.kmer27.(655000.660000).V18.E27.mincyc4.e0.75.graph")
195195
test_min_path_error(filename = "tests/cyclic_graphs/gt5.kmer27.(655000.660000).V18.E27.mincyc4.e0.75.graph")

flowpaths/abstractwalkmodeldigraph.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@ class AbstractWalkModelDiGraph(ABC):
1313
optimize_with_safe_sequences = True
1414
optimize_with_safe_sequences_allow_geq_constraints = True
1515
optimize_with_safe_sequences_fix_via_bounds = False
16-
optimize_with_safe_sequences_fix_zero_edges = False
16+
optimize_with_safe_sequences_fix_zero_edges = True
1717
# TODO: optimize_with_subset_constraints_as_safe_sequences = True
1818
optimize_with_safety_as_subset_constraints = False
1919
optimize_with_max_safe_antichain_as_subset_constraints = False
2020
allow_empty_walks = False
21-
allow_empty_walks = False
2221

2322
def __init__(
2423
self,
@@ -149,6 +148,9 @@ def __init__(
149148
self.solver_options = {}
150149
self.threads = self.solver_options.get("threads", sw.SolverWrapper.threads)
151150

151+
self.edges_set_to_zero = {}
152+
self.edges_set_to_one = {}
153+
152154
# optimizations
153155
if optimization_options is None:
154156
optimization_options = {}
@@ -445,6 +447,7 @@ def _apply_safety_optimizations(self):
445447
self.edge_vars[(u, v, i)] == 1,
446448
name=f"safe_list_u={u}_v={v}_i={i}_eq{m}",
447449
)
450+
self.edges_set_to_one[(u, v, i)] = True
448451
self.solve_statistics["edge_variables=1"] += 1
449452

450453
def _apply_safety_optimizations_fix_zero_edges(self):
@@ -496,18 +499,18 @@ def _apply_safety_optimizations_fix_zero_edges(self):
496499
# or that can reach the first node of the walk
497500
first_node = walk[0][0]
498501
last_node = walk[-1][1]
499-
reachable_from_last_walk = nx.descendants(self.G, last_node) | {last_node}
500-
can_reach_first_walk = nx.ancestors(self.G, first_node) | {first_node}
501502
for (u, v) in self.G.edges:
502-
if (u in reachable_from_last_walk) or (v in can_reach_first_walk):
503+
if (u in self.G.nodes_reachable(last_node)) or (v in self.G.nodes_reaching(first_node)):
503504
protected_edges.add((u, v))
504505

505506
# Collect pairs of non-contiguous consecutive edges (gaps)
506507
gap_pairs = []
507508
for idx in range(len(walk) - 1):
508509
end_prev = walk[idx][1]
509510
start_next = walk[idx + 1][0]
510-
if end_prev != start_next:
511+
# We consider all consecutive edges as gap pairs, because there could be a cycle
512+
# formed between them (this is not the case in DAGs)
513+
if True or end_prev != start_next:
511514
gap_pairs.append((end_prev, start_next))
512515

513516
# If there are no gaps, do not prune anything for this walk/layer
@@ -516,19 +519,22 @@ def _apply_safety_optimizations_fix_zero_edges(self):
516519

517520
# For each gap, add edges that can lie on some path bridging the gap
518521
for (current_last, current_start) in gap_pairs:
519-
reachable_from_last = nx.descendants(self.G, current_last) | {current_last}
520-
can_reach_start = nx.ancestors(self.G, current_start) | {current_start}
521-
522522
for (u, v) in self.G.edges:
523-
if (u in reachable_from_last) and (v in can_reach_start):
523+
if (u in self.G.nodes_reachable(current_last)) and (v in self.G.nodes_reaching(current_start)):
524+
# if (u in reachable_from_last) and (v in can_reach_start):
524525
protected_edges.add((u, v))
525526

526527
# Now fix every other edge to 0 for this layer i
527528
for (u, v) in self.G.edges:
528529
if (u, v) in protected_edges:
529530
continue
530531
# Queue zero-fix for batch bounds update
531-
self.solver.queue_fix_variable(self.edge_vars[(u, v, i)], int(0))
532+
# self.solver.queue_fix_variable(self.edge_vars[(u, v, i)], int(0))
533+
self.solver.add_constraint(
534+
self.edge_vars[(u, v, i)] == 0,
535+
name=f"i={i}_u={u}_v={v}_fix0",
536+
)
537+
self.edges_set_to_zero[(u, v, i)] = True
532538
fixed_zero_count += 1
533539

534540
if fixed_zero_count:

flowpaths/kflowdecompcycles.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,20 @@ def _encode_flow_decomposition(self):
228228
# We encode that edge_vars[(u,v,i)] * self.path_weights_vars[(i)] = self.pi_vars[(u,v,i)],
229229
# assuming self.w_max is a bound for self.path_weights_vars[(i)]
230230
for i in range(self.k):
231+
if (u, v, i) in self.edges_set_to_zero:
232+
self.solver.add_constraint(
233+
self.pi_vars[(u, v, i)] == 0,
234+
name=f"i={i}_u={u}_v={v}_10b",
235+
)
236+
continue
237+
238+
if (u, v, i) in self.edges_set_to_one:
239+
self.solver.add_constraint(
240+
self.pi_vars[(u, v, i)] == self.path_weights_vars[(i)],
241+
name=f"i={i}_u={u}_v={v}_10b",
242+
)
243+
continue
244+
231245
self.solver.add_integer_continuous_product_constraint(
232246
integer_var=self.edge_vars[(u, v, i)],
233247
continuous_var=self.path_weights_vars[(i)],

flowpaths/kleastabserrorscycles.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,20 @@ def _encode_leastabserrors_decomposition(self):
288288
# We encode that edge_vars[(u,v,i)] * self.path_weights_vars[(i)] = self.pi_vars[(u,v,i)],
289289
# assuming self.w_max is a bound for self.path_weights_vars[(i)]
290290
for i in range(self.k):
291+
if (u, v, i) in self.edges_set_to_zero:
292+
self.solver.add_constraint(
293+
self.pi_vars[(u, v, i)] == 0,
294+
name=f"i={i}_u={u}_v={v}_10b",
295+
)
296+
continue
297+
298+
if (u, v, i) in self.edges_set_to_one:
299+
self.solver.add_constraint(
300+
self.pi_vars[(u, v, i)] == self.path_weights_vars[(i)],
301+
name=f"i={i}_u={u}_v={v}_10b",
302+
)
303+
continue
304+
291305
self.solver.add_integer_continuous_product_constraint(
292306
integer_var=self.edge_vars[(u, v, i)],
293307
continuous_var=self.path_weights_vars[(i)],

flowpaths/kminpatherrorcycles.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,20 @@ def _encode_minpatherror_decomposition(self):
326326
# We encode that edge_vars[(u,v,i)] * self.path_weights_vars[(i)] = self.pi_vars[(u,v,i)],
327327
# assuming self.w_max is a bound for self.path_weights_vars[(i)]
328328
for i in range(self.k):
329+
if (u, v, i) in self.edges_set_to_zero:
330+
self.solver.add_constraint(
331+
self.pi_vars[(u, v, i)] == 0,
332+
name=f"i={i}_u={u}_v={v}_10b",
333+
)
334+
continue
335+
336+
if (u, v, i) in self.edges_set_to_one:
337+
self.solver.add_constraint(
338+
self.pi_vars[(u, v, i)] == self.path_weights_vars[(i)],
339+
name=f"i={i}_u={u}_v={v}_10b",
340+
)
341+
continue
342+
329343
self.solver.add_integer_continuous_product_constraint(
330344
integer_var=self.edge_vars[(u, v, i)],
331345
continuous_var=self.path_weights_vars[(i)],
@@ -338,6 +352,20 @@ def _encode_minpatherror_decomposition(self):
338352
# We encode that edge_vars[(u,v,i)] * self.path_slacks_vars[(i)] = self.gamma_vars[(u,v,i)],
339353
# assuming self.w_max is a bound for self.path_slacks_vars[(i)]
340354
for i in range(self.k):
355+
if (u, v, i) in self.edges_set_to_zero:
356+
self.solver.add_constraint(
357+
self.gamma_vars[(u, v, i)] == 0,
358+
name=f"i={i}_u={u}_v={v}_10b",
359+
)
360+
continue
361+
362+
if (u, v, i) in self.edges_set_to_one:
363+
self.solver.add_constraint(
364+
self.gamma_vars[(u, v, i)] == self.path_slacks_vars[(i)],
365+
name=f"i={i}_u={u}_v={v}_10b",
366+
)
367+
continue
368+
341369
self.solver.add_integer_continuous_product_constraint(
342370
integer_var=self.edge_vars[(u, v, i)],
343371
continuous_var=self.path_slacks_vars[i],

flowpaths/stdigraph.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from flowpaths.stdag import stDAG
55
import flowpaths.utils as utils
66
from flowpaths.abstractsourcesinkgraph import AbstractSourceSinkGraph
7-
from typing import Dict, Tuple, Optional
7+
from typing import Dict, Tuple, Optional, Set
88

99

1010
class stDiGraph(AbstractSourceSinkGraph):
@@ -59,6 +59,27 @@ def _post_build(self):
5959
raise ValueError("The graph passed to stDiGraph must have at least one sink, or at least one node in `additional_ends`.")
6060
self.condensation_width = None
6161
self._build_condensation_expanded()
62+
# Build indices and caches used by reachability queries
63+
C: nx.DiGraph = self._condensation
64+
mapping = C.graph["mapping"] # original node -> condensation node (int)
65+
# Per-SCC indices for fast unions
66+
# - edges by SCC (kept for other utilities in this class)
67+
self._edges_by_tail_scc: Dict[int, Set[Tuple[str, str]]] = {c: set() for c in C.nodes()}
68+
self._edges_by_head_scc: Dict[int, Set[Tuple[str, str]]] = {c: set() for c in C.nodes()}
69+
for a, b in self.edges():
70+
ca = mapping[a]
71+
cb = mapping[b]
72+
self._edges_by_tail_scc[ca].add((a, b))
73+
self._edges_by_head_scc[cb].add((a, b))
74+
75+
# - nodes by SCC (for fast node reachability queries)
76+
self._nodes_by_scc: Dict[int, Set[str]] = {c: set() for c in C.nodes()}
77+
for n in self.nodes():
78+
self._nodes_by_scc[mapping[n]].add(n)
79+
80+
# Per-node caches for reachability queries (now returning nodes)
81+
self._nodes_reachable_from_node_cache: Dict[str, Set[str]] = {}
82+
self._nodes_reaching_node_cache: Dict[str, Set[str]] = {}
6283

6384
def _expanded(self, v: int) -> str:
6485

@@ -466,4 +487,80 @@ def compute_edge_max_reachable_value(self, flow_attr: str) -> Dict[Tuple[str, st
466487
result[(u, v)] = max(edge_weight[(u, v)], max_desc[cv], max_anc[cu])
467488

468489
return result
490+
491+
def nodes_reachable(self, node: str) -> Set[str]:
492+
"""Return the set of nodes reachable from ``node`` (including itself).
493+
494+
The result is cached per query node. Reachability is computed on the SCC
495+
condensation DAG: for the SCC containing ``node``, take all SCCs reachable
496+
in the condensation (including itself) and return the union of original
497+
nodes lying in any of those SCCs.
498+
499+
Parameters
500+
----------
501+
node: str
502+
The node ``v`` in this graph from which to evaluate forward reachability.
503+
504+
Returns
505+
-------
506+
Set[str]
507+
All nodes ``a`` such that there exists a path from ``node`` to ``a``.
508+
"""
509+
if node not in self.nodes():
510+
utils.logger.error(f"{__name__}: Node {node} is not in the graph.")
511+
raise ValueError(f"Node {node} is not in the graph.")
512+
if node in self._nodes_reachable_from_node_cache:
513+
return self._nodes_reachable_from_node_cache[node]
514+
515+
C: nx.DiGraph = self._condensation
516+
mapping = C.graph["mapping"]
517+
cv = mapping[node]
518+
519+
# All SCCs reachable from cv (descendants) plus itself
520+
reachable_sccs = set(nx.descendants(C, cv)) | {cv}
521+
522+
result: Set[str] = set()
523+
for c in reachable_sccs:
524+
result |= self._nodes_by_scc.get(c, set())
525+
526+
self._nodes_reachable_from_node_cache[node] = result
527+
return result
528+
529+
def nodes_reaching(self, node: str) -> Set[str]:
530+
"""Return the set of nodes that can reach ``node`` (including itself).
531+
532+
The result is cached per query node. Reachability is computed on the SCC
533+
condensation DAG: for the SCC containing ``node``, take all SCCs that can
534+
reach it (ancestors, including itself) and return the union of original
535+
nodes lying in any of those SCCs.
536+
537+
Parameters
538+
----------
539+
node: str
540+
The node ``u`` in this graph to evaluate backward reachability to ``u``.
541+
542+
Returns
543+
-------
544+
Set[str]
545+
All nodes ``a`` such that there exists a path from ``a`` to ``node``.
546+
"""
547+
if node not in self.nodes():
548+
utils.logger.error(f"{__name__}: Node {node} is not in the graph.")
549+
raise ValueError(f"Node {node} is not in the graph.")
550+
if node in self._nodes_reaching_node_cache:
551+
return self._nodes_reaching_node_cache[node]
552+
553+
C: nx.DiGraph = self._condensation
554+
mapping = C.graph["mapping"]
555+
cu = mapping[node]
556+
557+
# All SCCs that can reach cu (ancestors) plus itself
558+
ancestor_sccs = set(nx.ancestors(C, cu)) | {cu}
559+
560+
result: Set[str] = set()
561+
for c in ancestor_sccs:
562+
result |= self._nodes_by_scc.get(c, set())
563+
564+
self._nodes_reaching_node_cache[node] = result
565+
return result
469566

0 commit comments

Comments
 (0)