diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 124188b64d..b9dbd56464 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -551,6 +551,17 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { firstNativeOp = true } + // CometNativeWriteExec is special: it has two separate plans: + // 1. A protobuf plan (nativeOp) describing the write operation + // 2. A Spark plan (child) that produces the data to write + // The serializedPlanOpt is a def that always returns Some(...) by serializing + // nativeOp on-demand, so it doesn't need convertBlock(). However, its child + // (e.g., CometNativeScanExec) may need its own serialization. Reset the flag + // so children can start their own native execution blocks. + if (op.isInstanceOf[CometNativeWriteExec]) { + firstNativeOp = true + } + newPlan case op => firstNativeOp = true diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index e4b8b53856..2ea697fd4d 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -24,7 +24,7 @@ import java.io.File import scala.util.Random import org.apache.spark.sql.{CometTestBase, DataFrame} -import org.apache.spark.sql.comet.CometNativeWriteExec +import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.internal.SQLConf @@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOpt class CometParquetWriterSuite extends CometTestBase { - test("basic parquet write") { - // no support for fully native scan as input yet - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) + private def createTestData(inputDir: File): String = { + val inputPath = new File(inputDir, "input.parquet").getAbsolutePath + val schema = FuzzDataGenerator.generateSchema( + SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false)) + val df = FuzzDataGenerator.generateDataFrame( + new Random(42), + spark, + schema, + 1000, + DataGenOptions(generateNegativeZero = false)) + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") { + df.write.parquet(inputPath) + } + inputPath + } + + private def writeWithCometNativeWriteExec( + inputPath: String, + outputPath: String): Option[QueryExecution] = { + val df = spark.read.parquet(inputPath) + + // Use a listener to capture the execution plan during write + var capturedPlan: Option[QueryExecution] = None + + val listener = new org.apache.spark.sql.util.QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + // Capture plans from write operations + if (funcName == "save" || funcName.contains("command")) { + capturedPlan = Some(qe) + } + } + + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + + try { + // Perform native write + df.write.parquet(outputPath) + + // Wait for listener to be called with timeout + val maxWaitTimeMs = 15000 + val checkIntervalMs = 100 + val maxIterations = maxWaitTimeMs / checkIntervalMs + var iterations = 0 + + while (capturedPlan.isEmpty && iterations < maxIterations) { + Thread.sleep(checkIntervalMs) + iterations += 1 + } + + // Verify that CometNativeWriteExec was used + assert( + capturedPlan.isDefined, + s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured") + + capturedPlan.foreach { qe => + val executedPlan = qe.executedPlan + val hasNativeWrite = executedPlan.exists { + case _: CometNativeWriteExec => true + case d: DataWritingCommandExec => + d.child.exists { + case _: CometNativeWriteExec => true + case _ => false + } + case _ => false + } + + assert( + hasNativeWrite, + s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}") + } + } finally { + spark.listenerManager.unregister(listener) + } + capturedPlan + } + + private def verifyWrittenFile(outputPath: String): Unit = { + // Verify the data was written correctly + val resultDf = spark.read.parquet(outputPath) + assert(resultDf.count() == 1000, "Expected 1000 rows to be written") + + // Verify multiple part files were created + val outputDir = new File(outputPath) + val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-")) + // With 1000 rows and default parallelism, we should get multiple partitions + assert(partFiles.length > 1, "Expected multiple part files to be created") + + // read with and without Comet and compare + var sparkDf: DataFrame = null + var cometDf: DataFrame = null + withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") { + sparkDf = spark.read.parquet(outputPath) + } + withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") { + cometDf = spark.read.parquet(outputPath) + } + checkAnswer(sparkDf, cometDf) + } + test("basic parquet write") { withTempPath { dir => val outputPath = new File(dir, "output.parquet").getAbsolutePath // Create test data and write it to a temp parquet file first withTempPath { inputDir => - val inputPath = new File(inputDir, "input.parquet").getAbsolutePath - val schema = FuzzDataGenerator.generateSchema( - SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false)) - val df = FuzzDataGenerator.generateDataFrame( - new Random(42), - spark, - schema, - 1000, - DataGenOptions(generateNegativeZero = false)) - withSQLConf( - CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false", - SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") { - df.write.parquet(inputPath) - } + val inputPath = createTestData(inputDir) withSQLConf( CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true") { - val df = spark.read.parquet(inputPath) - - // Use a listener to capture the execution plan during write - var capturedPlan: Option[QueryExecution] = None - - val listener = new org.apache.spark.sql.util.QueryExecutionListener { - override def onSuccess( - funcName: String, - qe: QueryExecution, - durationNs: Long): Unit = { - // Capture plans from write operations - if (funcName == "save" || funcName.contains("command")) { - capturedPlan = Some(qe) - } - } - override def onFailure( - funcName: String, - qe: QueryExecution, - exception: Exception): Unit = {} - } + writeWithCometNativeWriteExec(inputPath, outputPath) - spark.listenerManager.register(listener) - - try { - // Perform native write - df.write.parquet(outputPath) + verifyWrittenFile(outputPath) + } + } + } + } - // Wait for listener to be called with timeout - val maxWaitTimeMs = 15000 - val checkIntervalMs = 100 - val maxIterations = maxWaitTimeMs / checkIntervalMs - var iterations = 0 + test("basic parquet write with native scan child") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath - while (capturedPlan.isEmpty && iterations < maxIterations) { - Thread.sleep(checkIntervalMs) - iterations += 1 - } + // Create test data and write it to a temp parquet file first + withTempPath { inputDir => + val inputPath = createTestData(inputDir) - // Verify that CometNativeWriteExec was used - assert( - capturedPlan.isDefined, - s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured") + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { + val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath) capturedPlan.foreach { qe => val executedPlan = qe.executedPlan - val hasNativeWrite = executedPlan.exists { - case _: CometNativeWriteExec => true - case d: DataWritingCommandExec => - d.child.exists { - case _: CometNativeWriteExec => true - case _ => false - } + val hasNativeScan = executedPlan.exists { + case _: CometNativeScanExec => true case _ => false } assert( - hasNativeWrite, - s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}") + hasNativeScan, + s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}") } - } finally { - spark.listenerManager.unregister(listener) - } - // Verify the data was written correctly - val resultDf = spark.read.parquet(outputPath) - assert(resultDf.count() == 1000, "Expected 1000 rows to be written") - - // Verify multiple part files were created - val outputDir = new File(outputPath) - val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-")) - // With 1000 rows and default parallelism, we should get multiple partitions - assert(partFiles.length > 1, "Expected multiple part files to be created") - - // read with and without Comet and compare - var sparkDf: DataFrame = null - var cometDf: DataFrame = null - withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") { - sparkDf = spark.read.parquet(outputPath) - } - withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") { - cometDf = spark.read.parquet(outputPath) + verifyWrittenFile(outputPath) } - checkAnswer(sparkDf, cometDf) } } } } - }