Skip to content

Commit

Permalink
[CALCITE-6236] EnumerableBatchNestedLoopJoin uses wrong row count for…
Browse files Browse the repository at this point in the history
… cost calculation
  • Loading branch information
rubenada committed Jan 31, 2024
1 parent e17098d commit 90ecca0
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
public class EnumerableBatchNestedLoopJoin extends Join implements EnumerableRel {

private final ImmutableBitSet requiredColumns;
private final Join originalJoin;
protected EnumerableBatchNestedLoopJoin(
RelOptCluster cluster,
RelTraitSet traits,
Expand All @@ -63,9 +64,11 @@ protected EnumerableBatchNestedLoopJoin(
RexNode condition,
Set<CorrelationId> variablesSet,
ImmutableBitSet requiredColumns,
JoinRelType joinType) {
JoinRelType joinType,
Join originalJoin) {
super(cluster, traits, ImmutableList.of(), left, right, condition, variablesSet, joinType);
this.requiredColumns = requiredColumns;
this.originalJoin = originalJoin;
}

public static EnumerableBatchNestedLoopJoin create(
Expand All @@ -74,7 +77,8 @@ public static EnumerableBatchNestedLoopJoin create(
RexNode condition,
ImmutableBitSet requiredColumns,
Set<CorrelationId> variablesSet,
JoinRelType joinType) {
JoinRelType joinType,
Join originalJoin) {
final RelOptCluster cluster = left.getCluster();
final RelMetadataQuery mq = cluster.getMetadataQuery();
final RelTraitSet traitSet =
Expand All @@ -89,7 +93,12 @@ public static EnumerableBatchNestedLoopJoin create(
condition,
variablesSet,
requiredColumns,
joinType);
joinType,
originalJoin);
}

public Join getOriginalJoin() {
return originalJoin;
}

@Override public @Nullable Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(
Expand All @@ -116,7 +125,7 @@ public static EnumerableBatchNestedLoopJoin create(
RexNode condition, RelNode left, RelNode right, JoinRelType joinType,
boolean semiJoinDone) {
return new EnumerableBatchNestedLoopJoin(getCluster(), traitSet,
left, right, condition, variablesSet, requiredColumns, joinType);
left, right, condition, variablesSet, requiredColumns, joinType, originalJoin);
}

@Override public @Nullable RelOptCost computeSelfCost(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ public EnumerableBatchNestedLoopJoinRule(RelBuilderFactory relBuilderFactory,
join.getCondition(),
requiredColumns.build(),
correlationIds,
join.getJoinType()));
join.getJoinType(),
join));
}

/** Rule configuration. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.calcite.rel.metadata;

import org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin;
import org.apache.calcite.adapter.enumerable.EnumerableLimit;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
Expand Down Expand Up @@ -181,6 +182,11 @@ public Double getRowCount(Calc rel, RelMetadataQuery mq) {
return mq.getRowCount(rel.getInput());
}

// Ensures that EnumerableBatchNestedLoopJoin has the same rowCount as the join that originated it
public @Nullable Double getRowCount(EnumerableBatchNestedLoopJoin join, RelMetadataQuery mq) {
return mq.getRowCount(join.getOriginalJoin());
}

public @Nullable Double getRowCount(Join rel, RelMetadataQuery mq) {
return RelMdUtil.getJoinRowCount(mq, rel, rel.getCondition());
}
Expand Down
32 changes: 32 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelMetadataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
* limitations under the License.
*/
package org.apache.calcite.test;

import org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableLimit;
import org.apache.calcite.adapter.enumerable.EnumerableMergeJoin;
Expand Down Expand Up @@ -796,6 +798,36 @@ final RelMetadataFixture sql(String sql) {
sql(sql).assertThatRowCount(is(1D), is(1D), is(1D));
}

@Test void testRowCountEnumerableBatchNestedLoopJoin() {
final RelBuilder builder = RelBuilder.create(RelBuilderTest.config().build());
final RelNode relNode1 = builder
.scan("EMP")
.project(builder.field("DEPTNO"))
.scan("DEPT")
.project(builder.field("DEPTNO"))
.join(
JoinRelType.INNER,
builder.equals(
builder.field(2, 0, 0),
builder.field(2, 1, 0)))
.build();

final RelMetadataQuery mq = relNode1.getCluster().getMetadataQuery();
assertThat(relNode1, instanceOf(LogicalJoin.class));
final Double rowCount1 = mq.getRowCount(relNode1);

// Program to convert LogicalJoin into EnumerableBatchNestedLoopJoin
final HepProgram program = new HepProgramBuilder()
.addRuleInstance(EnumerableRules.ENUMERABLE_BATCH_NESTED_LOOP_JOIN_RULE)
.build();
final HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(relNode1);
final RelNode relNode2 = hepPlanner.findBestExp();
assertThat(relNode2, instanceOf(EnumerableBatchNestedLoopJoin.class));
final Double rowCount2 = mq.getRowCount(relNode2);
assertThat(rowCount2, equalTo(rowCount1));
}

// ----------------------------------------------------------------------
// Tests for computeSelfCost.cpu
// ----------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ public java.lang.Double getRowCount(
private java.lang.Double getRowCount_(
org.apache.calcite.rel.RelNode r,
org.apache.calcite.rel.metadata.RelMetadataQuery mq) {
if (r instanceof org.apache.calcite.adapter.enumerable.EnumerableLimit) {
if (r instanceof org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin) {
return provider0.getRowCount((org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin) r, mq);
} else if (r instanceof org.apache.calcite.adapter.enumerable.EnumerableLimit) {
return provider0.getRowCount((org.apache.calcite.adapter.enumerable.EnumerableLimit) r, mq);
} else if (r instanceof org.apache.calcite.plan.volcano.RelSubset) {
return provider0.getRowCount((org.apache.calcite.plan.volcano.RelSubset) r, mq);
Expand Down

0 comments on commit 90ecca0

Please sign in to comment.