From 90ecca062a086a292a43c2baaa2ce607d9a9e11e Mon Sep 17 00:00:00 2001 From: rubenada Date: Wed, 31 Jan 2024 16:40:41 +0000 Subject: [PATCH] [CALCITE-6236] EnumerableBatchNestedLoopJoin uses wrong row count for cost calculation --- .../EnumerableBatchNestedLoopJoin.java | 17 +++++++--- .../EnumerableBatchNestedLoopJoinRule.java | 3 +- .../calcite/rel/metadata/RelMdRowCount.java | 6 ++++ .../apache/calcite/test/RelMetadataTest.java | 32 +++++++++++++++++++ .../GeneratedMetadata_RowCountHandler.java | 4 ++- 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java index 94ef9c32133..cff5babb7da 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java @@ -55,6 +55,7 @@ public class EnumerableBatchNestedLoopJoin extends Join implements EnumerableRel { private final ImmutableBitSet requiredColumns; + private final Join originalJoin; protected EnumerableBatchNestedLoopJoin( RelOptCluster cluster, RelTraitSet traits, @@ -63,9 +64,11 @@ protected EnumerableBatchNestedLoopJoin( RexNode condition, Set variablesSet, ImmutableBitSet requiredColumns, - JoinRelType joinType) { + JoinRelType joinType, + Join originalJoin) { super(cluster, traits, ImmutableList.of(), left, right, condition, variablesSet, joinType); this.requiredColumns = requiredColumns; + this.originalJoin = originalJoin; } public static EnumerableBatchNestedLoopJoin create( @@ -74,7 +77,8 @@ public static EnumerableBatchNestedLoopJoin create( RexNode condition, ImmutableBitSet requiredColumns, Set variablesSet, - JoinRelType joinType) { + JoinRelType joinType, + Join originalJoin) { final RelOptCluster cluster = left.getCluster(); final RelMetadataQuery mq = cluster.getMetadataQuery(); final RelTraitSet traitSet = @@ -89,7 +93,12 @@ public static EnumerableBatchNestedLoopJoin create( condition, variablesSet, requiredColumns, - joinType); + joinType, + originalJoin); + } + + public Join getOriginalJoin() { + return originalJoin; } @Override public @Nullable Pair> passThroughTraits( @@ -116,7 +125,7 @@ public static EnumerableBatchNestedLoopJoin create( RexNode condition, RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone) { return new EnumerableBatchNestedLoopJoin(getCluster(), traitSet, - left, right, condition, variablesSet, requiredColumns, joinType); + left, right, condition, variablesSet, requiredColumns, joinType, originalJoin); } @Override public @Nullable RelOptCost computeSelfCost( diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java index d7accbd7058..5c326dd7dff 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java @@ -147,7 +147,8 @@ public EnumerableBatchNestedLoopJoinRule(RelBuilderFactory relBuilderFactory, join.getCondition(), requiredColumns.build(), correlationIds, - join.getJoinType())); + join.getJoinType(), + join)); } /** Rule configuration. */ diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdRowCount.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdRowCount.java index b7393823ba5..9c5995bd203 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdRowCount.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdRowCount.java @@ -16,6 +16,7 @@ */ package org.apache.calcite.rel.metadata; +import org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin; import org.apache.calcite.adapter.enumerable.EnumerableLimit; import org.apache.calcite.plan.volcano.RelSubset; import org.apache.calcite.rel.RelNode; @@ -181,6 +182,11 @@ public Double getRowCount(Calc rel, RelMetadataQuery mq) { return mq.getRowCount(rel.getInput()); } + // Ensures that EnumerableBatchNestedLoopJoin has the same rowCount as the join that originated it + public @Nullable Double getRowCount(EnumerableBatchNestedLoopJoin join, RelMetadataQuery mq) { + return mq.getRowCount(join.getOriginalJoin()); + } + public @Nullable Double getRowCount(Join rel, RelMetadataQuery mq) { return RelMdUtil.getJoinRowCount(mq, rel, rel.getCondition()); } diff --git a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java index 24155414e69..3715889c701 100644 --- a/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelMetadataTest.java @@ -15,6 +15,8 @@ * limitations under the License. */ package org.apache.calcite.test; + +import org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin; import org.apache.calcite.adapter.enumerable.EnumerableConvention; import org.apache.calcite.adapter.enumerable.EnumerableLimit; import org.apache.calcite.adapter.enumerable.EnumerableMergeJoin; @@ -796,6 +798,36 @@ final RelMetadataFixture sql(String sql) { sql(sql).assertThatRowCount(is(1D), is(1D), is(1D)); } + @Test void testRowCountEnumerableBatchNestedLoopJoin() { + final RelBuilder builder = RelBuilder.create(RelBuilderTest.config().build()); + final RelNode relNode1 = builder + .scan("EMP") + .project(builder.field("DEPTNO")) + .scan("DEPT") + .project(builder.field("DEPTNO")) + .join( + JoinRelType.INNER, + builder.equals( + builder.field(2, 0, 0), + builder.field(2, 1, 0))) + .build(); + + final RelMetadataQuery mq = relNode1.getCluster().getMetadataQuery(); + assertThat(relNode1, instanceOf(LogicalJoin.class)); + final Double rowCount1 = mq.getRowCount(relNode1); + + // Program to convert LogicalJoin into EnumerableBatchNestedLoopJoin + final HepProgram program = new HepProgramBuilder() + .addRuleInstance(EnumerableRules.ENUMERABLE_BATCH_NESTED_LOOP_JOIN_RULE) + .build(); + final HepPlanner hepPlanner = new HepPlanner(program); + hepPlanner.setRoot(relNode1); + final RelNode relNode2 = hepPlanner.findBestExp(); + assertThat(relNode2, instanceOf(EnumerableBatchNestedLoopJoin.class)); + final Double rowCount2 = mq.getRowCount(relNode2); + assertThat(rowCount2, equalTo(rowCount1)); + } + // ---------------------------------------------------------------------- // Tests for computeSelfCost.cpu // ---------------------------------------------------------------------- diff --git a/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_RowCountHandler.java b/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_RowCountHandler.java index cef8019f14c..cec12251e9b 100644 --- a/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_RowCountHandler.java +++ b/core/src/test/resources/org/apache/calcite/rel/metadata/janino/GeneratedMetadata_RowCountHandler.java @@ -60,7 +60,9 @@ public java.lang.Double getRowCount( private java.lang.Double getRowCount_( org.apache.calcite.rel.RelNode r, org.apache.calcite.rel.metadata.RelMetadataQuery mq) { - if (r instanceof org.apache.calcite.adapter.enumerable.EnumerableLimit) { + if (r instanceof org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin) { + return provider0.getRowCount((org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin) r, mq); + } else if (r instanceof org.apache.calcite.adapter.enumerable.EnumerableLimit) { return provider0.getRowCount((org.apache.calcite.adapter.enumerable.EnumerableLimit) r, mq); } else if (r instanceof org.apache.calcite.plan.volcano.RelSubset) { return provider0.getRowCount((org.apache.calcite.plan.volcano.RelSubset) r, mq);