diff --git a/Cargo.lock b/Cargo.lock
index 46f2e35..707fbcd 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -93,6 +93,7 @@ dependencies = [
"ordered-float",
"pico-args",
"rpds",
+ "rustc-hash",
]
[[package]]
@@ -237,6 +238,12 @@ dependencies = [
"archery",
]
+[[package]]
+name = "rustc-hash"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
+
[[package]]
name = "ryu"
version = "1.0.14"
diff --git a/src/extract/ilp_cbc.rs b/src/extract/ilp_cbc.rs
index c48b04b..82522bd 100644
--- a/src/extract/ilp_cbc.rs
+++ b/src/extract/ilp_cbc.rs
@@ -1,267 +1,208 @@
-use core::panic;
+/* An ILP extractor that returns the optimal DAG-extraction.
+
+This extractor is simple so that it's easy to see that it's correct.
+
+If the timeout is reached, it will return the result of the faster-greedy-dag extractor.
+*/
use super::*;
use coin_cbc::{Col, Model, Sense};
use indexmap::IndexSet;
-const INITIALISE_WITH_BOTTOM_UP: bool = false;
-
struct ClassVars {
active: Col,
nodes: Vec
,
}
-pub struct CbcExtractor;
+pub struct CbcExtractorWithTimeout;
-impl Extractor for CbcExtractor {
+impl Extractor for CbcExtractorWithTimeout {
fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult {
- let mut model = Model::default();
-
- let true_literal = model.add_binary();
- model.set_col_lower(true_literal, 1.0);
-
- let vars: IndexMap = egraph
- .classes()
- .values()
- .map(|class| {
- let cvars = ClassVars {
- active: if roots.contains(&class.id) {
- // Roots must be active.
- true_literal
- } else {
- model.add_binary()
- },
- nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
- };
- (class.id.clone(), cvars)
- })
- .collect();
-
- for (class_id, class) in &vars {
- // class active == some node active
- // sum(for node_active in class) == class_active
- let row = model.add_row();
- model.set_row_equal(row, 0.0);
- model.set_weight(row, class.active, -1.0);
- for &node_active in &class.nodes {
- model.set_weight(row, node_active, 1.0);
- }
+ return extract(egraph, roots, TIMEOUT_IN_SECONDS);
+ }
+}
- let childrens_classes_var = |nid: NodeId| {
- egraph[&nid]
- .children
- .iter()
- .map(|n| egraph[n].eclass.clone())
- .map(|n| vars[&n].active)
- .collect::>()
- };
+pub struct CbcExtractor;
- let mut intersection: IndexSet =
- childrens_classes_var(egraph[class_id].nodes[0].clone());
+impl Extractor for CbcExtractor {
+ fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult {
+ return extract(egraph, roots, std::u32::MAX);
+ }
+}
- for node in &egraph[class_id].nodes[1..] {
- intersection = intersection
- .intersection(&childrens_classes_var(node.clone()))
- .cloned()
- .collect();
- }
+fn extract(egraph: &EGraph, roots: &[ClassId], timeout_seconds: u32) -> ExtractionResult {
+ let mut model = Model::default();
- // A class being active implies that all in the intersection
- // of it's children are too.
- for c in &intersection {
- let row = model.add_row();
- model.set_row_upper(row, 0.0);
- model.set_weight(row, class.active, 1.0);
- model.set_weight(row, *c, -1.0);
- }
+ model.set_parameter("seconds", &timeout_seconds.to_string());
- for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) {
- for child_active in childrens_classes_var(node_id.clone()) {
- // node active implies child active, encoded as:
- // node_active <= child_active
- // node_active - child_active <= 0
- if !intersection.contains(&child_active) {
- let row = model.add_row();
- model.set_row_upper(row, 0.0);
- model.set_weight(row, node_active, 1.0);
- model.set_weight(row, child_active, -1.0);
- }
- }
- }
+ let vars: IndexMap = egraph
+ .classes()
+ .values()
+ .map(|class| {
+ let cvars = ClassVars {
+ active: model.add_binary(),
+ nodes: class.nodes.iter().map(|_| model.add_binary()).collect(),
+ };
+ (class.id.clone(), cvars)
+ })
+ .collect();
+
+ for (class_id, class) in &vars {
+ // class active == some node active
+ // sum(for node_active in class) == class_active
+ let row = model.add_row();
+ model.set_row_equal(row, 0.0);
+ model.set_weight(row, class.active, -1.0);
+ for &node_active in &class.nodes {
+ model.set_weight(row, node_active, 1.0);
}
- model.set_obj_sense(Sense::Minimize);
- for class in egraph.classes().values() {
- let min_cost = class
- .nodes
+ let childrens_classes_var = |nid: NodeId| {
+ egraph[&nid]
+ .children
.iter()
- .map(|n_id| egraph[n_id].cost)
- .min()
- .unwrap_or(Cost::default())
- .into_inner();
-
- // Most helpful when the members of the class all have the same cost.
- // For example if the members' costs are [1,1,1], three terms get
- // replaced by one in the objective function.
- if min_cost != 0.0 {
- model.set_obj_coeff(vars[&class.id].active, min_cost);
- }
-
- for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
- let node = &egraph[node_id];
- let node_cost = node.cost.into_inner() - min_cost;
- assert!(node_cost >= 0.0);
-
- if node_cost != 0.0 {
- model.set_obj_coeff(node_active, node_cost);
- }
+ .map(|n| egraph[n].eclass.clone())
+ .map(|n| vars[&n].active)
+ .collect::>()
+ };
+
+ for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) {
+ for child_active in childrens_classes_var(node_id.clone()) {
+ // node active implies child active, encoded as:
+ // node_active <= child_active
+ // node_active - child_active <= 0
+ let row = model.add_row();
+ model.set_row_upper(row, 0.0);
+ model.set_weight(row, node_active, 1.0);
+ model.set_weight(row, child_active, -1.0);
}
}
+ }
- // set initial solution based on bottom up extractor
- if INITIALISE_WITH_BOTTOM_UP {
- let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots);
- for (class, class_vars) in egraph.classes().values().zip(vars.values()) {
- if let Some(node_id) = initial_result.choices.get(&class.id) {
- model.set_col_initial_solution(class_vars.active, 1.0);
- for col in &class_vars.nodes {
- model.set_col_initial_solution(*col, 0.0);
- }
- let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap();
- model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0);
- } else {
- model.set_col_initial_solution(class_vars.active, 0.0);
- }
- }
- }
+ model.set_obj_sense(Sense::Minimize);
+ for class in egraph.classes().values() {
+ for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) {
+ let node = &egraph[node_id];
+ let node_cost = node.cost.into_inner();
+ assert!(node_cost >= 0.0);
- let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default();
- find_cycles(egraph, |id, i| {
- banned_cycles.insert((id, i));
- });
- for (class_id, class_vars) in &vars {
- for (i, &node_active) in class_vars.nodes.iter().enumerate() {
- if banned_cycles.contains(&(class_id.clone(), i)) {
- model.set_col_upper(node_active, 0.0);
- model.set_col_lower(node_active, 0.0);
- }
- }
- }
- log::info!("@blocked {}", banned_cycles.len());
-
- let solution = model.solve();
- log::info!(
- "CBC status {:?}, {:?}, obj = {}",
- solution.raw().status(),
- solution.raw().secondary_status(),
- solution.raw().obj_value(),
- );
-
- let mut result = ExtractionResult::default();
-
- for (id, var) in &vars {
- let active = solution.col(var.active) > 0.0;
- if active {
- let node_idx = var
- .nodes
- .iter()
- .position(|&n| solution.col(n) > 0.0)
- .unwrap();
- let node_id = egraph[id].nodes[node_idx].clone();
- result.choose(id.clone(), node_id);
+ if node_cost != 0.0 {
+ model.set_obj_coeff(node_active, node_cost);
}
}
-
- let cycles = result.find_cycles(egraph, roots);
- assert!(cycles.is_empty());
- return result;
}
-}
-// from @khaki3
-// fixes bug in egg 0.9.4's version
-// https://github.com/egraphs-good/egg/issues/207#issuecomment-1264737441
-fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) {
- let mut pending: IndexMap> = IndexMap::default();
+ for root in roots {
+ model.set_col_lower(vars[root].active, 1.0);
+ }
- let mut order: IndexMap = IndexMap::default();
+ block_cycles(&mut model, &vars, &egraph);
- let mut memo: IndexMap<(ClassId, usize), bool> = IndexMap::default();
+ let solution = model.solve();
+ log::info!(
+ "CBC status {:?}, {:?}, obj = {}",
+ solution.raw().status(),
+ solution.raw().secondary_status(),
+ solution.raw().obj_value(),
+ );
- let mut stack: Vec<(ClassId, usize)> = vec![];
+ if solution.raw().status() != coin_cbc::raw::Status::Finished {
+ assert!(timeout_seconds != std::u32::MAX);
- let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
+ let initial_result =
+ super::faster_greedy_dag::FasterGreedyDagExtractor.extract(egraph, roots);
+ log::info!("Unfinished CBC solution");
+ return initial_result;
+ }
- for class in egraph.classes().values() {
- let id = &class.id;
- for (i, node_id) in egraph[id].nodes.iter().enumerate() {
- let node = &egraph[node_id];
- for child in &node.children {
- let child = n2c(child).clone();
- pending
- .entry(child)
- .or_insert_with(Vec::new)
- .push((id.clone(), i));
- }
+ let mut result = ExtractionResult::default();
- if node.is_leaf() {
- stack.push((id.clone(), i));
- }
+ for (id, var) in &vars {
+ let active = solution.col(var.active) > 0.0;
+ if active {
+ let node_idx = var
+ .nodes
+ .iter()
+ .position(|&n| solution.col(n) > 0.0)
+ .unwrap();
+ let node_id = egraph[id].nodes[node_idx].clone();
+ result.choose(id.clone(), node_id);
}
}
- let mut count = 0;
-
- while let Some((id, i)) = stack.pop() {
- if memo.get(&(id.clone(), i)).is_some() {
- continue;
- }
+ return result;
+}
- let node_id = &egraph[&id].nodes[i];
- let node = &egraph[node_id];
- let mut update = false;
-
- if node.is_leaf() {
- update = true;
- } else if node.children.iter().all(|x| order.get(n2c(x)).is_some()) {
- if let Some(ord) = order.get(&id) {
- update = node.children.iter().all(|x| &order[n2c(x)] < ord);
- if !update {
- memo.insert((id, i), false);
- continue;
- }
- } else {
- update = true;
- }
- }
+/*
+
+ To block cycles, we enforce that a topological ordering exists on the extraction.
+ Each class is mapped to a variable (called its level). Then for each node,
+ we add a constraint that if a node is active, then the level of the class the node
+ belongs to must be less than than the level of each of the node's children.
+
+ To create a cycle, the levels would need to decrease, so they're blocked. For example,
+ given a two class cycle: if class A, has level 'l', and class B has level 'm', then
+ 'l' must be less than 'm', but because there is also an active node in class B that
+ has class A as a child, 'm' must be less than 'l', which is a contradiction.
+*/
+
+fn block_cycles(model: &mut Model, vars: &IndexMap, egraph: &EGraph) {
+ let mut levels: IndexMap = Default::default();
+ for c in vars.keys() {
+ let var = model.add_col();
+ levels.insert(c.clone(), var);
+ //model.set_col_lower(var, 0.0);
+ // It solves the benchmarks about 5% faster without this
+ //model.set_col_upper(var, vars.len() as f64);
+ }
- if update {
- if order.get(&id).is_none() {
- if egraph[node_id].is_leaf() {
- order.insert(id.clone(), 0);
- } else {
- order.insert(id.clone(), count);
- count += 1;
- }
- }
- memo.insert((id.clone(), i), true);
- if let Some(mut v) = pending.remove(&id) {
- stack.append(&mut v);
- stack.sort();
- stack.dedup();
- };
+ // If n.variable is true, opposite_col will be false and vice versa.
+ let mut opposite: IndexMap = Default::default();
+ for c in vars.values() {
+ for n in &c.nodes {
+ let opposite_col = model.add_binary();
+ opposite.insert(*n, opposite_col);
+ let row = model.add_row();
+ model.set_row_equal(row, 1.0);
+ model.set_weight(row, opposite_col, 1.0);
+ model.set_weight(row, *n, 1.0);
}
}
- for class in egraph.classes().values() {
- let id = &class.id;
- for (i, node) in class.nodes.iter().enumerate() {
- if let Some(true) = memo.get(&(id.clone(), i)) {
+ for (class_id, c) in vars {
+ for i in 0..c.nodes.len() {
+ let n_id = &egraph[class_id].nodes[i];
+ let n = &egraph[n_id];
+ let var = c.nodes[i];
+
+ let children_classes = n
+ .children
+ .iter()
+ .map(|n| egraph[n].eclass.clone())
+ .collect::>();
+
+ if children_classes.contains(class_id) {
+ // Self loop - disable this node.
+ // This is clumsier than calling set_col_lower(var,0.0),
+ // but means it'll be infeasible (rather than producing an
+ // incorrect solution) if var corresponds to a root node.
+ let row = model.add_row();
+ model.set_weight(row, var, 1.0);
+ model.set_row_equal(row, 0.0);
continue;
}
- assert!(!egraph[node].is_leaf());
- f(id.clone(), i);
+
+ for cc in children_classes {
+ assert!(*levels.get(class_id).unwrap() != *levels.get(&cc).unwrap());
+
+ let row = model.add_row();
+ model.set_row_lower(row, 1.0);
+ model.set_weight(row, *levels.get(class_id).unwrap(), -1.0);
+ model.set_weight(row, *levels.get(&cc).unwrap(), 1.0);
+
+ // If n.variable is 0, then disable the contraint.
+ model.set_weight(row, *opposite.get(&var).unwrap(), (vars.len() + 1) as f64);
+ }
}
}
- assert!(pending.is_empty());
}
diff --git a/src/main.rs b/src/main.rs
index bddc552..fab1624 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -36,6 +36,11 @@ fn main() {
"global-greedy-dag",
extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(),
),
+ #[cfg(feature = "ilp-cbc")]
+ (
+ "ilp-cbc-timeout",
+ extract::ilp_cbc::CbcExtractorWithTimeout::<10>.boxed(),
+ ),
]
.into_iter()
.collect();