-
Notifications
You must be signed in to change notification settings - Fork 258
feat: CometNativeWriteExec support with native scan as a child #2839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ import java.io.File | |
| import scala.util.Random | ||
|
|
||
| import org.apache.spark.sql.{CometTestBase, DataFrame} | ||
| import org.apache.spark.sql.comet.CometNativeWriteExec | ||
| import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec} | ||
| import org.apache.spark.sql.execution.QueryExecution | ||
| import org.apache.spark.sql.execution.command.DataWritingCommandExec | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOpt | |
|
|
||
| class CometParquetWriterSuite extends CometTestBase { | ||
|
|
||
| test("basic parquet write") { | ||
| // no support for fully native scan as input yet | ||
| assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) | ||
| private def createTestData(inputDir: File): String = { | ||
| val inputPath = new File(inputDir, "input.parquet").getAbsolutePath | ||
| val schema = FuzzDataGenerator.generateSchema( | ||
| SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false)) | ||
| val df = FuzzDataGenerator.generateDataFrame( | ||
| new Random(42), | ||
| spark, | ||
| schema, | ||
| 1000, | ||
| DataGenOptions(generateNegativeZero = false)) | ||
| withSQLConf( | ||
| CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false", | ||
| SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") { | ||
| df.write.parquet(inputPath) | ||
| } | ||
| inputPath | ||
| } | ||
|
|
||
| private def writeWithCometNativeWriteExec( | ||
| inputPath: String, | ||
| outputPath: String): Option[QueryExecution] = { | ||
| val df = spark.read.parquet(inputPath) | ||
|
|
||
| // Use a listener to capture the execution plan during write | ||
| var capturedPlan: Option[QueryExecution] = None | ||
|
|
||
| val listener = new org.apache.spark.sql.util.QueryExecutionListener { | ||
| override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { | ||
| // Capture plans from write operations | ||
| if (funcName == "save" || funcName.contains("command")) { | ||
| capturedPlan = Some(qe) | ||
| } | ||
| } | ||
|
|
||
| override def onFailure( | ||
| funcName: String, | ||
| qe: QueryExecution, | ||
| exception: Exception): Unit = {} | ||
| } | ||
|
|
||
| spark.listenerManager.register(listener) | ||
|
|
||
| try { | ||
| // Perform native write | ||
| df.write.parquet(outputPath) | ||
|
|
||
| // Wait for listener to be called with timeout | ||
| val maxWaitTimeMs = 15000 | ||
| val checkIntervalMs = 100 | ||
| val maxIterations = maxWaitTimeMs / checkIntervalMs | ||
| var iterations = 0 | ||
|
|
||
| while (capturedPlan.isEmpty && iterations < maxIterations) { | ||
| Thread.sleep(checkIntervalMs) | ||
| iterations += 1 | ||
| } | ||
|
|
||
| // Verify that CometNativeWriteExec was used | ||
| assert( | ||
| capturedPlan.isDefined, | ||
| s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured") | ||
|
|
||
| capturedPlan.foreach { qe => | ||
| val executedPlan = qe.executedPlan | ||
| val hasNativeWrite = executedPlan.exists { | ||
| case _: CometNativeWriteExec => true | ||
| case d: DataWritingCommandExec => | ||
| d.child.exists { | ||
| case _: CometNativeWriteExec => true | ||
| case _ => false | ||
| } | ||
| case _ => false | ||
| } | ||
|
|
||
| assert( | ||
| hasNativeWrite, | ||
| s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}") | ||
| } | ||
| } finally { | ||
| spark.listenerManager.unregister(listener) | ||
| } | ||
| capturedPlan | ||
| } | ||
|
|
||
| private def verifyWrittenFile(outputPath: String): Unit = { | ||
| // Verify the data was written correctly | ||
| val resultDf = spark.read.parquet(outputPath) | ||
| assert(resultDf.count() == 1000, "Expected 1000 rows to be written") | ||
|
|
||
| // Verify multiple part files were created | ||
| val outputDir = new File(outputPath) | ||
| val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-")) | ||
| // With 1000 rows and default parallelism, we should get multiple partitions | ||
| assert(partFiles.length > 1, "Expected multiple part files to be created") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we check exact number of partitions? example: if you write a df hash partiotined by 50 we should have 50 files
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just moved that logic. Since this is a pretty early proof-of-concept feature from @andygrove I'm not too inclined to change test behavior in this PR. |
||
|
|
||
| // read with and without Comet and compare | ||
| var sparkDf: DataFrame = null | ||
| var cometDf: DataFrame = null | ||
| withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") { | ||
| sparkDf = spark.read.parquet(outputPath) | ||
| } | ||
| withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") { | ||
| cometDf = spark.read.parquet(outputPath) | ||
| } | ||
| checkAnswer(sparkDf, cometDf) | ||
| } | ||
|
|
||
| test("basic parquet write") { | ||
| withTempPath { dir => | ||
| val outputPath = new File(dir, "output.parquet").getAbsolutePath | ||
|
|
||
| // Create test data and write it to a temp parquet file first | ||
| withTempPath { inputDir => | ||
| val inputPath = new File(inputDir, "input.parquet").getAbsolutePath | ||
| val schema = FuzzDataGenerator.generateSchema( | ||
| SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false)) | ||
| val df = FuzzDataGenerator.generateDataFrame( | ||
| new Random(42), | ||
| spark, | ||
| schema, | ||
| 1000, | ||
| DataGenOptions(generateNegativeZero = false)) | ||
| withSQLConf( | ||
| CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false", | ||
| SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") { | ||
| df.write.parquet(inputPath) | ||
| } | ||
| val inputPath = createTestData(inputDir) | ||
|
|
||
| withSQLConf( | ||
| CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", | ||
| SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", | ||
| CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", | ||
| CometConf.COMET_EXEC_ENABLED.key -> "true") { | ||
| val df = spark.read.parquet(inputPath) | ||
|
|
||
| // Use a listener to capture the execution plan during write | ||
| var capturedPlan: Option[QueryExecution] = None | ||
|
|
||
| val listener = new org.apache.spark.sql.util.QueryExecutionListener { | ||
| override def onSuccess( | ||
| funcName: String, | ||
| qe: QueryExecution, | ||
| durationNs: Long): Unit = { | ||
| // Capture plans from write operations | ||
| if (funcName == "save" || funcName.contains("command")) { | ||
| capturedPlan = Some(qe) | ||
| } | ||
| } | ||
|
|
||
| override def onFailure( | ||
| funcName: String, | ||
| qe: QueryExecution, | ||
| exception: Exception): Unit = {} | ||
| } | ||
| writeWithCometNativeWriteExec(inputPath, outputPath) | ||
|
|
||
| spark.listenerManager.register(listener) | ||
|
|
||
| try { | ||
| // Perform native write | ||
| df.write.parquet(outputPath) | ||
| verifyWrittenFile(outputPath) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Wait for listener to be called with timeout | ||
| val maxWaitTimeMs = 15000 | ||
| val checkIntervalMs = 100 | ||
| val maxIterations = maxWaitTimeMs / checkIntervalMs | ||
| var iterations = 0 | ||
| test("basic parquet write with native scan child") { | ||
| withTempPath { dir => | ||
| val outputPath = new File(dir, "output.parquet").getAbsolutePath | ||
|
|
||
| while (capturedPlan.isEmpty && iterations < maxIterations) { | ||
| Thread.sleep(checkIntervalMs) | ||
| iterations += 1 | ||
| } | ||
| // Create test data and write it to a temp parquet file first | ||
| withTempPath { inputDir => | ||
| val inputPath = createTestData(inputDir) | ||
|
|
||
| // Verify that CometNativeWriteExec was used | ||
| assert( | ||
| capturedPlan.isDefined, | ||
| s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured") | ||
| withSQLConf( | ||
| CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", | ||
| SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", | ||
| CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", | ||
| CometConf.COMET_EXEC_ENABLED.key -> "true") { | ||
|
|
||
| withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { | ||
| val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath) | ||
| capturedPlan.foreach { qe => | ||
| val executedPlan = qe.executedPlan | ||
| val hasNativeWrite = executedPlan.exists { | ||
| case _: CometNativeWriteExec => true | ||
| case d: DataWritingCommandExec => | ||
| d.child.exists { | ||
| case _: CometNativeWriteExec => true | ||
| case _ => false | ||
| } | ||
| val hasNativeScan = executedPlan.exists { | ||
| case _: CometNativeScanExec => true | ||
| case _ => false | ||
| } | ||
|
|
||
| assert( | ||
| hasNativeWrite, | ||
| s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}") | ||
| hasNativeScan, | ||
| s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}") | ||
| } | ||
| } finally { | ||
| spark.listenerManager.unregister(listener) | ||
| } | ||
|
|
||
| // Verify the data was written correctly | ||
| val resultDf = spark.read.parquet(outputPath) | ||
| assert(resultDf.count() == 1000, "Expected 1000 rows to be written") | ||
|
|
||
| // Verify multiple part files were created | ||
| val outputDir = new File(outputPath) | ||
| val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-")) | ||
| // With 1000 rows and default parallelism, we should get multiple partitions | ||
| assert(partFiles.length > 1, "Expected multiple part files to be created") | ||
|
|
||
| // read with and without Comet and compare | ||
| var sparkDf: DataFrame = null | ||
| var cometDf: DataFrame = null | ||
| withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") { | ||
| sparkDf = spark.read.parquet(outputPath) | ||
| } | ||
| withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") { | ||
| cometDf = spark.read.parquet(outputPath) | ||
| verifyWrittenFile(outputPath) | ||
| } | ||
| checkAnswer(sparkDf, cometDf) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use
sparkContext.listenerBus.waitUntilEmpty()ororg.scalatest.concurrent.Eventually#eventuallyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I'll update the tests soon to use this approach