Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip rule evaluation if any of the body relations is empty #62

Merged
merged 1 commit into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ascent/src/c_lat_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexRead<'a>
Some(res)
}

fn len(&self) -> usize { self.unwrap_frozen().len() }
fn len_estimate(&self) -> usize { self.unwrap_frozen().len() }

fn is_empty(&'a self) -> bool { self.unwrap_frozen().len() == 0 }
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq + Sync> CRelIndexRead<'a> for CLatIndex<K, V> {
Expand Down
4 changes: 3 additions & 1 deletion ascent/src/c_rel_full_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelFullIndex<K,
Some(res)
}

fn len(&self) -> usize { self.unwrap_frozen().len() }
fn len_estimate(&self) -> usize { self.unwrap_frozen().len() }

fn is_empty(&'a self) -> bool { self.unwrap_frozen().len() == 0 }
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelFullIndex<K, V> {
Expand Down
8 changes: 6 additions & 2 deletions ascent/src/c_rel_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,17 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelIndex<K, V>
Some(res)
}

fn len(&self) -> usize {
// approximate len
fn len_estimate(&self) -> usize {
let sample_size = 4;
let shards = self.unwrap_frozen().shards();
let (count, sum) = shards.iter().take(sample_size).fold((0, 0), |(c, s), shard| (c + 1, s + shard.read().len()));
sum * shards.len() / count
}

fn is_empty(&'a self) -> bool {
let shards = self.unwrap_frozen().shards();
shards.iter().all(|s| s.read().is_empty())
}
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelIndex<K, V> {
Expand Down
6 changes: 4 additions & 2 deletions ascent/src/c_rel_no_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ impl<'a, V: 'a> RelIndexRead<'a> for CRelNoIndex<V> {
}

#[inline(always)]
fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }

fn is_empty(&'a self) -> bool { false }
}

impl<'a, V: 'a + Sync + Send> CRelIndexRead<'a> for CRelNoIndex<V> {
Expand Down Expand Up @@ -92,7 +94,7 @@ impl<'a, V: 'a> RelIndexWrite for CRelNoIndex<V> {
impl<'a, V: 'a> RelIndexMerge for CRelNoIndex<V> {
fn move_index_contents(from: &mut Self, to: &mut Self) {
let before = Instant::now();
assert_eq!(from.len(), to.len());
assert_eq!(from.len_estimate(), to.len_estimate());
// not necessary because we have a mut reference
// assert!(!from.frozen);
// assert!(!to.frozen);
Expand Down
5 changes: 4 additions & 1 deletion ascent/src/rel_index_boilerplate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ where T: RelIndexRead<'a>
fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType> { (**self).index_get(key) }

#[inline(always)]
fn len(&self) -> usize { (**self).len() }
fn len_estimate(&self) -> usize { (**self).len_estimate() }

#[inline(always)]
fn is_empty(&'a self) -> bool { (**self).is_empty() }
}

impl<'a, T> RelIndexReadAll<'a> for &'a T
Expand Down
27 changes: 22 additions & 5 deletions ascent/src/rel_index_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ pub trait RelIndexRead<'a> {
type Value;
type IteratorType: Iterator<Item = Self::Value> + Clone + 'a;
fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType>;
fn len(&'a self) -> usize;
fn len_estimate(&'a self) -> usize;

/// Is the relation **definitely** empty?
///
/// It is OK for implementations to return `false` even if the relation may be empty,
/// as this is used to enable certain optimizations.
fn is_empty(&'a self) -> bool { false }
}

pub trait RelIndexReadAll<'a> {
Expand All @@ -36,7 +42,10 @@ impl<'a, K: Eq + std::hash::Hash + 'a, V: Clone + 'a> RelIndexRead<'a> for RelIn
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }

#[inline(always)]
fn is_empty(&'a self) -> bool { Self::is_empty(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for RelIndexType1<K, V> {
Expand Down Expand Up @@ -67,7 +76,10 @@ impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for HashBrownR
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }

#[inline(always)]
fn is_empty(&'a self) -> bool { Self::is_empty(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for HashBrownRelFullIndexType<K, V> {
Expand Down Expand Up @@ -98,7 +110,9 @@ impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for LatticeInd
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }
#[inline(always)]
fn is_empty(&'a self) -> bool { Self::is_empty(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for LatticeIndexType<K, V> {
Expand Down Expand Up @@ -154,7 +168,10 @@ where
}

#[inline(always)]
fn len(&self) -> usize { self.ind1.len() + self.ind2.len() }
fn len_estimate(&self) -> usize { self.ind1.len_estimate() + self.ind2.len_estimate() }

#[inline]
fn is_empty(&'a self) -> bool { self.ind1.is_empty() && self.ind2.is_empty() }
}

// impl <'a, Ind> RelIndexRead<'a> for RelIndexCombined<'a, Ind, Ind>
Expand Down
2 changes: 1 addition & 1 deletion ascent_base/src/lattice/ord_lattice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::{Debug, Display, Formatter};
use std::fmt::{Debug, Formatter};

use crate::Lattice;

Expand Down
52 changes: 38 additions & 14 deletions ascent_macro/src/ascent_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,32 @@ fn compile_mir_rule(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> proc_macro
} else {
0
};
compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0)
let rule_body_clauses = rule.body_items.iter().filter_map(|bi| bi.clause()).collect_vec();
let check_any_empty_rel_can_help =
rule_body_clauses.len() > 1 && !(rule.simple_join_start_index.is_some() && rule_body_clauses.len() == 2);
let any_empty_rel_code = check_any_empty_rel_can_help
.then(|| {
rule_body_clauses
.iter()
.map(|bclause| {
let rel_expr = expr_for_rel(&bclause.rel, mir);
quote_spanned! { bclause.rel_args_span=> #rel_expr.is_empty() }
})
.reduce(|l, r| quote! { #l || #r})
})
.flatten();

let rule_compiled = compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0);
if let Some(any_empty_rel_code) = any_empty_rel_code {
quote! {
let any_rel_empty = #any_empty_rel_code;
if !any_rel_empty {
#rule_compiled
}
}
} else {
rule_compiled
}
}

fn compile_mir_rule_inner(
Expand All @@ -844,20 +869,19 @@ fn compile_mir_rule_inner(
let rule_cp2_compiled =
compile_mir_rule_inner(&rule_cp2, _scc, mir, par_iter_to_ind, head_update_code, clause_ind);

if let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] {
let rel1_var_name = expr_for_rel(&bcl1.rel, mir);
let rel2_var_name = expr_for_rel(&bcl2.rel, mir);

return quote_spanned! {bcl1.rel_args_span=>
if #rel1_var_name.len() <= #rel2_var_name.len() {
#rule_cp1_compiled
} else {
#rule_cp2_compiled
}
};
} else {
let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] else {
panic!("unexpected body items in reorderable rule")
}
};
let rel1_var_name = expr_for_rel(&bcl1.rel, mir);
let rel2_var_name = expr_for_rel(&bcl2.rel, mir);

return quote_spanned! {bcl1.rel_args_span=>
if #rel1_var_name.len_estimate() <= #rel2_var_name.len_estimate() {
#rule_cp1_compiled
} else {
#rule_cp2_compiled
}
};
}
if clause_ind < rule.body_items.len() {
let bitem = &rule.body_items[clause_ind];
Expand Down
7 changes: 7 additions & 0 deletions ascent_macro/src/ascent_mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ impl MirBodyItem {
}
}

pub fn clause(&self) -> Option<&MirBodyClause> {
match self {
MirBodyItem::Clause(mir_body_clause) => Some(mir_body_clause),
_ => None,
}
}

pub fn bound_vars(&self) -> Vec<Ident> {
match self {
MirBodyItem::Clause(cl) => {
Expand Down
4 changes: 2 additions & 2 deletions ascent_macro/src/ascent_syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,10 @@ pub(crate) struct DsAttributeContents {

impl Parse for DsAttributeContents {
fn parse(input: ParseStream) -> Result<Self> {
let path = syn::Path::parse_mod_style(&input)?;
let path = syn::Path::parse_mod_style(input)?;
let args = if input.peek(Token![:]) {
input.parse::<Token![:]>()?;
TokenStream::parse(&input)?
TokenStream::parse(input)?
} else {
TokenStream::default()
};
Expand Down
20 changes: 20 additions & 0 deletions ascent_tests/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,3 +972,23 @@ fn test_ds_attr() {

assert_rels_eq!(res.bar, [(0, 1)]);
}

#[test]
fn test_rel_empty_check() {
let res = ascent_run_m_par! {
relation edge(i32, i32);
relation path(i32, i32);
relation legit(i32);

path(x, z) <-- edge(x, y), path(y, z), legit(x);
path(x, y) <-- edge(x, y), legit(*&x);

legit(0);
legit(y) <-- legit(x), path(x, y);

edge(x, x + 1) <-- for x in 0..9;
};

println!("{:?}", res.path);
assert_eq!(res.path.len(), 9 * 10 / 2);
}
11 changes: 7 additions & 4 deletions byods/ascent-byods-rels/src/adaptor/bin_rel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0<'a, TBinRel>
Some(res)
}

fn len(&'a self) -> usize { self.0.ind0_len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.ind0_len_estimate() }
fn is_empty(&'a self) -> bool { self.0.is_empty() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0<'a, TBinRel> {
Expand Down Expand Up @@ -99,7 +100,8 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd1<'a, TBinRel>
Some(res)
}

fn len(&'a self) -> usize { self.0.ind1_len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.ind1_len_estimate() }
fn is_empty(&'a self) -> bool { self.0.is_empty() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd1<'a, TBinRel> {
Expand Down Expand Up @@ -136,7 +138,8 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRe
if self.0.contains(&key.0, &key.1) { Some(once(())) } else { None }
}

fn len(&'a self) -> usize { self.0.len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.len_estimate() }
fn is_empty(&'a self) -> bool { self.0.is_empty() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0_1<'a, TBinRel> {
Expand Down Expand Up @@ -200,7 +203,7 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelIndNone<'a, TBinR
Some(IteratorFromDyn::new(res))
}

fn len(&'a self) -> usize { 1 }
fn len_estimate(&'a self) -> usize { 1 }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelIndNone<'a, TBinRel> {
Expand Down
22 changes: 14 additions & 8 deletions byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ where
Some(IteratorFromDyn::new(|| trrel.iter_all()))
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 4;
let sum = self.0.map.values().map(|x| x.len_estimate()).sum::<usize>();
sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -189,12 +190,13 @@ where
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind0_len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
sum * map_len / sample_size.min(map_len).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -248,12 +250,13 @@ where
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind1_len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
sum * map_len / sample_size.min(map_len).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -314,7 +317,8 @@ where

fn index_get(&'a self, (x1,): &Self::Key) -> Option<Self::IteratorType> { self.get(x1) }

fn len(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() }
fn len_estimate(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() }
fn is_empty(&'a self) -> bool { self.0.reverse_map1.as_ref().unwrap().is_empty() }
}

pub struct BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -374,7 +378,8 @@ where

fn index_get(&'a self, (x2,): &Self::Key) -> Option<Self::IteratorType> { self.get(x2) }

fn len(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() }
fn len_estimate(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() }
fn is_empty(&'a self) -> bool { self.0.reverse_map2.as_ref().unwrap().is_empty() }
}

pub struct BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -433,7 +438,7 @@ where
Some(IteratorFromDyn::new(res))
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
// TODO random estimate, could be very wrong
self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len()
/ ((self.0.map.len() as f32).sqrt() as usize)
Expand Down Expand Up @@ -480,7 +485,7 @@ where
Some(IteratorFromDyn::new(res))
}

fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }
}

pub struct BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -544,12 +549,13 @@ where
if self.0.map.get(x0)?.contains(x1, x2) { Some(once(())) } else { None }
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|rel| rel.len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
sum * map_len / sample_size.min(map_len).max(1)
}
fn is_empty(&'a self) -> bool { self.0.map.is_empty() }
}

pub struct BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel>(&'a mut BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down
Loading
Loading