Skip to content

Commit

Permalink
refactor(core,repr): expose group type in core
Browse files Browse the repository at this point in the history
Co-Authored-By: Benjamin O <jeep70cp@gmail.com>
Signed-off-by: Alex Chi <iskyzh@gmail.com>
  • Loading branch information
skyzh and jurplel committed Nov 5, 2024
1 parent f81649c commit 1cd9e10
Show file tree
Hide file tree
Showing 40 changed files with 754 additions and 693 deletions.
118 changes: 51 additions & 67 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tracing::trace;
use crate::{
cost::{Cost, Statistics},
property::PropertyBuilderAny,
rel_node::{ArcPredNode, RelNode, RelNodeRef, RelNodeTyp, Value},
rel_node::{ArcPredNode, MaybeRelNode, RelNode, RelNodeRef, RelNodeTyp, Value},
};

use super::optimizer::{ExprId, GroupId, PredId};
Expand All @@ -34,7 +34,7 @@ impl<T: RelNodeTyp> RelMemoNode<T> {
children: self
.children
.into_iter()
.map(|x| Arc::new(RelNode::new_group(x)))
.map(|x| MaybeRelNode::Group(x))
.collect(),
data: self.data,
predicates: Vec::new(), /* TODO: refactor */
Expand Down Expand Up @@ -117,7 +117,8 @@ pub trait Memo<T: RelNodeTyp>: 'static + Send + Sync {

/// Add a new expression to an existing gruop. If the expression is a group, it will merge the two groups. Otherwise,
/// it will add the expression to the group. Returns the expr id if the expression is not a group.
fn add_expr_to_group(&mut self, rel_node: RelNodeRef<T>, group_id: GroupId) -> Option<ExprId>;
fn add_expr_to_group(&mut self, rel_node: MaybeRelNode<T>, group_id: GroupId)
-> Option<ExprId>;

/// Add a new predicate into the memo table.
fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId;
Expand Down Expand Up @@ -194,10 +195,10 @@ fn get_best_group_binding_inner<M: Memo<T> + ?Sized, T: RelNodeTyp>(
let expr = this.get_expr_memoed(*expr_id);
let mut children = Vec::with_capacity(expr.children.len());
for child in &expr.children {
children.push(
children.push(MaybeRelNode::RelNode(
get_best_group_binding_inner(this, *child, post_process)
.with_context(|| format!("when processing expr {}", expr_id))?,
);
));
}
let node = Arc::new(RelNode {
typ: expr.typ.clone(),
Expand All @@ -221,7 +222,7 @@ fn get_predicate_binding_expr_inner<M: Memo<T> + ?Sized, T: RelNodeTyp>(
for child in expr.children.iter() {
if let Some(child) = get_predicate_binding_group_inner(this, *child, panic_on_invalid_group)
{
children.push(child);
children.push(MaybeRelNode::RelNode(child));
} else {
return None;
}
Expand Down Expand Up @@ -291,20 +292,28 @@ impl<T: RelNodeTyp> Memo<T> for NaiveMemo<T> {
(group_id, expr_id)
}

fn add_expr_to_group(&mut self, rel_node: RelNodeRef<T>, group_id: GroupId) -> Option<ExprId> {
if let Some(input_group) = rel_node.typ.extract_group() {
let input_group = self.reduce_group(input_group);
let group_id = self.reduce_group(group_id);
self.merge_group_inner(input_group, group_id);
return None;
fn add_expr_to_group(
&mut self,
rel_node: MaybeRelNode<T>,
group_id: GroupId,
) -> Option<ExprId> {
match rel_node {
MaybeRelNode::Group(input_group) => {
let input_group = self.reduce_group(input_group);
let group_id = self.reduce_group(group_id);
self.merge_group_inner(input_group, group_id);
return None;
}
MaybeRelNode::RelNode(rel_node) => {
let reduced_group_id = self.reduce_group(group_id);
let (returned_group_id, expr_id) = self
.add_new_group_expr_inner(rel_node, Some(reduced_group_id))
.unwrap();
assert_eq!(returned_group_id, reduced_group_id);
self.verify_integrity();
Some(expr_id)
}
}
let reduced_group_id = self.reduce_group(group_id);
let (returned_group_id, expr_id) = self
.add_new_group_expr_inner(rel_node, Some(reduced_group_id))
.unwrap();
assert_eq!(returned_group_id, reduced_group_id);
self.verify_integrity();
Some(expr_id)
}

fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId {
Expand Down Expand Up @@ -525,19 +534,19 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
rel_node: RelNodeRef<T>,
add_to_group_id: Option<GroupId>,
) -> anyhow::Result<(GroupId, ExprId)> {
assert!(rel_node.typ.extract_group().is_none());
let children_group_ids = rel_node
.children
.iter()
.map(|child| {
if let Some(group) = child.typ.extract_group() {
self.reduce_group(group) // TODO: can I remove?
} else {
// No merge / modification to the memo should occur for the following operation
let (group, _) = self
.add_new_group_expr_inner(child.clone(), None)
.expect("should not trigger merge group");
self.reduce_group(group) // TODO: can I remove?
match child {
MaybeRelNode::Group(group) => self.reduce_group(*group), // TODO: can I remove reduce?
MaybeRelNode::RelNode(child) => {
// No merge / modification to the memo should occur for the following operation
let (group, _) = self
.add_new_group_expr_inner(child.clone(), None)
.expect("should not trigger merge group");
self.reduce_group(group) // TODO: can I remove?
}
}
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -581,12 +590,9 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
let children_group_ids = rel_node
.children
.iter()
.map(|child| {
if let Some(group) = child.typ.extract_group() {
group
} else {
self.get_expr_info(child.clone()).0
}
.map(|child| match child {
MaybeRelNode::Group(group) => *group,
MaybeRelNode::RelNode(child) => self.get_expr_info(child.clone()).0,
})
.collect::<Vec<_>>();
let memo_node = RelMemoNode {
Expand Down Expand Up @@ -667,7 +673,6 @@ mod tests {

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum MemoTestRelTyp {
Group(GroupId),
List,
Join,
Project,
Expand All @@ -683,10 +688,7 @@ mod tests {

impl std::fmt::Display for MemoTestRelTyp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Group(x) => write!(f, "{}", x),
other => write!(f, "{:?}", other),
}
write!(f, "{:?}", self)
}
}

Expand All @@ -703,30 +705,17 @@ mod tests {
matches!(self, Self::Project | Self::Scan | Self::Join)
}

fn group_typ(group_id: GroupId) -> Self {
Self::Group(group_id)
}

fn list_typ() -> Self {
Self::List
}

fn extract_group(&self) -> Option<GroupId> {
if let Self::Group(group_id) = self {
Some(*group_id)
} else {
None
}
}
}

type MemoTestRelNode = RelNode<MemoTestRelTyp>;
type MemoTestRelNodeRef = RelNodeRef<MemoTestRelTyp>;

fn join(
left: impl Into<MemoTestRelNodeRef>,
right: impl Into<MemoTestRelNodeRef>,
cond: impl Into<MemoTestRelNodeRef>,
left: impl Into<MaybeRelNode<MemoTestRelTyp>>,
right: impl Into<MaybeRelNode<MemoTestRelTyp>>,
cond: impl Into<MaybeRelNode<MemoTestRelTyp>>,
) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::Join,
Expand All @@ -746,8 +735,8 @@ mod tests {
}

fn project(
input: impl Into<MemoTestRelNodeRef>,
expr_list: impl Into<MemoTestRelNodeRef>,
input: impl Into<MaybeRelNode<MemoTestRelTyp>>,
expr_list: impl Into<MaybeRelNode<MemoTestRelTyp>>,
) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::Project,
Expand All @@ -757,7 +746,7 @@ mod tests {
}
}

fn list(items: Vec<impl Into<MemoTestRelNodeRef>>) -> MemoTestRelNode {
fn list(items: Vec<impl Into<MaybeRelNode<MemoTestRelTyp>>>) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::List,
children: items.into_iter().map(|x| x.into()).collect(),
Expand All @@ -775,13 +764,8 @@ mod tests {
}
}

fn group(group_id: GroupId) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::Group(group_id),
children: vec![],
data: None,
predicates: Vec::new(), /* TODO: refactor */
}
fn group(group_id: GroupId) -> MaybeRelNode<MemoTestRelTyp> {
MaybeRelNode::Group(group_id)
}

#[test]
Expand Down Expand Up @@ -865,8 +849,8 @@ mod tests {
let (group_1, _) = memo.get_expr_info(expr1.clone());
let (group_2, _) = memo.get_expr_info(expr2.clone());
assert_eq!(group_1, group_2);
let (group_1, _) = memo.get_expr_info(expr1.child(0));
let (group_2, _) = memo.get_expr_info(expr2.child(0));
let (group_1, _) = memo.get_expr_info(expr1.child_rel(0));
let (group_2, _) = memo.get_expr_info(expr2.child_rel(0));
assert_eq!(group_1, group_2);
}

Expand Down
18 changes: 10 additions & 8 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
cost::CostModel,
optimizer::Optimizer,
property::{PropertyBuilder, PropertyBuilderAny},
rel_node::{ArcPredNode, RelNodeMeta, RelNodeMetaMap, RelNodeRef, RelNodeTyp},
rel_node::{ArcPredNode, MaybeRelNode, RelNodeMeta, RelNodeMetaMap, RelNodeRef, RelNodeTyp},
rules::RuleWrapper,
};

Expand Down Expand Up @@ -281,11 +281,9 @@ impl<T: RelNodeTyp, M: Memo<T>> CascadesOptimizer<T, M> {
self.memo.get_best_group_binding(group_id, |_, _, _| {})
}

pub fn resolve_group_id(&self, root_rel: RelNodeRef<T>) -> GroupId {
if let Some(group_id) = T::extract_group(&root_rel.typ) {
return group_id;
}
panic!("This function is deprecated -- you should only pass group id instead of a full expression to this function.")
pub fn resolve_group_id(&self, root_rel: MaybeRelNode<T>) -> GroupId {
root_rel.unwrap_group()
// panic!("This function is deprecated -- you should only pass group id instead of a full expression to this function.")
}

pub(super) fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
Expand All @@ -298,7 +296,7 @@ impl<T: RelNodeTyp, M: Memo<T>> CascadesOptimizer<T, M> {

pub fn add_expr_to_group(
&mut self,
rel_node: RelNodeRef<T>,
rel_node: MaybeRelNode<T>,
group_id: GroupId,
) -> Option<ExprId> {
self.memo.add_expr_to_group(rel_node, group_id)
Expand Down Expand Up @@ -387,7 +385,11 @@ impl<T: RelNodeTyp, M: Memo<T>> Optimizer<T> for CascadesOptimizer<T, M> {
self.optimize_inner(root_rel)
}

fn get_property<P: PropertyBuilder<T>>(&self, root_rel: RelNodeRef<T>, idx: usize) -> P::Prop {
fn get_property<P: PropertyBuilder<T>>(
&self,
root_rel: MaybeRelNode<T>,
idx: usize,
) -> P::Prop {
self.get_property_by_group::<P>(self.resolve_group_id(root_rel), idx)
}
}
Loading

0 comments on commit 1cd9e10

Please sign in to comment.