Skip to content

Commit

Permalink
Skip rule evaluation if any body relation is empty (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
s-arash authored Jan 19, 2025
1 parent 3d7330a commit 11a7f17
Show file tree
Hide file tree
Showing 20 changed files with 176 additions and 74 deletions.
4 changes: 3 additions & 1 deletion ascent/src/c_lat_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexRead<'a>
Some(res)
}

fn len(&self) -> usize { self.unwrap_frozen().len() }
fn len_estimate(&self) -> usize { self.unwrap_frozen().len() }

fn is_empty(&'a self) -> bool { self.unwrap_frozen().len() == 0 }
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq + Sync> CRelIndexRead<'a> for CLatIndex<K, V> {
Expand Down
4 changes: 3 additions & 1 deletion ascent/src/c_rel_full_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelFullIndex<K,
Some(res)
}

fn len(&self) -> usize { self.unwrap_frozen().len() }
fn len_estimate(&self) -> usize { self.unwrap_frozen().len() }

fn is_empty(&'a self) -> bool { self.unwrap_frozen().len() == 0 }
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelFullIndex<K, V> {
Expand Down
8 changes: 6 additions & 2 deletions ascent/src/c_rel_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,17 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelIndex<K, V>
Some(res)
}

fn len(&self) -> usize {
// approximate len
fn len_estimate(&self) -> usize {
let sample_size = 4;
let shards = self.unwrap_frozen().shards();
let (count, sum) = shards.iter().take(sample_size).fold((0, 0), |(c, s), shard| (c + 1, s + shard.read().len()));
sum * shards.len() / count
}

fn is_empty(&'a self) -> bool {
let shards = self.unwrap_frozen().shards();
shards.iter().all(|s| s.read().is_empty())
}
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelIndex<K, V> {
Expand Down
6 changes: 4 additions & 2 deletions ascent/src/c_rel_no_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ impl<'a, V: 'a> RelIndexRead<'a> for CRelNoIndex<V> {
}

#[inline(always)]
fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }

fn is_empty(&'a self) -> bool { false }
}

impl<'a, V: 'a + Sync + Send> CRelIndexRead<'a> for CRelNoIndex<V> {
Expand Down Expand Up @@ -92,7 +94,7 @@ impl<'a, V: 'a> RelIndexWrite for CRelNoIndex<V> {
impl<'a, V: 'a> RelIndexMerge for CRelNoIndex<V> {
fn move_index_contents(from: &mut Self, to: &mut Self) {
let before = Instant::now();
assert_eq!(from.len(), to.len());
assert_eq!(from.len_estimate(), to.len_estimate());
// not necessary because we have a mut reference
// assert!(!from.frozen);
// assert!(!to.frozen);
Expand Down
5 changes: 4 additions & 1 deletion ascent/src/rel_index_boilerplate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ where T: RelIndexRead<'a>
fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType> { (**self).index_get(key) }

#[inline(always)]
fn len(&self) -> usize { (**self).len() }
fn len_estimate(&self) -> usize { (**self).len_estimate() }

#[inline(always)]
fn is_empty(&'a self) -> bool { (**self).is_empty() }
}

impl<'a, T> RelIndexReadAll<'a> for &'a T
Expand Down
27 changes: 22 additions & 5 deletions ascent/src/rel_index_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ pub trait RelIndexRead<'a> {
type Value;
type IteratorType: Iterator<Item = Self::Value> + Clone + 'a;
fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType>;
fn len(&'a self) -> usize;
fn len_estimate(&'a self) -> usize;

/// Is the relation **definitely** empty?
///
/// It is OK for implementations to return `false` even if the relation may be empty,
/// as this is used to enable certain optimizations.
fn is_empty(&'a self) -> bool { false }
}

pub trait RelIndexReadAll<'a> {
Expand All @@ -36,7 +42,10 @@ impl<'a, K: Eq + std::hash::Hash + 'a, V: Clone + 'a> RelIndexRead<'a> for RelIn
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }

#[inline(always)]
fn is_empty(&'a self) -> bool { Self::is_empty(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for RelIndexType1<K, V> {
Expand Down Expand Up @@ -67,7 +76,10 @@ impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for HashBrownR
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }

#[inline(always)]
fn is_empty(&'a self) -> bool { Self::is_empty(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for HashBrownRelFullIndexType<K, V> {
Expand Down Expand Up @@ -98,7 +110,9 @@ impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for LatticeInd
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }
#[inline(always)]
fn is_empty(&'a self) -> bool { Self::is_empty(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for LatticeIndexType<K, V> {
Expand Down Expand Up @@ -154,7 +168,10 @@ where
}

#[inline(always)]
fn len(&self) -> usize { self.ind1.len() + self.ind2.len() }
fn len_estimate(&self) -> usize { self.ind1.len_estimate() + self.ind2.len_estimate() }

#[inline]
fn is_empty(&'a self) -> bool { self.ind1.is_empty() && self.ind2.is_empty() }
}

// impl <'a, Ind> RelIndexRead<'a> for RelIndexCombined<'a, Ind, Ind>
Expand Down
2 changes: 1 addition & 1 deletion ascent_base/src/lattice/ord_lattice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::{Debug, Display, Formatter};
use std::fmt::{Debug, Formatter};

use crate::Lattice;

Expand Down
52 changes: 38 additions & 14 deletions ascent_macro/src/ascent_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,32 @@ fn compile_mir_rule(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> proc_macro
} else {
0
};
compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0)
let rule_body_clauses = rule.body_items.iter().filter_map(|bi| bi.clause()).collect_vec();
let check_any_empty_rel_can_help =
rule_body_clauses.len() > 1 && !(rule.simple_join_start_index.is_some() && rule_body_clauses.len() == 2);
let any_empty_rel_code = check_any_empty_rel_can_help
.then(|| {
rule_body_clauses
.iter()
.map(|bclause| {
let rel_expr = expr_for_rel(&bclause.rel, mir);
quote_spanned! { bclause.rel_args_span=> #rel_expr.is_empty() }
})
.reduce(|l, r| quote! { #l || #r})
})
.flatten();

let rule_compiled = compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0);
if let Some(any_empty_rel_code) = any_empty_rel_code {
quote! {
let any_rel_empty = #any_empty_rel_code;
if !any_rel_empty {
#rule_compiled
}
}
} else {
rule_compiled
}
}

fn compile_mir_rule_inner(
Expand All @@ -844,20 +869,19 @@ fn compile_mir_rule_inner(
let rule_cp2_compiled =
compile_mir_rule_inner(&rule_cp2, _scc, mir, par_iter_to_ind, head_update_code, clause_ind);

if let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] {
let rel1_var_name = expr_for_rel(&bcl1.rel, mir);
let rel2_var_name = expr_for_rel(&bcl2.rel, mir);

return quote_spanned! {bcl1.rel_args_span=>
if #rel1_var_name.len() <= #rel2_var_name.len() {
#rule_cp1_compiled
} else {
#rule_cp2_compiled
}
};
} else {
let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] else {
panic!("unexpected body items in reorderable rule")
}
};
let rel1_var_name = expr_for_rel(&bcl1.rel, mir);
let rel2_var_name = expr_for_rel(&bcl2.rel, mir);

return quote_spanned! {bcl1.rel_args_span=>
if #rel1_var_name.len_estimate() <= #rel2_var_name.len_estimate() {
#rule_cp1_compiled
} else {
#rule_cp2_compiled
}
};
}
if clause_ind < rule.body_items.len() {
let bitem = &rule.body_items[clause_ind];
Expand Down
7 changes: 7 additions & 0 deletions ascent_macro/src/ascent_mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ impl MirBodyItem {
}
}

pub fn clause(&self) -> Option<&MirBodyClause> {
match self {
MirBodyItem::Clause(mir_body_clause) => Some(mir_body_clause),
_ => None,
}
}

pub fn bound_vars(&self) -> Vec<Ident> {
match self {
MirBodyItem::Clause(cl) => {
Expand Down
4 changes: 2 additions & 2 deletions ascent_macro/src/ascent_syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,10 @@ pub(crate) struct DsAttributeContents {

impl Parse for DsAttributeContents {
fn parse(input: ParseStream) -> Result<Self> {
let path = syn::Path::parse_mod_style(&input)?;
let path = syn::Path::parse_mod_style(input)?;
let args = if input.peek(Token![:]) {
input.parse::<Token![:]>()?;
TokenStream::parse(&input)?
TokenStream::parse(input)?
} else {
TokenStream::default()
};
Expand Down
20 changes: 20 additions & 0 deletions ascent_tests/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,3 +972,23 @@ fn test_ds_attr() {

assert_rels_eq!(res.bar, [(0, 1)]);
}

#[test]
fn test_rel_empty_check() {
let res = ascent_run_m_par! {
relation edge(i32, i32);
relation path(i32, i32);
relation legit(i32);

path(x, z) <-- edge(x, y), path(y, z), legit(x);
path(x, y) <-- edge(x, y), legit(*&x);

legit(0);
legit(y) <-- legit(x), path(x, y);

edge(x, x + 1) <-- for x in 0..9;
};

println!("{:?}", res.path);
assert_eq!(res.path.len(), 9 * 10 / 2);
}
11 changes: 7 additions & 4 deletions byods/ascent-byods-rels/src/adaptor/bin_rel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0<'a, TBinRel>
Some(res)
}

fn len(&'a self) -> usize { self.0.ind0_len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.ind0_len_estimate() }
fn is_empty(&'a self) -> bool { self.0.is_empty() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0<'a, TBinRel> {
Expand Down Expand Up @@ -99,7 +100,8 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd1<'a, TBinRel>
Some(res)
}

fn len(&'a self) -> usize { self.0.ind1_len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.ind1_len_estimate() }
fn is_empty(&'a self) -> bool { self.0.is_empty() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd1<'a, TBinRel> {
Expand Down Expand Up @@ -136,7 +138,8 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRe
if self.0.contains(&key.0, &key.1) { Some(once(())) } else { None }
}

fn len(&'a self) -> usize { self.0.len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.len_estimate() }
fn is_empty(&'a self) -> bool { self.0.is_empty() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0_1<'a, TBinRel> {
Expand Down Expand Up @@ -200,7 +203,7 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelIndNone<'a, TBinR
Some(IteratorFromDyn::new(res))
}

fn len(&'a self) -> usize { 1 }
fn len_estimate(&'a self) -> usize { 1 }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelIndNone<'a, TBinRel> {
Expand Down
22 changes: 14 additions & 8 deletions byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ where
Some(IteratorFromDyn::new(|| trrel.iter_all()))
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 4;
let sum = self.0.map.values().map(|x| x.len_estimate()).sum::<usize>();
sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -189,12 +190,13 @@ where
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind0_len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
sum * map_len / sample_size.min(map_len).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -248,12 +250,13 @@ where
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind1_len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
sum * map_len / sample_size.min(map_len).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -314,7 +317,8 @@ where

fn index_get(&'a self, (x1,): &Self::Key) -> Option<Self::IteratorType> { self.get(x1) }

fn len(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() }
fn len_estimate(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() }
fn is_empty(&'a self) -> bool { self.0.reverse_map1.as_ref().unwrap().is_empty() }
}

pub struct BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -374,7 +378,8 @@ where

fn index_get(&'a self, (x2,): &Self::Key) -> Option<Self::IteratorType> { self.get(x2) }

fn len(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() }
fn len_estimate(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() }
fn is_empty(&'a self) -> bool { self.0.reverse_map2.as_ref().unwrap().is_empty() }
}

pub struct BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -433,7 +438,7 @@ where
Some(IteratorFromDyn::new(res))
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
// TODO random estimate, could be very wrong
self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len()
/ ((self.0.map.len() as f32).sqrt() as usize)
Expand Down Expand Up @@ -480,7 +485,7 @@ where
Some(IteratorFromDyn::new(res))
}

fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }
}

pub struct BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -544,12 +549,13 @@ where
if self.0.map.get(x0)?.contains(x1, x2) { Some(once(())) } else { None }
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|rel| rel.len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
sum * map_len / sample_size.min(map_len).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel>(&'a mut BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down
Loading

0 comments on commit 11a7f17

Please sign in to comment.