Skip to content

Commit

Permalink
[SPARK-48655][SQL] SPJ: Add tests for shuffle skipping for aggregate …
Browse files Browse the repository at this point in the history
…queries

### What changes were proposed in this pull request?
  Add unit test in KeyGroupedPartitionSuite to verify that aggregation can also skip shuffle if key matches partition key

### Why are the changes needed?
  This lacked test coverage

### Does this PR introduce _any_ user-facing change?
  No

### How was this patch tested?
  Ran unit test

### Was this patch authored or co-authored using generative AI tooling?
  No

Closes apache#47015 from szehon-ho/spj_test.

Authored-by: Szehon Ho <szehon.apache@gmail.com>
Signed-off-by: Chao Sun <chao@openai.com>
  • Loading branch information
szehon-ho authored and sunchao committed Jun 21, 2024
1 parent 9414211 commit 7e5a461
Showing 1 changed file with 23 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50)))
}

private def collectAllShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
collect(plan) {
case s: ShuffleExchangeExec => s
}
}

private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = {
// here we skip collecting shuffle operators that are not associated with SMJ
collect(plan) {
Expand Down Expand Up @@ -346,6 +352,23 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Column.create("price", FloatType),
Column.create("time", TimestampType))

test("SPARK-48655: group by on partition keys should not introduce additional shuffle") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 41.0, cast('2020-01-02' as timestamp)), " +
s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val df = sql(s"SELECT MAX(price) AS res FROM testcat.ns.$items GROUP BY id")
val shuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty,
"should contain shuffle when not grouping by partition values")

checkAnswer(df.sort("res"), Seq(Row(10.0), Row(15.5), Row(41.0)))
}

test("partitioned join: join with two partition keys and matching & sorted partitions") {
val items_partitions = Array(bucket(8, "id"), days("arrive_time"))
createTable(items, itemsColumns, items_partitions)
Expand Down

0 comments on commit 7e5a461

Please sign in to comment.