Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core,repr): expose group type in core #221

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 51 additions & 67 deletions optd-core/src/cascades/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tracing::trace;
use crate::{
cost::{Cost, Statistics},
property::PropertyBuilderAny,
rel_node::{ArcPredNode, RelNode, RelNodeRef, RelNodeTyp, Value},
rel_node::{ArcPredNode, MaybeRelNode, RelNode, RelNodeRef, RelNodeTyp, Value},
};

use super::optimizer::{ExprId, GroupId, PredId};
Expand All @@ -34,7 +34,7 @@ impl<T: RelNodeTyp> RelMemoNode<T> {
children: self
.children
.into_iter()
.map(|x| Arc::new(RelNode::new_group(x)))
.map(|x| MaybeRelNode::Group(x))
.collect(),
data: self.data,
predicates: Vec::new(), /* TODO: refactor */
Expand Down Expand Up @@ -117,7 +117,8 @@ pub trait Memo<T: RelNodeTyp>: 'static + Send + Sync {

/// Add a new expression to an existing gruop. If the expression is a group, it will merge the two groups. Otherwise,
/// it will add the expression to the group. Returns the expr id if the expression is not a group.
fn add_expr_to_group(&mut self, rel_node: RelNodeRef<T>, group_id: GroupId) -> Option<ExprId>;
fn add_expr_to_group(&mut self, rel_node: MaybeRelNode<T>, group_id: GroupId)
-> Option<ExprId>;

/// Add a new predicate into the memo table.
fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId;
Expand Down Expand Up @@ -194,10 +195,10 @@ fn get_best_group_binding_inner<M: Memo<T> + ?Sized, T: RelNodeTyp>(
let expr = this.get_expr_memoed(*expr_id);
let mut children = Vec::with_capacity(expr.children.len());
for child in &expr.children {
children.push(
children.push(MaybeRelNode::RelNode(
get_best_group_binding_inner(this, *child, post_process)
.with_context(|| format!("when processing expr {}", expr_id))?,
);
));
}
let node = Arc::new(RelNode {
typ: expr.typ.clone(),
Expand All @@ -221,7 +222,7 @@ fn get_predicate_binding_expr_inner<M: Memo<T> + ?Sized, T: RelNodeTyp>(
for child in expr.children.iter() {
if let Some(child) = get_predicate_binding_group_inner(this, *child, panic_on_invalid_group)
{
children.push(child);
children.push(MaybeRelNode::RelNode(child));
} else {
return None;
}
Expand Down Expand Up @@ -291,20 +292,28 @@ impl<T: RelNodeTyp> Memo<T> for NaiveMemo<T> {
(group_id, expr_id)
}

fn add_expr_to_group(&mut self, rel_node: RelNodeRef<T>, group_id: GroupId) -> Option<ExprId> {
if let Some(input_group) = rel_node.typ.extract_group() {
let input_group = self.reduce_group(input_group);
let group_id = self.reduce_group(group_id);
self.merge_group_inner(input_group, group_id);
return None;
fn add_expr_to_group(
&mut self,
rel_node: MaybeRelNode<T>,
group_id: GroupId,
) -> Option<ExprId> {
match rel_node {
MaybeRelNode::Group(input_group) => {
let input_group = self.reduce_group(input_group);
let group_id = self.reduce_group(group_id);
self.merge_group_inner(input_group, group_id);
return None;
}
MaybeRelNode::RelNode(rel_node) => {
let reduced_group_id = self.reduce_group(group_id);
let (returned_group_id, expr_id) = self
.add_new_group_expr_inner(rel_node, Some(reduced_group_id))
.unwrap();
assert_eq!(returned_group_id, reduced_group_id);
self.verify_integrity();
Some(expr_id)
}
}
let reduced_group_id = self.reduce_group(group_id);
let (returned_group_id, expr_id) = self
.add_new_group_expr_inner(rel_node, Some(reduced_group_id))
.unwrap();
assert_eq!(returned_group_id, reduced_group_id);
self.verify_integrity();
Some(expr_id)
}

fn add_new_pred(&mut self, pred_node: ArcPredNode<T>) -> PredId {
Expand Down Expand Up @@ -525,19 +534,19 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
rel_node: RelNodeRef<T>,
add_to_group_id: Option<GroupId>,
) -> anyhow::Result<(GroupId, ExprId)> {
assert!(rel_node.typ.extract_group().is_none());
let children_group_ids = rel_node
.children
.iter()
.map(|child| {
if let Some(group) = child.typ.extract_group() {
self.reduce_group(group) // TODO: can I remove?
} else {
// No merge / modification to the memo should occur for the following operation
let (group, _) = self
.add_new_group_expr_inner(child.clone(), None)
.expect("should not trigger merge group");
self.reduce_group(group) // TODO: can I remove?
match child {
MaybeRelNode::Group(group) => self.reduce_group(*group), // TODO: can I remove reduce?
MaybeRelNode::RelNode(child) => {
// No merge / modification to the memo should occur for the following operation
let (group, _) = self
.add_new_group_expr_inner(child.clone(), None)
.expect("should not trigger merge group");
self.reduce_group(group) // TODO: can I remove?
}
}
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -581,12 +590,9 @@ impl<T: RelNodeTyp> NaiveMemo<T> {
let children_group_ids = rel_node
.children
.iter()
.map(|child| {
if let Some(group) = child.typ.extract_group() {
group
} else {
self.get_expr_info(child.clone()).0
}
.map(|child| match child {
MaybeRelNode::Group(group) => *group,
MaybeRelNode::RelNode(child) => self.get_expr_info(child.clone()).0,
})
.collect::<Vec<_>>();
let memo_node = RelMemoNode {
Expand Down Expand Up @@ -667,7 +673,6 @@ mod tests {

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum MemoTestRelTyp {
Group(GroupId),
List,
Join,
Project,
Expand All @@ -683,10 +688,7 @@ mod tests {

impl std::fmt::Display for MemoTestRelTyp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Group(x) => write!(f, "{}", x),
other => write!(f, "{:?}", other),
}
write!(f, "{:?}", self)
}
}

Expand All @@ -703,30 +705,17 @@ mod tests {
matches!(self, Self::Project | Self::Scan | Self::Join)
}

fn group_typ(group_id: GroupId) -> Self {
Self::Group(group_id)
}

fn list_typ() -> Self {
Self::List
}

fn extract_group(&self) -> Option<GroupId> {
if let Self::Group(group_id) = self {
Some(*group_id)
} else {
None
}
}
}

type MemoTestRelNode = RelNode<MemoTestRelTyp>;
type MemoTestRelNodeRef = RelNodeRef<MemoTestRelTyp>;

fn join(
left: impl Into<MemoTestRelNodeRef>,
right: impl Into<MemoTestRelNodeRef>,
cond: impl Into<MemoTestRelNodeRef>,
left: impl Into<MaybeRelNode<MemoTestRelTyp>>,
right: impl Into<MaybeRelNode<MemoTestRelTyp>>,
cond: impl Into<MaybeRelNode<MemoTestRelTyp>>,
) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::Join,
Expand All @@ -746,8 +735,8 @@ mod tests {
}

fn project(
input: impl Into<MemoTestRelNodeRef>,
expr_list: impl Into<MemoTestRelNodeRef>,
input: impl Into<MaybeRelNode<MemoTestRelTyp>>,
expr_list: impl Into<MaybeRelNode<MemoTestRelTyp>>,
) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::Project,
Expand All @@ -757,7 +746,7 @@ mod tests {
}
}

fn list(items: Vec<impl Into<MemoTestRelNodeRef>>) -> MemoTestRelNode {
fn list(items: Vec<impl Into<MaybeRelNode<MemoTestRelTyp>>>) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::List,
children: items.into_iter().map(|x| x.into()).collect(),
Expand All @@ -775,13 +764,8 @@ mod tests {
}
}

fn group(group_id: GroupId) -> MemoTestRelNode {
RelNode {
typ: MemoTestRelTyp::Group(group_id),
children: vec![],
data: None,
predicates: Vec::new(), /* TODO: refactor */
}
fn group(group_id: GroupId) -> MaybeRelNode<MemoTestRelTyp> {
MaybeRelNode::Group(group_id)
}

#[test]
Expand Down Expand Up @@ -865,8 +849,8 @@ mod tests {
let (group_1, _) = memo.get_expr_info(expr1.clone());
let (group_2, _) = memo.get_expr_info(expr2.clone());
assert_eq!(group_1, group_2);
let (group_1, _) = memo.get_expr_info(expr1.child(0));
let (group_2, _) = memo.get_expr_info(expr2.child(0));
let (group_1, _) = memo.get_expr_info(expr1.child_rel(0));
let (group_2, _) = memo.get_expr_info(expr2.child_rel(0));
assert_eq!(group_1, group_2);
}

Expand Down
18 changes: 10 additions & 8 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
cost::CostModel,
optimizer::Optimizer,
property::{PropertyBuilder, PropertyBuilderAny},
rel_node::{ArcPredNode, RelNodeMeta, RelNodeMetaMap, RelNodeRef, RelNodeTyp},
rel_node::{ArcPredNode, MaybeRelNode, RelNodeMeta, RelNodeMetaMap, RelNodeRef, RelNodeTyp},
rules::RuleWrapper,
};

Expand Down Expand Up @@ -281,11 +281,9 @@ impl<T: RelNodeTyp, M: Memo<T>> CascadesOptimizer<T, M> {
self.memo.get_best_group_binding(group_id, |_, _, _| {})
}

pub fn resolve_group_id(&self, root_rel: RelNodeRef<T>) -> GroupId {
if let Some(group_id) = T::extract_group(&root_rel.typ) {
return group_id;
}
panic!("This function is deprecated -- you should only pass group id instead of a full expression to this function.")
pub fn resolve_group_id(&self, root_rel: MaybeRelNode<T>) -> GroupId {
root_rel.unwrap_group()
// panic!("This function is deprecated -- you should only pass group id instead of a full expression to this function.")
}

pub(super) fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
Expand All @@ -298,7 +296,7 @@ impl<T: RelNodeTyp, M: Memo<T>> CascadesOptimizer<T, M> {

pub fn add_expr_to_group(
&mut self,
rel_node: RelNodeRef<T>,
rel_node: MaybeRelNode<T>,
group_id: GroupId,
) -> Option<ExprId> {
self.memo.add_expr_to_group(rel_node, group_id)
Expand Down Expand Up @@ -387,7 +385,11 @@ impl<T: RelNodeTyp, M: Memo<T>> Optimizer<T> for CascadesOptimizer<T, M> {
self.optimize_inner(root_rel)
}

fn get_property<P: PropertyBuilder<T>>(&self, root_rel: RelNodeRef<T>, idx: usize) -> P::Prop {
fn get_property<P: PropertyBuilder<T>>(
&self,
root_rel: MaybeRelNode<T>,
idx: usize,
) -> P::Prop {
self.get_property_by_group::<P>(self.resolve_group_id(root_rel), idx)
}
}
Loading
Loading