From 87298db43d9a33fa3a3986f274442a17aad74dc3 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 9 Aug 2023 10:27:07 -0700 Subject: [PATCH] [SPARK-44503][SQL] Project any PARTITION BY expressions not already returned from Python UDTF TABLE arguments ### What changes were proposed in this pull request? This PR adds a projection when any Python UDTF TABLE argument contains PARTITION BY expressions that are not simple attributes that are already present in the output of the relation. For example: ``` CREATE TABLE t(d DATE, y INT) USING PARQUET; INSERT INTO t VALUES ... SELECT * FROM UDTF(TABLE(t) PARTITION BY EXTRACT(YEAR FROM d) ORDER BY y ASC); ``` This will generate a plan like: ``` +- Sort (y ASC) +- RepartitionByExpressions (partition_by_0) +- Project (t.d, t.y, EXTRACT(YEAR FROM t.d) AS partition_by_0) +- LogicalRelation "t" ``` ### Why are the changes needed? We project the PARTITION BY expressions so that their resulting values appear in attributes that the Python UDTF interpreter can simply inspect in order to know when the partition boundaries have changed. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR adds unit test coverage. Closes #42351 from dtenedor/partition-by-execution. Authored-by: Daniel Tenedorio Signed-off-by: Takuya UESHIN --- ...ctionTableSubqueryArgumentExpression.scala | 77 +++++++++-- .../execution/python/PythonUDTFSuite.scala | 127 ++++++++++++++++-- 2 files changed, 184 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index e7a4888125df1..daa0751eedf22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -104,23 +104,80 @@ case class FunctionTableSubqueryArgumentExpression( // the query plan. var subquery = plan if (partitionByExpressions.nonEmpty) { - subquery = RepartitionByExpression( - partitionExpressions = partitionByExpressions, - child = subquery, - optNumPartitions = None) + // Add a projection to project each of the partitioning expressions that it is not a simple + // attribute that is already present in the plan output. Then add a sort operation by the + // partition keys (plus any explicit ORDER BY items) since after the hash-based shuffle + // operation, the rows from several partitions may arrive interleaved. In this way, the Python + // UDTF evaluator is able to inspect the values of the partitioning expressions for adjacent + // rows in order to determine when each partition ends and the next one begins. + subquery = Project( + projectList = subquery.output ++ extraProjectedPartitioningExpressions, + child = subquery) + val partitioningAttributes = partitioningExpressionIndexes.map(i => subquery.output(i)) + subquery = Sort( + order = partitioningAttributes.map(e => SortOrder(e, Ascending)) ++ orderByExpressions, + global = false, + child = RepartitionByExpression( + partitionExpressions = partitioningAttributes, + optNumPartitions = None, + child = subquery)) } if (withSinglePartition) { subquery = Repartition( numPartitions = 1, shuffle = true, child = subquery) - } - if (orderByExpressions.nonEmpty) { - subquery = Sort( - order = orderByExpressions, - global = false, - child = subquery) + if (orderByExpressions.nonEmpty) { + subquery = Sort( + order = orderByExpressions, + global = false, + child = subquery) + } } Project(Seq(Alias(CreateStruct(subquery.output), "c")()), subquery) } + + /** + * These are the indexes of the PARTITION BY expressions within the concatenation of the child's + * output attributes and the [[extraProjectedPartitioningExpressions]]. We send these indexes to + * the Python UDTF evaluator so it knows which expressions to compare on adjacent rows to know + * when the partition has changed. + */ + lazy val partitioningExpressionIndexes: Seq[Int] = { + val extraPartitionByExpressionsToIndexes: Map[Expression, Int] = + extraProjectedPartitioningExpressions.map(_.child).zipWithIndex.toMap + partitionByExpressions.map { e => + subqueryOutputs.get(e).getOrElse { + extraPartitionByExpressionsToIndexes.get(e).get + plan.output.length + } + } + } + + private lazy val extraProjectedPartitioningExpressions: Seq[Alias] = { + partitionByExpressions.filter { e => + !subqueryOutputs.contains(e) + }.zipWithIndex.map { case (expr, index) => + Alias(expr, s"partition_by_$index")() + } + } + + private lazy val subqueryOutputs: Map[Expression, Int] = plan.output.zipWithIndex.toMap } + +object FunctionTableSubqueryArgumentExpression { + /** + * Returns a sequence of zero-based integer indexes identifying the values of a Python UDTF's + * 'eval' method's *args list that correspond to partitioning columns of the input TABLE argument. + */ + def partitionChildIndexes(udtfArguments: Seq[Expression]): Seq[Int] = { + udtfArguments.zipWithIndex.flatMap { case (expr, index) => + expr match { + case f: FunctionTableSubqueryArgumentExpression => + f.partitioningExpressionIndexes.map(_ + index) + case _ => + Seq() + } + } + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index 8f1bf172bbdac..43f61a7c61e8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Repartition, RepartitionByExpression, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, FunctionTableSubqueryArgumentExpression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Repartition, RepartitionByExpression, Sort, SubqueryAlias} import org.apache.spark.sql.functions.lit import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -112,7 +113,9 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { test("SPARK-44503: Specify PARTITION BY and ORDER BY for TABLE arguments") { // Positive tests assume(shouldTestPythonUDFs) - def failure(plan: LogicalPlan): Unit = fail(s"Unexpected plan: $plan") + def failure(plan: LogicalPlan): Unit = { + fail(s"Unexpected plan: $plan") + } sql( """ |SELECT * FROM testUDTF( @@ -120,8 +123,12 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | PARTITION BY X) |""".stripMargin).queryExecution.analyzed .collectFirst { case r: RepartitionByExpression => r }.get match { - case RepartitionByExpression(_, SubqueryAlias(_, _: LocalRelation), _, _) => - case other => failure(other) + case RepartitionByExpression( + _, Project( + _, SubqueryAlias( + _, _: LocalRelation)), _, _) => + case other => + failure(other) } sql( """ @@ -130,8 +137,11 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | WITH SINGLE PARTITION) |""".stripMargin).queryExecution.analyzed .collectFirst { case r: Repartition => r }.get match { - case Repartition(1, true, SubqueryAlias(_, _: LocalRelation)) => - case other => failure(other) + case Repartition( + 1, true, SubqueryAlias( + _, _: LocalRelation)) => + case other => + failure(other) } sql( """ @@ -140,8 +150,13 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | PARTITION BY SUBSTR(X, 2) ORDER BY (X, Y)) |""".stripMargin).queryExecution.analyzed .collectFirst { case r: Sort => r }.get match { - case Sort(_, false, RepartitionByExpression(_, SubqueryAlias(_, _: LocalRelation), _, _)) => - case other => failure(other) + case Sort( + _, false, RepartitionByExpression( + _, Project( + _, SubqueryAlias( + _, _: LocalRelation)), _, _)) => + case other => + failure(other) } sql( """ @@ -150,8 +165,12 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { | WITH SINGLE PARTITION ORDER BY (X, Y)) |""".stripMargin).queryExecution.analyzed .collectFirst { case r: Sort => r }.get match { - case Sort(_, false, Repartition(1, true, SubqueryAlias(_, _: LocalRelation))) => - case other => failure(other) + case Sort( + _, false, Repartition( + 1, true, SubqueryAlias( + _, _: LocalRelation))) => + case other => + failure(other) } // Negative tests withTable("t") { @@ -172,4 +191,92 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { stop = 30)) } } + + test("SPARK-44503: Compute partition child indexes for various UDTF argument lists") { + // Each of the following tests calls the PythonUDTF.partitionChildIndexes with a list of + // expressions and then checks the PARTITION BY child expression indexes that come out. + val projectList = Seq( + Alias(Literal(42), "a")(), + Alias(Literal(43), "b")()) + val projectTwoValues = Project( + projectList = projectList, + child = OneRowRelation()) + // There are no UDTF TABLE arguments, so there are no PARTITION BY child expression indexes. + val partitionChildIndexes = FunctionTableSubqueryArgumentExpression.partitionChildIndexes(_) + assert(partitionChildIndexes(Seq( + Literal(41))) == + Seq.empty[Int]) + assert(partitionChildIndexes(Seq( + Literal(41), + Literal("abc"))) == + Seq.empty[Int]) + // The UDTF TABLE argument has no PARTITION BY expressions, so there are no PARTITION BY child + // expression indexes. + assert(partitionChildIndexes(Seq( + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues))) == + Seq.empty[Int]) + // The UDTF TABLE argument has two PARTITION BY expressions which are equal to the output + // attributes from the provided relation, in order. Therefore the PARTITION BY child expression + // indexes are 0 and 1. + assert(partitionChildIndexes(Seq( + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = projectTwoValues.output))) == + Seq(0, 1)) + // The UDTF TABLE argument has one PARTITION BY expression which is equal to the first output + // attribute from the provided relation. Therefore the PARTITION BY child expression index is 0. + assert(partitionChildIndexes(Seq( + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = Seq(projectList.head.toAttribute)))) == + Seq(0)) + // The UDTF TABLE argument has one PARTITION BY expression which is equal to the second output + // attribute from the provided relation. Therefore the PARTITION BY child expression index is 1. + assert(partitionChildIndexes(Seq( + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = Seq(projectList.last.toAttribute)))) == + Seq(1)) + // The UDTF has one scalar argument, then one TABLE argument, then another scalar argument. The + // TABLE argument has two PARTITION BY expressions which are equal to the output attributes from + // the provided relation, in order. Therefore the PARTITION BY child expression indexes are 1 + // and 2, because they begin at an offset of 1 from the zero-based start of the list of values + // provided to the UDTF 'eval' method. + assert(partitionChildIndexes(Seq( + Literal(41), + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = projectTwoValues.output), + Literal("abc"))) == + Seq(1, 2)) + // Same as above, but the PARTITION BY expressions are new expressions which must be projected + // after all the attributes from the relation provided to the UDTF TABLE argument. Therefore the + // PARTITION BY child indexes are 3 and 4 because they begin at an offset of 3 from the + // zero-based start of the list of values provided to the UDTF 'eval' method. + assert(partitionChildIndexes(Seq( + Literal(41), + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = Seq(Literal(42), Literal(43))), + Literal("abc"))) == + Seq(3, 4)) + // Same as above, but the PARTITION BY list comprises just one addition expression. + assert(partitionChildIndexes(Seq( + Literal(41), + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = Seq(Add(projectList.head.toAttribute, Literal(1)))), + Literal("abc"))) == + Seq(3)) + // Same as above, but the PARTITION BY list comprises one literal value and one addition + // expression. + assert(partitionChildIndexes(Seq( + Literal(41), + FunctionTableSubqueryArgumentExpression( + plan = projectTwoValues, + partitionByExpressions = Seq(Literal(42), Add(projectList.head.toAttribute, Literal(1)))), + Literal("abc"))) == + Seq(3, 4)) + } }