From 5d281fcd3473f310ab78312e60837a558558240b Mon Sep 17 00:00:00 2001 From: Ruben Quesada Lopez Date: Fri, 12 Jul 2024 20:57:22 +0100 Subject: [PATCH] Minor: address review comments --- .../calcite/sql2rel/RelDecorrelator.java | 2 + .../calcite/sql2rel/RelDecorrelatorTest.java | 109 +++++++++--------- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java index 33a914ac04f4..79fde87ee2f8 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java @@ -560,6 +560,8 @@ protected RexNode removeCorrelationExpr( // Now add the corVars from the input, starting from // position oldGroupKeyCount. for (Map.Entry entry : frame.corDefOutputs.entrySet()) { + // Verify if the CorDef position was already added to the mapNewInputToProjOutputs + // during the previous group key processing final Integer pos = mapNewInputToProjOutputs.get(entry.getValue()); if (pos == null) { RexInputRef.add2(projects, entry.getValue(), newInputOutput); diff --git a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java index 5e365bcc5aa1..dbbf69d87205 100644 --- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java +++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java @@ -100,9 +100,11 @@ public static Frameworks.ConfigBuilder config() { assertThat(after, hasTree(planAfter)); } - /** Test case for - * [CALCITE-2744] - * RelDecorrelator use wrong output map for LogicalAggregate decorrelate. */ + /** + * Test case for + * [CALCITE-6468] RelDecorrelator + * throws AssertionError if correlated variable is used as Aggregate group key. + */ @Test void testCorrVarOnAggregateKey() { final FrameworkConfig frameworkConfig = config().build(); final RelBuilder builder = RelBuilder.create(frameworkConfig); @@ -112,62 +114,63 @@ public static Frameworks.ConfigBuilder config() { + " (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)\n" + " SELECT 1 FROM agg_sal s1" + " WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE s1.deptno = s2.deptno)"; + final RelNode originalRel; try { final SqlNode parse = planner.parse(sql); final SqlNode validate = planner.validate(parse); - final RelNode originalRel = planner.rel(validate).rel; - final HepProgram hepProgram = HepProgram.builder() - .addRuleCollection( - ImmutableList.of( - // SubQuery program rules - CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, - CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, - CoreRules.JOIN_SUB_QUERY_TO_CORRELATE, - // plus FilterAggregateTransposeRule - CoreRules.FILTER_AGGREGATE_TRANSPOSE)) - .build(); - final Program program = - Programs.of(hepProgram, true, Objects.requireNonNull(cluster.getMetadataProvider())); - final RelNode before = - program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), - Collections.emptyList(), Collections.emptyList()); - final String planBefore = "" - + "LogicalProject(EXPR$0=[1])\n" - + " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n" - + " LogicalFilter(condition=[>($1, $2)])\n" - + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" - + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n" - + " LogicalProject(TOTAL=[$1])\n" - + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" - + " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; - assertThat(before, hasTree(planBefore)); - - // Check decorrelation does not fail here - final RelNode after = RelDecorrelator.decorrelateQuery(before, builder); - - // Verify plan - final String planAfter = "" - + "LogicalProject(EXPR$0=[1])\n" - + " LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], joinType=[inner])\n" - + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n" - + " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n" - + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" - + " LogicalProject(DEPTNO=[$0], SAL=[$1])\n" - + " LogicalFilter(condition=[IS NOT NULL($0)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; - assertThat(after, hasTree(planAfter)); + originalRel = planner.rel(validate).rel; } catch (Exception e) { throw TestUtil.rethrow(e); } + final HepProgram hepProgram = HepProgram.builder() + .addRuleCollection( + ImmutableList.of( + // SubQuery program rules + CoreRules.FILTER_SUB_QUERY_TO_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE, + CoreRules.JOIN_SUB_QUERY_TO_CORRELATE, + // plus FilterAggregateTransposeRule + CoreRules.FILTER_AGGREGATE_TRANSPOSE)) + .build(); + final Program program = + Programs.of(hepProgram, true, Objects.requireNonNull(cluster.getMetadataProvider())); + final RelNode before = + program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), + Collections.emptyList(), Collections.emptyList()); + final String planBefore = "" + + "LogicalProject(EXPR$0=[1])\n" + + " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n" + + " LogicalFilter(condition=[>($1, $2)])\n" + + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n" + + " LogicalProject(TOTAL=[$1])\n" + + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" + + " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(before, hasTree(planBefore)); + + // Check decorrelation does not fail here + final RelNode after = RelDecorrelator.decorrelateQuery(before, builder); + + // Verify plan + final String planAfter = "" + + "LogicalProject(EXPR$0=[1])\n" + + " LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], joinType=[inner])\n" + + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n" + + " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n" + + " LogicalProject(DEPTNO=[$0], SAL=[$1])\n" + + " LogicalFilter(condition=[IS NOT NULL($0)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + assertThat(after, hasTree(planAfter)); } }