Skip to content

Commit

Permalink
refactor(core): add predicate into relnode (#219)
Browse files Browse the repository at this point in the history
Extracted from #193, this patch adds
predicate to the relnode and refactors the memo table to store the
predicates. Minimum changes are done to the df-repr to ensure it still
works.

Follow-ups:
* Rename RelNode -> PlanNode
* Rename RelNodeTyp -> PlanNodeTyp
* Refactor RelNode to store RelNodeOrGroup
* Remove `data` from RelNode

Signed-off-by: Alex Chi <iskyzh@gmail.com>
Co-authored-by: Benjamin O <jeep70cp@gmail.com>
  • Loading branch information
skyzh and jurplel authored Nov 5, 2024
1 parent c8e4765 commit 7045f09
Show file tree
Hide file tree
Showing 17 changed files with 181 additions and 4 deletions.
87 changes: 84 additions & 3 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use tracing::trace;
use crate::{
cost::{Cost, Statistics},
property::PropertyBuilderAny,
rel_node::{RelNode, RelNodeRef, RelNodeTyp, Value},
rel_node::{ArcPredNode, RelNode, RelNodeRef, RelNodeTyp, Value},
};

use super::optimizer::{ExprId, GroupId};
use super::optimizer::{ExprId, GroupId, PredId};

pub type RelMemoNodeRef<T> = Arc<RelMemoNode<T>>;

Expand All @@ -24,6 +24,7 @@ pub struct RelMemoNode<T: RelNodeTyp> {
pub typ: T,
pub children: Vec<GroupId>,
pub data: Option<Value>,
pub predicates: Vec<PredId>,
}

impl<T: RelNodeTyp> RelMemoNode<T> {
Expand All @@ -36,6 +37,7 @@ impl<T: RelNodeTyp> RelMemoNode<T> {
.map(|x| Arc::new(RelNode::new_group(x)))
.collect(),
data: self.data,
predicates: Vec::new(), /* TODO: refactor */
}
}
}
Expand All @@ -49,6 +51,9 @@ impl<T: RelNodeTyp> std::fmt::Display for RelMemoNode<T> {
for child in &self.children {
write!(f, " {}", child)?;
}
for pred in &self.predicates {
write!(f, " {}", pred)?;
}
write!(f, ")")
}
}
Expand Down Expand Up @@ -114,6 +119,9 @@ pub trait Memo<T: RelNodeTyp>: 'static + Send + Sync {
/// it will add the expression to the group. Returns the expr id if the expression is not a group.
fn add_expr_to_group(&mut self, rel_node: RelNodeRef<T>, group_id: GroupId) -> Option<ExprId>;

/// Add a new predicate into the memo table.
fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId;

/// Get the group id of an expression.
/// The group id is volatile, depending on whether the groups are merged.
fn get_group_id(&self, expr_id: ExprId) -> GroupId;
Expand All @@ -127,6 +135,9 @@ pub trait Memo<T: RelNodeTyp>: 'static + Send + Sync {
/// Get a group by ID
fn get_group(&self, group_id: GroupId) -> &Group;

/// Get a predicate by ID
fn get_pred(&self, pred_id: PredId) -> ArcPredNode<T>;

/// Update the group info.
fn update_group_info(&mut self, group_id: GroupId, group_info: GroupInfo);

Expand Down Expand Up @@ -159,7 +170,10 @@ pub trait Memo<T: RelNodeTyp>: 'static + Send + Sync {
get_best_group_binding_inner(self, group_id, &mut post_process)
}

/// Get all bindings of a predicate group. Will panic if the group contains more than one bindings.
/// Get all bindings of a predicate group. Will panic if the group contains more than one bindings. Note that we
/// are currently in the refactor process of having predicates as a separate entity. If the representation stores
/// predicates in the rel node children, the repr should use this function to get the predicate binding. Otherwise,
/// use `ger_pred` for those predicates stored within the `predicates` field.
fn get_predicate_binding(&self, group_id: GroupId) -> Option<RelNodeRef<T>> {
get_predicate_binding_group_inner(self, group_id, true)
}
Expand Down Expand Up @@ -189,6 +203,7 @@ fn get_best_group_binding_inner<M: Memo<T> + ?Sized, T: RelNodeTyp>(
typ: expr.typ.clone(),
children,
data: expr.data.clone(),
predicates: expr.predicates.iter().map(|x| this.get_pred(*x)).collect(),
});
post_process(node.clone(), group_id, info);
return Ok(node);
Expand All @@ -215,6 +230,7 @@ fn get_predicate_binding_expr_inner<M: Memo<T> + ?Sized, T: RelNodeTyp>(
typ: expr.typ.clone(),
data: expr.data.clone(),
children,
predicates: expr.predicates.iter().map(|x| this.get_pred(*x)).collect(),
}))
}

Expand Down Expand Up @@ -247,6 +263,10 @@ pub struct NaiveMemo<T: RelNodeTyp> {
groups: HashMap<GroupId, Group>,
expr_id_to_expr_node: HashMap<ExprId, RelMemoNodeRef<T>>,

// Predicate stuff. We don't find logically equivalent predicates. Duplicate predicates
// will have different IDs.
pred_id_to_pred_node: HashMap<PredId, ArcPredNode<T>>,

// Internal states.
group_expr_counter: usize,
property_builders: Arc<[Box<dyn PropertyBuilderAny<T>>]>,
Expand Down Expand Up @@ -287,6 +307,16 @@ impl<T: RelNodeTyp> Memo<T> for NaiveMemo<T> {
Some(expr_id)
}

fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId {
let pred_id = self.next_pred_id();
self.pred_id_to_pred_node.insert(pred_id, pred_node);
pred_id
}

fn get_pred(&self, pred_id: PredId) -> ArcPredNode<T> {
self.pred_id_to_pred_node[&pred_id].clone()
}

fn get_group_id(&self, mut expr_id: ExprId) -> GroupId {
while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) {
expr_id = *new_expr_id;
Expand Down Expand Up @@ -346,6 +376,7 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
expr_id_to_group_id: HashMap::new(),
expr_id_to_expr_node: HashMap::new(),
expr_node_to_expr_id: HashMap::new(),
pred_id_to_pred_node: HashMap::new(),
groups: HashMap::new(),
group_expr_counter: 0,
merged_group_mapping: HashMap::new(),
Expand All @@ -368,6 +399,13 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
ExprId(id)
}

/// Get the next pred id. Group id and expr id shares the same counter, so as to make it easier to debug...
fn next_pred_id(&mut self) -> PredId {
let id = self.group_expr_counter;
self.group_expr_counter += 1;
PredId(id)
}

fn verify_integrity(&self) {
if cfg!(debug_assertions) {
let num_of_exprs = self.expr_id_to_expr_node.len();
Expand Down Expand Up @@ -507,6 +545,11 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
typ: rel_node.typ.clone(),
children: children_group_ids,
data: rel_node.data.clone(),
predicates: rel_node
.predicates
.iter()
.map(|x| self.add_new_pred(x.clone()))
.collect(),
};
if let Some(&expr_id) = self.expr_node_to_expr_id.get(&memo_node) {
let group_id = self.expr_id_to_group_id[&expr_id];
Expand Down Expand Up @@ -550,6 +593,7 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
typ: rel_node.typ.clone(),
children: children_group_ids,
data: rel_node.data.clone(),
predicates: Vec::new(), /* TODO: refactor */
};
let Some(&expr_id) = self.expr_node_to_expr_id.get(&memo_node) else {
unreachable!("not found {}", memo_node)
Expand Down Expand Up @@ -617,6 +661,8 @@ impl<T: RelNodeTyp> NaiveMemo<T> {

#[cfg(test)]
mod tests {
use crate::rel_node::PredNode;

use super::*;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand All @@ -629,6 +675,12 @@ mod tests {
Expr,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum MemoTestPredTyp {
Add,
Minus,
}

impl std::fmt::Display for MemoTestRelTyp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -638,7 +690,15 @@ mod tests {
}
}

impl std::fmt::Display for MemoTestPredTyp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

impl RelNodeTyp for MemoTestRelTyp {
type PredType = MemoTestPredTyp;

fn is_logical(&self) -> bool {
matches!(self, Self::Project | Self::Scan | Self::Join)
}
Expand Down Expand Up @@ -672,6 +732,7 @@ mod tests {
typ: MemoTestRelTyp::Join,
children: vec![left.into(), right.into(), cond.into()],
data: None,
predicates: Vec::new(), /* TODO: refactor */
}
}

Expand All @@ -680,6 +741,7 @@ mod tests {
typ: MemoTestRelTyp::Scan,
children: vec![],
data: Some(Value::String(table.to_string().into())),
predicates: Vec::new(), /* TODO: refactor */
}
}

Expand All @@ -691,6 +753,7 @@ mod tests {
typ: MemoTestRelTyp::Project,
children: vec![input.into(), expr_list.into()],
data: None,
predicates: Vec::new(), /* TODO: refactor */
}
}

Expand All @@ -699,6 +762,7 @@ mod tests {
typ: MemoTestRelTyp::List,
children: items.into_iter().map(|x| x.into()).collect(),
data: None,
predicates: Vec::new(), /* TODO: refactor */
}
}

Expand All @@ -707,6 +771,7 @@ mod tests {
typ: MemoTestRelTyp::Expr,
children: vec![],
data: Some(data),
predicates: Vec::new(), /* TODO: refactor */
}
}

Expand All @@ -715,9 +780,25 @@ mod tests {
typ: MemoTestRelTyp::Group(group_id),
children: vec![],
data: None,
predicates: Vec::new(), /* TODO: refactor */
}
}

#[test]
fn add_predicate() {
let mut memo = NaiveMemo::<MemoTestRelTyp>::new(Arc::new([]));
let pred_node = Arc::new(PredNode {
typ: MemoTestPredTyp::Add,
children: vec![Arc::new(PredNode {
typ: MemoTestPredTyp::Minus,
children: vec![],
data: None,
})],
data: None,
});
memo.add_new_pred(pred_node);
}

#[test]
fn group_merge_1() {
let mut memo = NaiveMemo::new(Arc::new([]));
Expand Down
15 changes: 14 additions & 1 deletion optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
cost::CostModel,
optimizer::Optimizer,
property::{PropertyBuilder, PropertyBuilderAny},
rel_node::{RelNodeMeta, RelNodeMetaMap, RelNodeRef, RelNodeTyp},
rel_node::{ArcPredNode, RelNodeMeta, RelNodeMetaMap, RelNodeRef, RelNodeTyp},
rules::RuleWrapper,
};

Expand Down Expand Up @@ -68,6 +68,9 @@ pub struct GroupId(pub(super) usize);
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)]
pub struct ExprId(pub usize);

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)]
pub struct PredId(pub usize);

impl Display for GroupId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "!{}", self.0)
Expand All @@ -80,6 +83,12 @@ impl Display for ExprId {
}
}

impl Display for PredId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "P{}", self.0)
}
}

impl<T: RelNodeTyp> CascadesOptimizer<T, NaiveMemo<T>> {
pub fn new(
rules: Vec<Arc<RuleWrapper<T, Self>>>,
Expand Down Expand Up @@ -330,6 +339,10 @@ impl<T: RelNodeTyp, M: Memo<T>> CascadesOptimizer<T, M> {
self.memo.get_predicate_binding(group_id)
}

pub fn get_predicate(&self, pred_id: PredId) -> ArcPredNode<T> {
self.memo.get_pred(pred_id)
}

pub(super) fn is_group_explored(&self, group_id: GroupId) -> bool {
self.explored_group.contains(&group_id)
}
Expand Down
6 changes: 6 additions & 0 deletions optd-core/src/cascades/tasks/apply_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ fn match_node<T: RelNodeTyp, M: Memo<T>>(
.map(|x| RelNode::new_group(*x).into())
.collect_vec(),
data: node.data.clone(),
// rule engine by default captures all predicates
predicates: node
.predicates
.iter()
.map(|x| optimizer.get_predicate(*x))
.collect(),
},
);
assert!(res.is_none(), "dup pick");
Expand Down
3 changes: 3 additions & 0 deletions optd-core/src/heuristics/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn match_node<T: RelNodeTyp>(
typ: typ.clone(),
children: node.children.clone(),
data: node.data.clone(),
predicates: node.predicates.clone(),
},
);
assert!(res.is_none(), "dup pick");
Expand Down Expand Up @@ -153,6 +154,7 @@ impl<T: RelNodeTyp> HeuristicsOptimizer<T> {
typ: root_rel.typ.clone(),
children: optimized_children,
data: root_rel.data.clone(),
predicates: root_rel.predicates.clone(),
}
.into(),
)?;
Expand All @@ -165,6 +167,7 @@ impl<T: RelNodeTyp> HeuristicsOptimizer<T> {
typ: root_rel.typ.clone(),
children: optimized_children,
data: root_rel.data.clone(),
predicates: root_rel.predicates.clone(),
}
.into();
Ok(node)
Expand Down
Loading

0 comments on commit 7045f09

Please sign in to comment.