Skip to content

Commit

Permalink
Minor: address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rubenada committed Jul 12, 2024
1 parent 485f3c1 commit 5d281fc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,8 @@ protected RexNode removeCorrelationExpr(
// Now add the corVars from the input, starting from
// position oldGroupKeyCount.
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
// Verify if the CorDef position was already added to the mapNewInputToProjOutputs
// during the previous group key processing
final Integer pos = mapNewInputToProjOutputs.get(entry.getValue());
if (pos == null) {
RexInputRef.add2(projects, entry.getValue(), newInputOutput);
Expand Down
109 changes: 56 additions & 53 deletions core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ public static Frameworks.ConfigBuilder config() {
assertThat(after, hasTree(planAfter));
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-2744">[CALCITE-2744]
* RelDecorrelator use wrong output map for LogicalAggregate decorrelate</a>. */
/**
* Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6468">[CALCITE-6468] RelDecorrelator
* throws AssertionError if correlated variable is used as Aggregate group key</a>.
*/
@Test void testCorrVarOnAggregateKey() {
final FrameworkConfig frameworkConfig = config().build();
final RelBuilder builder = RelBuilder.create(frameworkConfig);
Expand All @@ -112,62 +114,63 @@ public static Frameworks.ConfigBuilder config() {
+ " (SELECT deptno, sum(sal) AS total FROM emp GROUP BY deptno)\n"
+ " SELECT 1 FROM agg_sal s1"
+ " WHERE s1.total > (SELECT avg(total) FROM agg_sal s2 WHERE s1.deptno = s2.deptno)";
final RelNode originalRel;
try {
final SqlNode parse = planner.parse(sql);
final SqlNode validate = planner.validate(parse);
final RelNode originalRel = planner.rel(validate).rel;
final HepProgram hepProgram = HepProgram.builder()
.addRuleCollection(
ImmutableList.of(
// SubQuery program rules
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
// plus FilterAggregateTransposeRule
CoreRules.FILTER_AGGREGATE_TRANSPOSE))
.build();
final Program program =
Programs.of(hepProgram, true, Objects.requireNonNull(cluster.getMetadataProvider()));
final RelNode before =
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalProject(EXPR$0=[1])\n"
+ " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+ " LogicalFilter(condition=[>($1, $2)])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n"
+ " LogicalProject(TOTAL=[$1])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(before, hasTree(planBefore));

// Check decorrelation does not fail here
final RelNode after = RelDecorrelator.decorrelateQuery(before, builder);

// Verify plan
final String planAfter = ""
+ "LogicalProject(EXPR$0=[1])\n"
+ " LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], joinType=[inner])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
+ " LogicalFilter(condition=[IS NOT NULL($0)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(after, hasTree(planAfter));
originalRel = planner.rel(validate).rel;
} catch (Exception e) {
throw TestUtil.rethrow(e);
}

final HepProgram hepProgram = HepProgram.builder()
.addRuleCollection(
ImmutableList.of(
// SubQuery program rules
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
// plus FilterAggregateTransposeRule
CoreRules.FILTER_AGGREGATE_TRANSPOSE))
.build();
final Program program =
Programs.of(hepProgram, true, Objects.requireNonNull(cluster.getMetadataProvider()));
final RelNode before =
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalProject(EXPR$0=[1])\n"
+ " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+ " LogicalFilter(condition=[>($1, $2)])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{}], EXPR$0=[AVG($0)])\n"
+ " LogicalProject(TOTAL=[$1])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalFilter(condition=[=($cor0.DEPTNO, $0)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(before, hasTree(planBefore));

// Check decorrelation does not fail here
final RelNode after = RelDecorrelator.decorrelateQuery(before, builder);

// Verify plan
final String planAfter = ""
+ "LogicalProject(EXPR$0=[1])\n"
+ " LogicalJoin(condition=[AND(=($0, $2), >($1, $3))], joinType=[inner])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$0], TOTAL=[$1])\n"
+ " LogicalAggregate(group=[{0}], TOTAL=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$0], SAL=[$1])\n"
+ " LogicalFilter(condition=[IS NOT NULL($0)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(after, hasTree(planAfter));
}
}

0 comments on commit 5d281fc

Please sign in to comment.