Skip to content

Commit fe49e40

Browse files
authored
feat: CometNativeWriteExec support with native scan as a child (#2839)
1 parent 9d82669 commit fe49e40

File tree

2 files changed

+144
-88
lines changed

2 files changed

+144
-88
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,17 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
536536
firstNativeOp = true
537537
}
538538

539+
// CometNativeWriteExec is special: it has two separate plans:
540+
// 1. A protobuf plan (nativeOp) describing the write operation
541+
// 2. A Spark plan (child) that produces the data to write
542+
// The serializedPlanOpt is a def that always returns Some(...) by serializing
543+
// nativeOp on-demand, so it doesn't need convertBlock(). However, its child
544+
// (e.g., CometNativeScanExec) may need its own serialization. Reset the flag
545+
// so children can start their own native execution blocks.
546+
if (op.isInstanceOf[CometNativeWriteExec]) {
547+
firstNativeOp = true
548+
}
549+
539550
newPlan
540551
case op =>
541552
firstNativeOp = true

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 133 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.io.File
2424
import scala.util.Random
2525

2626
import org.apache.spark.sql.{CometTestBase, DataFrame}
27-
import org.apache.spark.sql.comet.CometNativeWriteExec
27+
import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec}
2828
import org.apache.spark.sql.execution.QueryExecution
2929
import org.apache.spark.sql.execution.command.DataWritingCommandExec
3030
import org.apache.spark.sql.internal.SQLConf
@@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOpt
3434

3535
class CometParquetWriterSuite extends CometTestBase {
3636

37-
test("basic parquet write") {
38-
// no support for fully native scan as input yet
39-
assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION)
37+
private def createTestData(inputDir: File): String = {
38+
val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
39+
val schema = FuzzDataGenerator.generateSchema(
40+
SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false))
41+
val df = FuzzDataGenerator.generateDataFrame(
42+
new Random(42),
43+
spark,
44+
schema,
45+
1000,
46+
DataGenOptions(generateNegativeZero = false))
47+
withSQLConf(
48+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
49+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
50+
df.write.parquet(inputPath)
51+
}
52+
inputPath
53+
}
54+
55+
private def writeWithCometNativeWriteExec(
56+
inputPath: String,
57+
outputPath: String): Option[QueryExecution] = {
58+
val df = spark.read.parquet(inputPath)
59+
60+
// Use a listener to capture the execution plan during write
61+
var capturedPlan: Option[QueryExecution] = None
62+
63+
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
64+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
65+
// Capture plans from write operations
66+
if (funcName == "save" || funcName.contains("command")) {
67+
capturedPlan = Some(qe)
68+
}
69+
}
70+
71+
override def onFailure(
72+
funcName: String,
73+
qe: QueryExecution,
74+
exception: Exception): Unit = {}
75+
}
76+
77+
spark.listenerManager.register(listener)
78+
79+
try {
80+
// Perform native write
81+
df.write.parquet(outputPath)
82+
83+
// Wait for listener to be called with timeout
84+
val maxWaitTimeMs = 15000
85+
val checkIntervalMs = 100
86+
val maxIterations = maxWaitTimeMs / checkIntervalMs
87+
var iterations = 0
88+
89+
while (capturedPlan.isEmpty && iterations < maxIterations) {
90+
Thread.sleep(checkIntervalMs)
91+
iterations += 1
92+
}
93+
94+
// Verify that CometNativeWriteExec was used
95+
assert(
96+
capturedPlan.isDefined,
97+
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
98+
99+
capturedPlan.foreach { qe =>
100+
val executedPlan = qe.executedPlan
101+
val hasNativeWrite = executedPlan.exists {
102+
case _: CometNativeWriteExec => true
103+
case d: DataWritingCommandExec =>
104+
d.child.exists {
105+
case _: CometNativeWriteExec => true
106+
case _ => false
107+
}
108+
case _ => false
109+
}
110+
111+
assert(
112+
hasNativeWrite,
113+
s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}")
114+
}
115+
} finally {
116+
spark.listenerManager.unregister(listener)
117+
}
118+
capturedPlan
119+
}
120+
121+
private def verifyWrittenFile(outputPath: String): Unit = {
122+
// Verify the data was written correctly
123+
val resultDf = spark.read.parquet(outputPath)
124+
assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
125+
126+
// Verify multiple part files were created
127+
val outputDir = new File(outputPath)
128+
val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-"))
129+
// With 1000 rows and default parallelism, we should get multiple partitions
130+
assert(partFiles.length > 1, "Expected multiple part files to be created")
131+
132+
// read with and without Comet and compare
133+
var sparkDf: DataFrame = null
134+
var cometDf: DataFrame = null
135+
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
136+
sparkDf = spark.read.parquet(outputPath)
137+
}
138+
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
139+
cometDf = spark.read.parquet(outputPath)
140+
}
141+
checkAnswer(sparkDf, cometDf)
142+
}
40143

144+
test("basic parquet write") {
41145
withTempPath { dir =>
42146
val outputPath = new File(dir, "output.parquet").getAbsolutePath
43147

44148
// Create test data and write it to a temp parquet file first
45149
withTempPath { inputDir =>
46-
val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
47-
val schema = FuzzDataGenerator.generateSchema(
48-
SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false))
49-
val df = FuzzDataGenerator.generateDataFrame(
50-
new Random(42),
51-
spark,
52-
schema,
53-
1000,
54-
DataGenOptions(generateNegativeZero = false))
55-
withSQLConf(
56-
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
57-
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
58-
df.write.parquet(inputPath)
59-
}
150+
val inputPath = createTestData(inputDir)
60151

61152
withSQLConf(
62153
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
63154
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
64155
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
65156
CometConf.COMET_EXEC_ENABLED.key -> "true") {
66-
val df = spark.read.parquet(inputPath)
67-
68-
// Use a listener to capture the execution plan during write
69-
var capturedPlan: Option[QueryExecution] = None
70-
71-
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
72-
override def onSuccess(
73-
funcName: String,
74-
qe: QueryExecution,
75-
durationNs: Long): Unit = {
76-
// Capture plans from write operations
77-
if (funcName == "save" || funcName.contains("command")) {
78-
capturedPlan = Some(qe)
79-
}
80-
}
81157

82-
override def onFailure(
83-
funcName: String,
84-
qe: QueryExecution,
85-
exception: Exception): Unit = {}
86-
}
158+
writeWithCometNativeWriteExec(inputPath, outputPath)
87159

88-
spark.listenerManager.register(listener)
89-
90-
try {
91-
// Perform native write
92-
df.write.parquet(outputPath)
160+
verifyWrittenFile(outputPath)
161+
}
162+
}
163+
}
164+
}
93165

94-
// Wait for listener to be called with timeout
95-
val maxWaitTimeMs = 15000
96-
val checkIntervalMs = 100
97-
val maxIterations = maxWaitTimeMs / checkIntervalMs
98-
var iterations = 0
166+
test("basic parquet write with native scan child") {
167+
withTempPath { dir =>
168+
val outputPath = new File(dir, "output.parquet").getAbsolutePath
99169

100-
while (capturedPlan.isEmpty && iterations < maxIterations) {
101-
Thread.sleep(checkIntervalMs)
102-
iterations += 1
103-
}
170+
// Create test data and write it to a temp parquet file first
171+
withTempPath { inputDir =>
172+
val inputPath = createTestData(inputDir)
104173

105-
// Verify that CometNativeWriteExec was used
106-
assert(
107-
capturedPlan.isDefined,
108-
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
174+
withSQLConf(
175+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
176+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
177+
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
178+
CometConf.COMET_EXEC_ENABLED.key -> "true") {
109179

180+
withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") {
181+
val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
110182
capturedPlan.foreach { qe =>
111183
val executedPlan = qe.executedPlan
112-
val hasNativeWrite = executedPlan.exists {
113-
case _: CometNativeWriteExec => true
114-
case d: DataWritingCommandExec =>
115-
d.child.exists {
116-
case _: CometNativeWriteExec => true
117-
case _ => false
118-
}
184+
val hasNativeScan = executedPlan.exists {
185+
case _: CometNativeScanExec => true
119186
case _ => false
120187
}
121188

122189
assert(
123-
hasNativeWrite,
124-
s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}")
190+
hasNativeScan,
191+
s"Expected CometNativeScanExec in the plan, but got:\n${executedPlan.treeString}")
125192
}
126-
} finally {
127-
spark.listenerManager.unregister(listener)
128-
}
129193

130-
// Verify the data was written correctly
131-
val resultDf = spark.read.parquet(outputPath)
132-
assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
133-
134-
// Verify multiple part files were created
135-
val outputDir = new File(outputPath)
136-
val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-"))
137-
// With 1000 rows and default parallelism, we should get multiple partitions
138-
assert(partFiles.length > 1, "Expected multiple part files to be created")
139-
140-
// read with and without Comet and compare
141-
var sparkDf: DataFrame = null
142-
var cometDf: DataFrame = null
143-
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
144-
sparkDf = spark.read.parquet(outputPath)
145-
}
146-
withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
147-
cometDf = spark.read.parquet(outputPath)
194+
verifyWrittenFile(outputPath)
148195
}
149-
checkAnswer(sparkDf, cometDf)
150196
}
151197
}
152198
}
153199
}
154-
155200
}

0 commit comments

Comments
 (0)