Skip to content

Commit

Permalink
[SPARK-44503][SQL] Project any PARTITION BY expressions not already r…
Browse files Browse the repository at this point in the history
…eturned from Python UDTF TABLE arguments

### What changes were proposed in this pull request?

This PR adds a projection when any Python UDTF TABLE argument contains PARTITION BY expressions that are not simple attributes that are already present in the output of the relation.

For example:

```
CREATE TABLE t(d DATE, y INT) USING PARQUET;
INSERT INTO t VALUES ...
SELECT * FROM UDTF(TABLE(t) PARTITION BY EXTRACT(YEAR FROM d) ORDER BY y ASC);
```

This will generate a plan like:

```
+- Sort (y ASC)
  +- RepartitionByExpressions (partition_by_0)
    +- Project (t.d, t.y, EXTRACT(YEAR FROM t.d) AS partition_by_0)
      +- LogicalRelation "t"
```

### Why are the changes needed?

We project the PARTITION BY expressions so that their resulting values appear in attributes that the Python UDTF interpreter can simply inspect in order to know when the partition boundaries have changed.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR adds unit test coverage.

Closes apache#42351 from dtenedor/partition-by-execution.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
dtenedor authored and ueshin committed Aug 9, 2023
1 parent be9ffb3 commit 87298db
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,80 @@ case class FunctionTableSubqueryArgumentExpression(
// the query plan.
var subquery = plan
if (partitionByExpressions.nonEmpty) {
subquery = RepartitionByExpression(
partitionExpressions = partitionByExpressions,
child = subquery,
optNumPartitions = None)
// Add a projection to project each of the partitioning expressions that it is not a simple
// attribute that is already present in the plan output. Then add a sort operation by the
// partition keys (plus any explicit ORDER BY items) since after the hash-based shuffle
// operation, the rows from several partitions may arrive interleaved. In this way, the Python
// UDTF evaluator is able to inspect the values of the partitioning expressions for adjacent
// rows in order to determine when each partition ends and the next one begins.
subquery = Project(
projectList = subquery.output ++ extraProjectedPartitioningExpressions,
child = subquery)
val partitioningAttributes = partitioningExpressionIndexes.map(i => subquery.output(i))
subquery = Sort(
order = partitioningAttributes.map(e => SortOrder(e, Ascending)) ++ orderByExpressions,
global = false,
child = RepartitionByExpression(
partitionExpressions = partitioningAttributes,
optNumPartitions = None,
child = subquery))
}
if (withSinglePartition) {
subquery = Repartition(
numPartitions = 1,
shuffle = true,
child = subquery)
}
if (orderByExpressions.nonEmpty) {
subquery = Sort(
order = orderByExpressions,
global = false,
child = subquery)
if (orderByExpressions.nonEmpty) {
subquery = Sort(
order = orderByExpressions,
global = false,
child = subquery)
}
}
Project(Seq(Alias(CreateStruct(subquery.output), "c")()), subquery)
}

/**
* These are the indexes of the PARTITION BY expressions within the concatenation of the child's
* output attributes and the [[extraProjectedPartitioningExpressions]]. We send these indexes to
* the Python UDTF evaluator so it knows which expressions to compare on adjacent rows to know
* when the partition has changed.
*/
lazy val partitioningExpressionIndexes: Seq[Int] = {
val extraPartitionByExpressionsToIndexes: Map[Expression, Int] =
extraProjectedPartitioningExpressions.map(_.child).zipWithIndex.toMap
partitionByExpressions.map { e =>
subqueryOutputs.get(e).getOrElse {
extraPartitionByExpressionsToIndexes.get(e).get + plan.output.length
}
}
}

private lazy val extraProjectedPartitioningExpressions: Seq[Alias] = {
partitionByExpressions.filter { e =>
!subqueryOutputs.contains(e)
}.zipWithIndex.map { case (expr, index) =>
Alias(expr, s"partition_by_$index")()
}
}

private lazy val subqueryOutputs: Map[Expression, Int] = plan.output.zipWithIndex.toMap
}

object FunctionTableSubqueryArgumentExpression {
/**
* Returns a sequence of zero-based integer indexes identifying the values of a Python UDTF's
* 'eval' method's *args list that correspond to partitioning columns of the input TABLE argument.
*/
def partitionChildIndexes(udtfArguments: Seq[Expression]): Seq[Int] = {
udtfArguments.zipWithIndex.flatMap { case (expr, index) =>
expr match {
case f: FunctionTableSubqueryArgumentExpression =>
f.partitioningExpressionIndexes.map(_ + index)
case _ =>
Seq()
}
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.python

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Repartition, RepartitionByExpression, Sort, SubqueryAlias}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, FunctionTableSubqueryArgumentExpression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -112,16 +113,22 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") {
// Positive tests
assume(shouldTestPythonUDFs)
def failure(plan: LogicalPlan): Unit = fail(s"Unexpected plan: $plan")
def failure(plan: LogicalPlan): Unit = {
fail(s"Unexpected plan: $plan")
}
sql(
"""
|SELECT * FROM testUDTF(
| TABLE(VALUES (1), (1) AS tab(x))
| PARTITION BY X)
|""".stripMargin).queryExecution.analyzed
.collectFirst { case r: RepartitionByExpression => r }.get match {
case RepartitionByExpression(_, SubqueryAlias(_, _: LocalRelation), _, _) =>
case other => failure(other)
case RepartitionByExpression(
_, Project(
_, SubqueryAlias(
_, _: LocalRelation)), _, _) =>
case other =>
failure(other)
}
sql(
"""
Expand All @@ -130,8 +137,11 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
| WITH SINGLE PARTITION)
|""".stripMargin).queryExecution.analyzed
.collectFirst { case r: Repartition => r }.get match {
case Repartition(1, true, SubqueryAlias(_, _: LocalRelation)) =>
case other => failure(other)
case Repartition(
1, true, SubqueryAlias(
_, _: LocalRelation)) =>
case other =>
failure(other)
}
sql(
"""
Expand All @@ -140,8 +150,13 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
| PARTITION BY SUBSTR(X, 2) ORDER BY (X, Y))
|""".stripMargin).queryExecution.analyzed
.collectFirst { case r: Sort => r }.get match {
case Sort(_, false, RepartitionByExpression(_, SubqueryAlias(_, _: LocalRelation), _, _)) =>
case other => failure(other)
case Sort(
_, false, RepartitionByExpression(
_, Project(
_, SubqueryAlias(
_, _: LocalRelation)), _, _)) =>
case other =>
failure(other)
}
sql(
"""
Expand All @@ -150,8 +165,12 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
| WITH SINGLE PARTITION ORDER BY (X, Y))
|""".stripMargin).queryExecution.analyzed
.collectFirst { case r: Sort => r }.get match {
case Sort(_, false, Repartition(1, true, SubqueryAlias(_, _: LocalRelation))) =>
case other => failure(other)
case Sort(
_, false, Repartition(
1, true, SubqueryAlias(
_, _: LocalRelation))) =>
case other =>
failure(other)
}
// Negative tests
withTable("t") {
Expand All @@ -172,4 +191,92 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession {
stop = 30))
}
}

test("SPARK-44503: Compute partition child indexes for various UDTF argument lists") {
// Each of the following tests calls the PythonUDTF.partitionChildIndexes with a list of
// expressions and then checks the PARTITION BY child expression indexes that come out.
val projectList = Seq(
Alias(Literal(42), "a")(),
Alias(Literal(43), "b")())
val projectTwoValues = Project(
projectList = projectList,
child = OneRowRelation())
// There are no UDTF TABLE arguments, so there are no PARTITION BY child expression indexes.
val partitionChildIndexes = FunctionTableSubqueryArgumentExpression.partitionChildIndexes(_)
assert(partitionChildIndexes(Seq(
Literal(41))) ==
Seq.empty[Int])
assert(partitionChildIndexes(Seq(
Literal(41),
Literal("abc"))) ==
Seq.empty[Int])
// The UDTF TABLE argument has no PARTITION BY expressions, so there are no PARTITION BY child
// expression indexes.
assert(partitionChildIndexes(Seq(
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues))) ==
Seq.empty[Int])
// The UDTF TABLE argument has two PARTITION BY expressions which are equal to the output
// attributes from the provided relation, in order. Therefore the PARTITION BY child expression
// indexes are 0 and 1.
assert(partitionChildIndexes(Seq(
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = projectTwoValues.output))) ==
Seq(0, 1))
// The UDTF TABLE argument has one PARTITION BY expression which is equal to the first output
// attribute from the provided relation. Therefore the PARTITION BY child expression index is 0.
assert(partitionChildIndexes(Seq(
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = Seq(projectList.head.toAttribute)))) ==
Seq(0))
// The UDTF TABLE argument has one PARTITION BY expression which is equal to the second output
// attribute from the provided relation. Therefore the PARTITION BY child expression index is 1.
assert(partitionChildIndexes(Seq(
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = Seq(projectList.last.toAttribute)))) ==
Seq(1))
// The UDTF has one scalar argument, then one TABLE argument, then another scalar argument. The
// TABLE argument has two PARTITION BY expressions which are equal to the output attributes from
// the provided relation, in order. Therefore the PARTITION BY child expression indexes are 1
// and 2, because they begin at an offset of 1 from the zero-based start of the list of values
// provided to the UDTF 'eval' method.
assert(partitionChildIndexes(Seq(
Literal(41),
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = projectTwoValues.output),
Literal("abc"))) ==
Seq(1, 2))
// Same as above, but the PARTITION BY expressions are new expressions which must be projected
// after all the attributes from the relation provided to the UDTF TABLE argument. Therefore the
// PARTITION BY child indexes are 3 and 4 because they begin at an offset of 3 from the
// zero-based start of the list of values provided to the UDTF 'eval' method.
assert(partitionChildIndexes(Seq(
Literal(41),
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = Seq(Literal(42), Literal(43))),
Literal("abc"))) ==
Seq(3, 4))
// Same as above, but the PARTITION BY list comprises just one addition expression.
assert(partitionChildIndexes(Seq(
Literal(41),
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = Seq(Add(projectList.head.toAttribute, Literal(1)))),
Literal("abc"))) ==
Seq(3))
// Same as above, but the PARTITION BY list comprises one literal value and one addition
// expression.
assert(partitionChildIndexes(Seq(
Literal(41),
FunctionTableSubqueryArgumentExpression(
plan = projectTwoValues,
partitionByExpressions = Seq(Literal(42), Add(projectList.head.toAttribute, Literal(1)))),
Literal("abc"))) ==
Seq(3, 4))
}
}

0 comments on commit 87298db

Please sign in to comment.