diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index fc210ed26cc..2eb8dfff959 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -122,7 +122,8 @@ public QueryEnvironment(TypeFactory typeFactory, CalciteSchema rootSchema, Worke // SUB-QUERY Threshold is useless as we are encoding all IN clause in-line anyway .withInSubQueryThreshold(Integer.MAX_VALUE) .addRelBuilderConfigTransform(c -> c.withPushJoinCondition(true)) - .addRelBuilderConfigTransform(c -> c.withAggregateUnique(true))) + .addRelBuilderConfigTransform(c -> c.withAggregateUnique(true)) + .addRelBuilderConfigTransform(c -> c.withPruneInputOfAggregate(false))) .build(); _optProgram = getOptProgram(); _traitProgram = getTraitProgram(); diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json index 9ab9edebdab..7ce235acf0b 100644 --- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json +++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json @@ -424,7 +424,81 @@ ] }, { - "description": "nexted reused tmp table SEMI JOINs", + "description": "multiple SEMI JOINs on same column multiple conditions and multiple columns", + "sql": "EXPLAIN PLAN FOR WITH tmp1 AS ( SELECT * FROM a WHERE col2 NOT IN ('foo', 'bar') ) SELECT * FROM a WHERE col2 IN (SELECT col1 FROM tmp1) AND col2 IN (SELECT col1 FROM b WHERE col3 > 100) AND col3 IN (SELECT col3 from b WHERE col3 < 100)", + "output": [ + "Execution Plan", + "\nLogicalJoin(condition=[=($2, $7)], joinType=[semi])", + "\n LogicalJoin(condition=[=($1, $7)], joinType=[semi])", + "\n LogicalJoin(condition=[=($1, $7)], joinType=[semi])", + "\n LogicalTableScan(table=[[a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col1=[$0])", + "\n LogicalFilter(condition=[AND(<>($1, 'bar'), <>($1, 'foo'))])", + "\n LogicalTableScan(table=[[a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col1=[$0])", + "\n LogicalFilter(condition=[>($2, 100)])", + "\n LogicalTableScan(table=[[b]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col3=[$2])", + "\n LogicalFilter(condition=[<($2, 100)])", + "\n LogicalTableScan(table=[[b]])", + "\n" + ] + }, + { + "description": "multiple SEMI JOINs with Agg", + "sql": "EXPLAIN PLAN FOR WITH tmp1 AS ( SELECT * FROM a WHERE col2 NOT IN ('foo', 'bar') ) SELECT COUNT(*) FROM a WHERE col2 IN (SELECT col1 FROM tmp1) AND col3 IN (SELECT col3 from b WHERE col3 < 100)", + "output": [ + "Execution Plan", + "\nLogicalAggregate(group=[{}], agg#0=[COUNT($0)])", + "\n PinotLogicalExchange(distribution=[hash])", + "\n LogicalAggregate(group=[{}], agg#0=[COUNT()])", + "\n LogicalJoin(condition=[=($0, $1)], joinType=[semi])", + "\n LogicalProject(col3=[$1])", + "\n LogicalJoin(condition=[=($0, $2)], joinType=[semi])", + "\n LogicalProject(col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col1=[$0])", + "\n LogicalFilter(condition=[AND(<>($1, 'bar'), <>($1, 'foo'))])", + "\n LogicalTableScan(table=[[a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col3=[$2])", + "\n LogicalFilter(condition=[<($2, 100)])", + "\n LogicalTableScan(table=[[b]])", + "\n" + ] + }, + { + "description": "multiple SEMI JOINs with group-by and having clause", + "sql": "EXPLAIN PLAN FOR WITH tmp1 AS ( SELECT * FROM a WHERE col2 NOT IN ('foo', 'bar') ) SELECT col1, SUM(col3) FROM a WHERE col2 IN (SELECT col1 FROM tmp1) AND col3 IN (SELECT col3 from b WHERE col3 < 100) GROUP BY col1 HAVING COUNT(*) > 10", + "output": [ + "Execution Plan", + "\nLogicalProject(col1=[$0], EXPR$1=[$1])", + "\n LogicalFilter(condition=[>($2, 10)])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])", + "\n PinotLogicalExchange(distribution=[hash[0]])", + "\n LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])", + "\n LogicalJoin(condition=[=($1, $2)], joinType=[semi])", + "\n LogicalProject(col1=[$0], col3=[$2])", + "\n LogicalJoin(condition=[=($1, $3)], joinType=[semi])", + "\n LogicalProject(col1=[$0], col2=[$1], col3=[$2])", + "\n LogicalTableScan(table=[[a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col1=[$0])", + "\n LogicalFilter(condition=[AND(<>($1, 'bar'), <>($1, 'foo'))])", + "\n LogicalTableScan(table=[[a]])", + "\n PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])", + "\n LogicalProject(col3=[$2])", + "\n LogicalFilter(condition=[<($2, 100)])", + "\n LogicalTableScan(table=[[b]])", + "\n" + ] + }, + { + "description": "reused tmp table in SEMI JOINs", "sql": "EXPLAIN PLAN FOR WITH tmp1 AS ( SELECT * FROM a WHERE col2 NOT IN ('foo', 'bar') ), tmp2 AS ( SELECT * FROM b WHERE col1 IN (SELECT col1 FROM tmp1) AND col3 < 100 ) SELECT * FROM tmp2 WHERE col3 IN (SELECT col3 from tmp1)", "output": [ "Execution Plan", diff --git a/pinot-query-runtime/src/test/resources/queries/WithStatements.json b/pinot-query-runtime/src/test/resources/queries/WithStatements.json index 0d8063e1c5f..61e0940a1c5 100644 --- a/pinot-query-runtime/src/test/resources/queries/WithStatements.json +++ b/pinot-query-runtime/src/test/resources/queries/WithStatements.json @@ -117,6 +117,27 @@ ["b", "bob", 196883] ] }, + { + "description": "multi 'with' table and semi-joins on same column multiple conditions and other column condition", + "sql": "WITH t1 AS ( SELECT * FROM {tbl1} WHERE intCol > 1 ) SELECT boolCol, strCol1, strCol2, intCol FROM {tbl2} WHERE strCol1 IN (SELECT strCol FROM t1) AND strCol2 IN (SELECT strCol2 FROM {tbl2} WHERE intCol > 1000) AND strCol1 IN (SELECT strCol1 FROM {tbl2} WHERE boolCol)", + "outputs": [ + [true, "b", "bob", 196883] + ] + }, + { + "description": "multi 'with' table and semi-joins and group-by with having", + "sql": "WITH t1 AS ( SELECT * FROM {tbl1} WHERE intCol > 1 ) SELECT strCol1, SUM(doubleCol) FROM {tbl2} WHERE strCol1 IN (SELECT strCol FROM t1) AND intCol IN (SELECT intCol FROM {tbl1} WHERE intCol < 100) GROUP BY strCol1 HAVING COUNT(*) < 10", + "outputs": [ + ["a", 275.12] + ] + }, + { + "description": "multi 'with' table and semi-joins and agg", + "sql": "WITH t1 AS ( SELECT * FROM {tbl1} WHERE intCol > 1 ) SELECT COUNT(*) FROM {tbl2} WHERE strCol1 IN (SELECT strCol FROM t1) AND intCol IN (SELECT intCol FROM {tbl1} WHERE intCol < 100)", + "outputs": [ + [2] + ] + }, { "description": "nested 'with' on agg table: (with a as ( ... ), select ... ", "sql": "WITH agg1 AS (SELECT strCol1, strCol2, sum(intCol) AS sumVal FROM {tbl2} GROUP BY strCol1, strCol2) SELECT strCol1, avg(sumVal) AS avgVal FROM agg1 GROUP BY strCol1",