Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,17 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
firstNativeOp = true
}

// CometNativeWriteExec is special: it has two separate plans:
// 1. A protobuf plan (nativeOp) describing the write operation
// 2. A Spark plan (child) that produces the data to write
// The serializedPlanOpt is a def that always returns Some(...) by serializing
// nativeOp on-demand, so it doesn't need convertBlock(). However, its child
// (e.g., CometNativeScanExec) may need its own serialization. Reset the flag
// so children can start their own native execution blocks.
if (op.isInstanceOf[CometNativeWriteExec]) {
firstNativeOp = true
}

newPlan
case op =>
firstNativeOp = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.io.File
import scala.util.Random

import org.apache.spark.sql.{CometTestBase, DataFrame}
import org.apache.spark.sql.comet.CometNativeWriteExec
import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOpt

class CometParquetWriterSuite extends CometTestBase {

test("basic parquet write") {
// no support for fully native scan as input yet
assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
private def createTestData(inputDir: File): String = {
val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
val schema = FuzzDataGenerator.generateSchema(
SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false))
val df = FuzzDataGenerator.generateDataFrame(
new Random(42),
spark,
schema,
1000,
DataGenOptions(generateNegativeZero = false))
withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
df.write.parquet(inputPath)
}
inputPath
}

private def writeWithCometNativeWriteExec(
inputPath: String,
outputPath: String): Option[QueryExecution] = {
val df = spark.read.parquet(inputPath)

// Use a listener to capture the execution plan during write
var capturedPlan: Option[QueryExecution] = None

val listener = new org.apache.spark.sql.util.QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
// Capture plans from write operations
if (funcName == "save" || funcName.contains("command")) {
capturedPlan = Some(qe)
}
}

override def onFailure(
funcName: String,
qe: QueryExecution,
exception: Exception): Unit = {}
}

spark.listenerManager.register(listener)

try {
// Perform native write
df.write.parquet(outputPath)

// Wait for listener to be called with timeout
Copy link
Member

@wForget wForget Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use sparkContext.listenerBus.waitUntilEmpty() or org.scalatest.concurrent.Eventually#eventually

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll update the tests soon to use this approach

val maxWaitTimeMs = 15000
val checkIntervalMs = 100
val maxIterations = maxWaitTimeMs / checkIntervalMs
var iterations = 0

while (capturedPlan.isEmpty && iterations < maxIterations) {
Thread.sleep(checkIntervalMs)
iterations += 1
}

// Verify that CometNativeWriteExec was used
assert(
capturedPlan.isDefined,
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")

capturedPlan.foreach { qe =>
val executedPlan = qe.executedPlan
val hasNativeWrite = executedPlan.exists {
case _: CometNativeWriteExec => true
case d: DataWritingCommandExec =>
d.child.exists {
case _: CometNativeWriteExec => true
case _ => false
}
case _ => false
}

assert(
hasNativeWrite,
s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}")
}
} finally {
spark.listenerManager.unregister(listener)
}
capturedPlan
}

private def verifyWrittenFile(outputPath: String): Unit = {
// Verify the data was written correctly
val resultDf = spark.read.parquet(outputPath)
assert(resultDf.count() == 1000, "Expected 1000 rows to be written")

// Verify multiple part files were created
val outputDir = new File(outputPath)
val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-"))
// With 1000 rows and default parallelism, we should get multiple partitions
assert(partFiles.length > 1, "Expected multiple part files to be created")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check exact number of partitions? example: if you write a df hash partiotined by 50 we should have 50 files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved that logic. Since this is a pretty early proof-of-concept feature from @andygrove I'm not too inclined to change test behavior in this PR.


// read with and without Comet and compare
var sparkDf: DataFrame = null
var cometDf: DataFrame = null
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
sparkDf = spark.read.parquet(outputPath)
}
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
cometDf = spark.read.parquet(outputPath)
}
checkAnswer(sparkDf, cometDf)
}

test("basic parquet write") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

// Create test data and write it to a temp parquet file first
withTempPath { inputDir =>
val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
val schema = FuzzDataGenerator.generateSchema(
SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false))
val df = FuzzDataGenerator.generateDataFrame(
new Random(42),
spark,
schema,
1000,
DataGenOptions(generateNegativeZero = false))
withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
df.write.parquet(inputPath)
}
val inputPath = createTestData(inputDir)

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
val df = spark.read.parquet(inputPath)

// Use a listener to capture the execution plan during write
var capturedPlan: Option[QueryExecution] = None

val listener = new org.apache.spark.sql.util.QueryExecutionListener {
override def onSuccess(
funcName: String,
qe: QueryExecution,
durationNs: Long): Unit = {
// Capture plans from write operations
if (funcName == "save" || funcName.contains("command")) {
capturedPlan = Some(qe)
}
}

override def onFailure(
funcName: String,
qe: QueryExecution,
exception: Exception): Unit = {}
}
writeWithCometNativeWriteExec(inputPath, outputPath)

spark.listenerManager.register(listener)

try {
// Perform native write
df.write.parquet(outputPath)
verifyWrittenFile(outputPath)
}
}
}
}

// Wait for listener to be called with timeout
val maxWaitTimeMs = 15000
val checkIntervalMs = 100
val maxIterations = maxWaitTimeMs / checkIntervalMs
var iterations = 0
test("basic parquet write with native scan child") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

while (capturedPlan.isEmpty && iterations < maxIterations) {
Thread.sleep(checkIntervalMs)
iterations += 1
}
// Create test data and write it to a temp parquet file first
withTempPath { inputDir =>
val inputPath = createTestData(inputDir)

// Verify that CometNativeWriteExec was used
assert(
capturedPlan.isDefined,
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") {
val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
capturedPlan.foreach { qe =>
val executedPlan = qe.executedPlan
val hasNativeWrite = executedPlan.exists {
case _: CometNativeWriteExec => true
case d: DataWritingCommandExec =>
d.child.exists {
case _: CometNativeWriteExec => true
case _ => false
}
val hasNativeScan = executedPlan.exists {
case _: CometNativeScanExec => true
case _ => false
}

assert(
hasNativeWrite,
s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}")
hasNativeScan,
s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}")
}
} finally {
spark.listenerManager.unregister(listener)
}

// Verify the data was written correctly
val resultDf = spark.read.parquet(outputPath)
assert(resultDf.count() == 1000, "Expected 1000 rows to be written")

// Verify multiple part files were created
val outputDir = new File(outputPath)
val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-"))
// With 1000 rows and default parallelism, we should get multiple partitions
assert(partFiles.length > 1, "Expected multiple part files to be created")

// read with and without Comet and compare
var sparkDf: DataFrame = null
var cometDf: DataFrame = null
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
sparkDf = spark.read.parquet(outputPath)
}
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
cometDf = spark.read.parquet(outputPath)
verifyWrittenFile(outputPath)
}
checkAnswer(sparkDf, cometDf)
}
}
}
}

}
Loading