Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test code #31

Merged
merged 18 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ indexmap = "2.0.0"
log = "0.4.19"
ordered-float = "3"
pico-args = { version = "0.5.0", features = ["eq-separator"] }
rand = "0.8.5"
walkdir = "2.4.0"

anyhow = "1.0.71"
coin_cbc = { version = "0.1.6", optional = true }
Expand Down
70 changes: 70 additions & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub mod greedy_dag;
#[cfg(feature = "ilp-cbc")]
pub mod ilp_cbc;

// Allowance for floating point values to be considered equal
pub const EPSILON_ALLOWANCE: f64 = 0.00001;

pub trait Extractor: Sync {
fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult;

Expand Down Expand Up @@ -198,3 +201,70 @@ impl ExtractionResult {
.sum::<Cost>()
}
}

use ordered_float::NotNan;
use rand::Rng;

// generates a float between 0 and 1
fn generate_random_not_nan() -> NotNan<f64> {
let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
let random_float: f64 = rng.gen();
NotNan::new(random_float).unwrap()
}

//make a random egraph
pub fn generate_random_egraph() -> EGraph {
let mut rng = rand::thread_rng();
let mut egraph = EGraph::default();
let mut nodes = Vec::<Node>::default();
let mut eclass = 0;

let mut n2nid = IndexMap::<Node, NodeId>::default();
let mut count = 0;

for _ in 0..rng.gen_range(1..100) {
let mut children = Vec::<NodeId>::default();
for node in &nodes {
if rng.gen_bool(0.1) {
children.push(n2nid.get(node).unwrap().clone());
}
}

if rng.gen_bool(0.2) {
eclass += 1;
}

let node = Node {
op: "operation".to_string(),
children: children,
eclass: eclass.to_string().clone().into(),
cost: (generate_random_not_nan() * 100.0),
};

nodes.push(node.clone());
let id = "node_".to_owned() + &count.to_string();
count += 1;
egraph.add_node(id.clone(), node.clone());
n2nid.insert(node.clone(), id.clone().into());
}

//I've not seen this generate an infeasible egraph, and don't undertand why.
let len = nodes.len();
for n in &mut nodes {
if rng.gen_bool(0.5) {
n.children.push(n2nid[rng.gen_range(0..len)].clone());
}
}

// Get roots, potentially duplicate.
for _ in 1..rng.gen_range(2..11) {
egraph.root_eclasses.push(
nodes
.get(rng.gen_range(0..nodes.len()))
.unwrap()
.eclass
.clone(),
);
}
egraph
}
230 changes: 218 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,73 @@ use std::path::PathBuf;
pub type Cost = NotNan<f64>;
pub const INFINITY: Cost = unsafe { NotNan::new_unchecked(std::f64::INFINITY) };

fn main() {
env_logger::init();
struct ExtractorDetail {
extractor: Box<dyn Extractor>,
is_dag_optimal: bool,
is_tree_optimal: bool,
}

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
fn extractors() -> IndexMap<&'static str, ExtractorDetail> {
let extractors: IndexMap<&'static str, ExtractorDetail> = [
(
"faster-bottom-up",
extract::faster_bottom_up::FasterBottomUpExtractor.boxed(),
"bottom-up",
ExtractorDetail {
extractor: extract::bottom_up::BottomUpExtractor.boxed(),
is_dag_optimal: false,
is_tree_optimal: true,
},
),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
"faster-bottom-up",
ExtractorDetail {
extractor: extract::faster_bottom_up::FasterBottomUpExtractor.boxed(),
is_dag_optimal: false,
is_tree_optimal: true,
},
),
(
"faster-greedy-dag",
extract::faster_greedy_dag::FasterGreedyDagExtractor.boxed(),
ExtractorDetail {
extractor: extract::faster_greedy_dag::FasterGreedyDagExtractor.boxed(),
is_dag_optimal: false,
is_tree_optimal: false,
},
),
(
"greedy-dag",
ExtractorDetail {
extractor: extract::greedy_dag::GreedyDagExtractor.boxed(),
is_dag_optimal: false,
is_tree_optimal: false,
},
),
/*(
"global-greedy-dag",
extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(),
ExtractorDetail {
extractor: extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(),
is_dag_optimal: false,
is_tree_optimal: false,
},
),*/
#[cfg(feature = "ilp-cbc")]
(
"ilp-cbc",
ExtractorDetail {
extractor: extract::ilp_cbc::CbcExtractor.boxed(),
is_dag_optimal: true,
is_tree_optimal: false,
},
),
]
.into_iter()
.collect();
return extractors;
}

fn main() {
env_logger::init();

let extractors = extractors();

let mut args = pico_args::Arguments::from_env();

Expand Down Expand Up @@ -71,13 +114,13 @@ fn main() {
.with_context(|| format!("Failed to parse {filename}"))
.unwrap();

let extractor = extractors
let ed = extractors
.get(extractor_name.as_str())
.with_context(|| format!("Unknown extractor: {extractor_name}"))
.unwrap();

let start_time = std::time::Instant::now();
let result = extractor.extract(&egraph, &egraph.root_eclasses);
let result = ed.extractor.extract(&egraph, &egraph.root_eclasses);
let us = start_time.elapsed().as_micros();

result.check(&egraph);
Expand All @@ -98,3 +141,166 @@ fn main() {
)
.unwrap();
}

/*
* Checks that no extractors produce better results than the extractors that produce optimal results.
* Checks that the extractions are valid.
*/

fn check_optimal_results() {
let optimal_dag_found = extractors().into_iter().any(|(_, ed)| ed.is_dag_optimal);

let iterations = if optimal_dag_found { 100 } else { 10000 };

let egraphs = (0..iterations).map(|_| generate_random_egraph());

check_optimal_results2(egraphs);
}

fn check_optimal_results2<I: Iterator<Item = EGraph>>(egraphs: I) {
let optimal_dag: Vec<Box<dyn Extractor>> = extractors()
.into_iter()
.filter(|(_, ed)| ed.is_dag_optimal)
.map(|(_, ed)| ed.extractor)
.collect();

let optimal_tree: Vec<Box<dyn Extractor>> = extractors()
.into_iter()
.filter(|(_, ed)| ed.is_tree_optimal)
.map(|(_, ed)| ed.extractor)
.collect();

let others: Vec<Box<dyn Extractor>> = extractors()
.into_iter()
.filter(|(_, ed)| !ed.is_dag_optimal || !ed.is_tree_optimal)
.map(|(_, ed)| ed.extractor)
.collect();

let mut count = 0;
for egraph in egraphs {
count += 1;
println!("{count}");

let mut optimal_dag_cost: Option<Cost> = None;

for e in &optimal_dag {
let extract = e.extract(&egraph, &egraph.root_eclasses);
extract.check(&egraph);
let dag_cost = extract.dag_cost(&egraph, &egraph.root_eclasses);
let tree_cost = extract.tree_cost(&egraph, &egraph.root_eclasses);
if optimal_dag_cost.is_none() {
optimal_dag_cost = Some(dag_cost);
continue;
}

assert!(
(dag_cost.into_inner() - optimal_dag_cost.unwrap().into_inner()).abs()
< EPSILON_ALLOWANCE
);

assert!(
tree_cost.into_inner() + EPSILON_ALLOWANCE > optimal_dag_cost.unwrap().into_inner()
);
}

let mut optimal_tree_cost: Option<Cost> = None;

for e in &optimal_tree {
let extract = e.extract(&egraph, &egraph.root_eclasses);
extract.check(&egraph);
let tree_cost = extract.tree_cost(&egraph, &egraph.root_eclasses);
if optimal_tree_cost.is_none() {
optimal_tree_cost = Some(tree_cost);
continue;
}

assert!(
(tree_cost.into_inner() - optimal_tree_cost.unwrap().into_inner()).abs()
< EPSILON_ALLOWANCE
);
}

if optimal_dag_cost.is_some() {
assert!(optimal_dag_cost.unwrap() < optimal_tree_cost.unwrap() + EPSILON_ALLOWANCE);
}

for e in &others {
let extract = e.extract(&egraph, &egraph.root_eclasses);
extract.check(&egraph);
let tree_cost = extract.tree_cost(&egraph, &egraph.root_eclasses);
let dag_cost = extract.dag_cost(&egraph, &egraph.root_eclasses);

// The optimal tree cost should be <= any extractor's tree cost.
assert!(optimal_tree_cost.unwrap() <= tree_cost + EPSILON_ALLOWANCE);

if optimal_dag_cost.is_some() {
// The optimal dag should be less <= any extractor's dag cost
assert!(optimal_dag_cost.unwrap() <= dag_cost + EPSILON_ALLOWANCE);
}
}
}
}

// Run on all the .json files in the data directory
#[test]
#[ignore = "too slow to run all the time"]
fn run_files() {
use walkdir::WalkDir;

let egraphs = WalkDir::new("./data")
.into_iter()
.filter_map(Result::ok)
.filter(|e| {
e.file_type().is_file()
&& e.path().extension().and_then(std::ffi::OsStr::to_str) == Some("json")
})
.map(|e| e.path().to_string_lossy().into_owned())
.map(|e| EGraph::from_json_file(&e).unwrap());
check_optimal_results2(egraphs);
}

#[test]
#[should_panic]
fn check_assert_enabled() {
assert!(false);
}

// Make several identical functions so they'll be run in parallel
#[test]
fn check0() {
check_optimal_results();
}

#[test]
fn check1() {
check_optimal_results();
}

#[test]
fn check2() {
check_optimal_results();
}

#[test]
fn check3() {
check_optimal_results();
}

#[test]
fn check4() {
check_optimal_results();
}
#[test]
fn check5() {
check_optimal_results();
}

#[test]
fn check6() {
check_optimal_results();
}

#[test]
fn check7() {
check_optimal_results();
}