Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
dtenedor committed Sep 12, 2024
2 parents 61a0edb + 557bd0c commit cf0c913
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
extendedResolutionRules : _*),
Batch("Remove TempResolvedColumn", Once, RemoveTempResolvedColumn),
Batch("Post-Hoc Resolution", Once,
Seq(ResolveCommandsWithIfExists, RemovePipeOperators) ++
Seq(ResolveCommandsWithIfExists) ++
postHocResolutionRules: _*),
Batch("Remove Unresolved Hints", Once,
new ResolveHints.RemoveAllHints),
Expand Down Expand Up @@ -2727,12 +2727,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
t => t.containsAnyPattern(AGGREGATE_EXPRESSION, PYTHON_UDF) && t.containsPattern(PROJECT),
ruleId) {
case Project(projectList, child) if containsAggregates(projectList) =>
if (child.isInstanceOf[PipeOperatorSelect]) {
// If we used the pipe operator |> SELECT clause to specify an aggregate function, this is
// invalid; return an error message instructing the user to use the pipe operator
// |> AGGREGATE clause for this purpose instead.
throw QueryCompilationErrors.pipeOperatorSelectContainsAggregateFunction(projectList.head)
}
Aggregate(Nil, projectList, child)
}

Expand All @@ -2753,17 +2747,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* Removes placeholder PipeOperator* logical plan nodes and checks invariants.
*/
object RemovePipeOperators extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(PIPE_OPERATOR_SELECT), ruleId) {
case PipeOperatorSelect(child) =>
child
}
}

/**
* This rule finds aggregate expressions that are not in an aggregate operator. For example,
* those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.trees.TreePattern.{PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
* Represents a SELECT clause when used with the |> SQL pipe operator.
* We use this to make sure that no aggregate functions exist in the SELECT expressions.
*/
case class PipeSelect(child: Expression)
extends UnaryExpression with RuntimeReplaceable {
final override val nodePatterns: Seq[TreePattern] = Seq(PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE)
override def withNewChildInternal(newChild: Expression): Expression = PipeSelect(newChild)
override def replacement: Expression = {
def visit(e: Expression): Unit = e match {
case a: AggregateFunction =>
// If we used the pipe operator |> SELECT clause to specify an aggregate function, this is
// invalid; return an error message instructing the user to use the pipe operator
// |> AGGREGATE clause for this purpose instead.
throw QueryCompilationErrors.pipeOperatorSelectContainsAggregateFunction(a)
case _: WindowExpression =>
// Window functions are allowed in pipe SELECT operators, so do not traverse into children.
case _ =>
e.children.foreach(visit)
}
visit(child)
child
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5721,6 +5721,16 @@ class AstBuilder extends DataTypeAstBuilder
operationNotAllowed("Operator pipe SQL syntax using |>", ctx)
}
Option(ctx.selectClause).map { c =>
def updateProject(p: Project): Project = {
val newProjectList: Seq[NamedExpression] = p.projectList.map {
case a: Alias =>
a.withNewChildren(Seq(PipeSelect(a.child)))
.asInstanceOf[NamedExpression]
case other =>
other
}
p.copy(projectList = newProjectList)
}
withSelectQuerySpecification(
ctx = ctx,
selectClause = c,
Expand All @@ -5730,12 +5740,12 @@ class AstBuilder extends DataTypeAstBuilder
havingClause = null,
windowClause = null,
left) match {
// The input should always be a projection since we only pass a context for the SELECT
// The input should generally be a projection since we only pass a context for the SELECT
// clause here and pass "null" for all other clauses.
case p: Project =>
p.copy(child = PipeOperatorSelect(p.child))
updateProject(p)
case d @ Distinct(p: Project) =>
d.copy(child = p.copy(child = PipeOperatorSelect(p.child)))
d.copy(child = updateProject(p))
case other =>
throw SparkException.internalError(s"Unrecognized matched logical plan: $other")
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractWindowExpressions" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$RemovePipeOperators" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggAliasInGroupBy" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_d
table t
|> select 1 as x
-- !query analysis
Project [1 AS x#x]
Project [pipeselect(1) AS x#x]
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv

Expand All @@ -85,7 +85,7 @@ table t
|> select x, y
|> select x + length(y) as z
-- !query analysis
Project [(x#x + length(y#x)) AS z#x]
Project [pipeselect((x#x + length(y#x))) AS z#x]
+- Project [x#x, y#x]
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv
Expand All @@ -95,7 +95,7 @@ Project [(x#x + length(y#x)) AS z#x]
values (0), (1) tab(col)
|> select col * 2 as result
-- !query analysis
Project [(col#x * 2) AS result#x]
Project [pipeselect((col#x * 2)) AS result#x]
+- SubqueryAlias tab
+- LocalRelation [col#x]

Expand All @@ -104,7 +104,7 @@ Project [(col#x * 2) AS result#x]
(select * from t union all select * from t)
|> select x + length(y) as result
-- !query analysis
Project [(x#x + length(y#x)) AS result#x]
Project [pipeselect((x#x + length(y#x))) AS result#x]
+- Union false, false
:- Project [x#x, y#x]
: +- SubqueryAlias spark_catalog.default.t
Expand Down Expand Up @@ -155,7 +155,7 @@ Project [col#x.i1 AS i1#x]
table t
|> select (select a from other where x = a limit 1) as result
-- !query analysis
Project [scalar-subquery#x [x#x] AS result#x]
Project [pipeselect(scalar-subquery#x [x#x]) AS result#x]
: +- GlobalLimit 1
: +- LocalLimit 1
: +- Project [a#x]
Expand All @@ -170,7 +170,7 @@ Project [scalar-subquery#x [x#x] AS result#x]
table t
|> select (select any_value(a) from other where x = a limit 1) as result
-- !query analysis
Project [scalar-subquery#x [x#x] AS result#x]
Project [pipeselect(scalar-subquery#x [x#x]) AS result#x]
: +- GlobalLimit 1
: +- LocalLimit 1
: +- Aggregate [any_value(a#x, false) AS any_value(a)#x]
Expand All @@ -185,8 +185,8 @@ Project [scalar-subquery#x [x#x] AS result#x]
table t
|> select x + length(x) as z, z + 1 as plus_one
-- !query analysis
Project [z#x, (z#x + 1) AS plus_one#x]
+- Project [x#x, y#x, (x#x + length(cast(x#x as string))) AS z#x]
Project [z#x, pipeselect((z#x + 1)) AS plus_one#x]
+- Project [x#x, y#x, pipeselect((x#x + length(cast(x#x as string)))) AS z#x]
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv

Expand All @@ -196,8 +196,8 @@ table t
|> select first_value(x) over (partition by y) as result
-- !query analysis
Project [result#x]
+- Project [x#x, y#x, result#x, result#x]
+- Window [first_value(x#x, false) windowspecdefinition(y#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS result#x], [y#x]
+- Project [x#x, y#x, _we0#x, pipeselect(_we0#x) AS result#x]
+- Window [first_value(x#x, false) windowspecdefinition(y#x, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#x], [y#x]
+- Project [x#x, y#x]
+- SubqueryAlias spark_catalog.default.t
+- Relation spark_catalog.default.t[x#x,y#x] csv
Expand All @@ -213,8 +213,8 @@ select 1 x, 2 y, 3 z
-- !query analysis
Project [a2#x]
+- Project [(1 + sum(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, x#x, a2#x]
+- Project [x#x, y#x, _w1#x, z#x, _we0#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, a2#x, (cast(1 as bigint) + _we0#xL) AS (1 + sum(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, a2#x]
+- Window [avg(_w1#x) windowspecdefinition(y#x, z#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS a2#x], [y#x], [z#x ASC NULLS FIRST]
+- Project [x#x, y#x, _w1#x, z#x, _we0#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, _we2#x, (cast(1 as bigint) + _we0#xL) AS (1 + sum(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))#xL, avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x, pipeselect(_we2#x) AS a2#x]
+- Window [avg(_w1#x) windowspecdefinition(y#x, z#x ASC NULLS FIRST, specifiedwindowframe(RangeFrame, unboundedpreceding$(), currentrow$())) AS _we2#x], [y#x], [z#x ASC NULLS FIRST]
+- Window [sum(x#x) windowspecdefinition(specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS _we0#xL, avg(y#x) windowspecdefinition(specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS avg(y) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#x]
+- Project [x#x, y#x, (x#x + 1) AS _w1#x, z#x]
+- Project [1 AS x#x, 2 AS y#x, 3 AS z#x]
Expand Down Expand Up @@ -246,23 +246,20 @@ Distinct


-- !query
<<<<<<< HEAD


table t
|> select sum(x) as result
-- !query analysis
org.apache.spark.sql.AnalysisException
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION",
"sqlState" : "0A000",
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"expr" : "sum(x#x) AS result#xL"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 19,
"stopIndex" : 34,
"fragment" : "sum(x) as result"
} ]
"error" : "'<<'",
"hint" : ""
}
}


Expand All @@ -275,14 +272,14 @@ org.apache.spark.sql.AnalysisException
"errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION",
"sqlState" : "0A000",
"messageParameters" : {
"expr" : "y#x"
"expr" : "sum(x#x)"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 19,
"stopIndex" : 19,
"fragment" : "y"
"startIndex" : 34,
"stopIndex" : 39,
"fragment" : "sum(x)"
} ]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ table t
table t
|> select distinct x, y;

<<<<<<< HEAD

-- SELECT operators: negative tests.
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,25 +231,22 @@ struct<x:int,y:string>


-- !query
<<<<<<< HEAD


table t
|> select sum(x) as result
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
org.apache.spark.sql.catalyst.parser.ParseException
{
"errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION",
"sqlState" : "0A000",
"errorClass" : "PARSE_SYNTAX_ERROR",
"sqlState" : "42601",
"messageParameters" : {
"expr" : "sum(x#x) AS result#xL"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 19,
"stopIndex" : 34,
"fragment" : "sum(x) as result"
} ]
"error" : "'<<'",
"hint" : ""
}
}


Expand All @@ -264,14 +261,14 @@ org.apache.spark.sql.AnalysisException
"errorClass" : "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION",
"sqlState" : "0A000",
"messageParameters" : {
"expr" : "y#x"
"expr" : "sum(x#x)"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 19,
"stopIndex" : 19,
"fragment" : "y"
"startIndex" : 34,
"stopIndex" : 39,
"fragment" : "sum(x)"
} ]
}

Expand Down

0 comments on commit cf0c913

Please sign in to comment.