@@ -24,7 +24,7 @@ import java.io.File
2424import scala .util .Random
2525
2626import org .apache .spark .sql .{CometTestBase , DataFrame }
27- import org .apache .spark .sql .comet .CometNativeWriteExec
27+ import org .apache .spark .sql .comet .{ CometNativeScanExec , CometNativeWriteExec }
2828import org .apache .spark .sql .execution .QueryExecution
2929import org .apache .spark .sql .execution .command .DataWritingCommandExec
3030import org .apache .spark .sql .internal .SQLConf
@@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOpt
3434
3535class CometParquetWriterSuite extends CometTestBase {
3636
37- test(" basic parquet write" ) {
38- // no support for fully native scan as input yet
39- assume(CometConf .COMET_NATIVE_SCAN_IMPL .get() != CometConf .SCAN_NATIVE_DATAFUSION )
37+ private def createTestData (inputDir : File ): String = {
38+ val inputPath = new File (inputDir, " input.parquet" ).getAbsolutePath
39+ val schema = FuzzDataGenerator .generateSchema(
40+ SchemaGenOptions (generateArray = false , generateStruct = false , generateMap = false ))
41+ val df = FuzzDataGenerator .generateDataFrame(
42+ new Random (42 ),
43+ spark,
44+ schema,
45+ 1000 ,
46+ DataGenOptions (generateNegativeZero = false ))
47+ withSQLConf(
48+ CometConf .COMET_NATIVE_PARQUET_WRITE_ENABLED .key -> " false" ,
49+ SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Denver" ) {
50+ df.write.parquet(inputPath)
51+ }
52+ inputPath
53+ }
54+
55+ private def writeWithCometNativeWriteExec (
56+ inputPath : String ,
57+ outputPath : String ): Option [QueryExecution ] = {
58+ val df = spark.read.parquet(inputPath)
59+
60+ // Use a listener to capture the execution plan during write
61+ var capturedPlan : Option [QueryExecution ] = None
62+
63+ val listener = new org.apache.spark.sql.util.QueryExecutionListener {
64+ override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
65+ // Capture plans from write operations
66+ if (funcName == " save" || funcName.contains(" command" )) {
67+ capturedPlan = Some (qe)
68+ }
69+ }
70+
71+ override def onFailure (
72+ funcName : String ,
73+ qe : QueryExecution ,
74+ exception : Exception ): Unit = {}
75+ }
76+
77+ spark.listenerManager.register(listener)
78+
79+ try {
80+ // Perform native write
81+ df.write.parquet(outputPath)
82+
83+ // Wait for listener to be called with timeout
84+ val maxWaitTimeMs = 15000
85+ val checkIntervalMs = 100
86+ val maxIterations = maxWaitTimeMs / checkIntervalMs
87+ var iterations = 0
88+
89+ while (capturedPlan.isEmpty && iterations < maxIterations) {
90+ Thread .sleep(checkIntervalMs)
91+ iterations += 1
92+ }
93+
94+ // Verify that CometNativeWriteExec was used
95+ assert(
96+ capturedPlan.isDefined,
97+ s " Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured " )
98+
99+ capturedPlan.foreach { qe =>
100+ val executedPlan = qe.executedPlan
101+ val hasNativeWrite = executedPlan.exists {
102+ case _ : CometNativeWriteExec => true
103+ case d : DataWritingCommandExec =>
104+ d.child.exists {
105+ case _ : CometNativeWriteExec => true
106+ case _ => false
107+ }
108+ case _ => false
109+ }
110+
111+ assert(
112+ hasNativeWrite,
113+ s " Expected CometNativeWriteExec in the plan, but got: \n ${executedPlan.treeString}" )
114+ }
115+ } finally {
116+ spark.listenerManager.unregister(listener)
117+ }
118+ capturedPlan
119+ }
120+
121+ private def verifyWrittenFile (outputPath : String ): Unit = {
122+ // Verify the data was written correctly
123+ val resultDf = spark.read.parquet(outputPath)
124+ assert(resultDf.count() == 1000 , " Expected 1000 rows to be written" )
125+
126+ // Verify multiple part files were created
127+ val outputDir = new File (outputPath)
128+ val partFiles = outputDir.listFiles().filter(_.getName.startsWith(" part-" ))
129+ // With 1000 rows and default parallelism, we should get multiple partitions
130+ assert(partFiles.length > 1 , " Expected multiple part files to be created" )
131+
132+ // read with and without Comet and compare
133+ var sparkDf : DataFrame = null
134+ var cometDf : DataFrame = null
135+ withSQLConf(CometConf .COMET_NATIVE_SCAN_ENABLED .key -> " false" ) {
136+ sparkDf = spark.read.parquet(outputPath)
137+ }
138+ withSQLConf(CometConf .COMET_NATIVE_SCAN_ENABLED .key -> " true" ) {
139+ cometDf = spark.read.parquet(outputPath)
140+ }
141+ checkAnswer(sparkDf, cometDf)
142+ }
40143
144+ test(" basic parquet write" ) {
41145 withTempPath { dir =>
42146 val outputPath = new File (dir, " output.parquet" ).getAbsolutePath
43147
44148 // Create test data and write it to a temp parquet file first
45149 withTempPath { inputDir =>
46- val inputPath = new File (inputDir, " input.parquet" ).getAbsolutePath
47- val schema = FuzzDataGenerator .generateSchema(
48- SchemaGenOptions (generateArray = false , generateStruct = false , generateMap = false ))
49- val df = FuzzDataGenerator .generateDataFrame(
50- new Random (42 ),
51- spark,
52- schema,
53- 1000 ,
54- DataGenOptions (generateNegativeZero = false ))
55- withSQLConf(
56- CometConf .COMET_NATIVE_PARQUET_WRITE_ENABLED .key -> " false" ,
57- SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Denver" ) {
58- df.write.parquet(inputPath)
59- }
150+ val inputPath = createTestData(inputDir)
60151
61152 withSQLConf(
62153 CometConf .COMET_NATIVE_PARQUET_WRITE_ENABLED .key -> " true" ,
63154 SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Halifax" ,
64155 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ,
65156 CometConf .COMET_EXEC_ENABLED .key -> " true" ) {
66- val df = spark.read.parquet(inputPath)
67-
68- // Use a listener to capture the execution plan during write
69- var capturedPlan : Option [QueryExecution ] = None
70-
71- val listener = new org.apache.spark.sql.util.QueryExecutionListener {
72- override def onSuccess (
73- funcName : String ,
74- qe : QueryExecution ,
75- durationNs : Long ): Unit = {
76- // Capture plans from write operations
77- if (funcName == " save" || funcName.contains(" command" )) {
78- capturedPlan = Some (qe)
79- }
80- }
81157
82- override def onFailure (
83- funcName : String ,
84- qe : QueryExecution ,
85- exception : Exception ): Unit = {}
86- }
158+ writeWithCometNativeWriteExec(inputPath, outputPath)
87159
88- spark.listenerManager.register(listener )
89-
90- try {
91- // Perform native write
92- df.write.parquet(outputPath)
160+ verifyWrittenFile(outputPath )
161+ }
162+ }
163+ }
164+ }
93165
94- // Wait for listener to be called with timeout
95- val maxWaitTimeMs = 15000
96- val checkIntervalMs = 100
97- val maxIterations = maxWaitTimeMs / checkIntervalMs
98- var iterations = 0
166+ test(" basic parquet write with native scan child" ) {
167+ withTempPath { dir =>
168+ val outputPath = new File (dir, " output.parquet" ).getAbsolutePath
99169
100- while (capturedPlan.isEmpty && iterations < maxIterations) {
101- Thread .sleep(checkIntervalMs)
102- iterations += 1
103- }
170+ // Create test data and write it to a temp parquet file first
171+ withTempPath { inputDir =>
172+ val inputPath = createTestData(inputDir)
104173
105- // Verify that CometNativeWriteExec was used
106- assert(
107- capturedPlan.isDefined,
108- s " Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured " )
174+ withSQLConf(
175+ CometConf .COMET_NATIVE_PARQUET_WRITE_ENABLED .key -> " true" ,
176+ SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Halifax" ,
177+ CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ,
178+ CometConf .COMET_EXEC_ENABLED .key -> " true" ) {
109179
180+ withSQLConf(CometConf .COMET_NATIVE_SCAN_IMPL .key -> " native_datafusion" ) {
181+ val capturedPlan = writeWithCometNativeWriteExec(inputPath, outputPath)
110182 capturedPlan.foreach { qe =>
111183 val executedPlan = qe.executedPlan
112- val hasNativeWrite = executedPlan.exists {
113- case _ : CometNativeWriteExec => true
114- case d : DataWritingCommandExec =>
115- d.child.exists {
116- case _ : CometNativeWriteExec => true
117- case _ => false
118- }
184+ val hasNativeScan = executedPlan.exists {
185+ case _ : CometNativeScanExec => true
119186 case _ => false
120187 }
121188
122189 assert(
123- hasNativeWrite ,
124- s " Expected CometNativeWriteExec in the plan, but got: \n ${executedPlan.treeString}" )
190+ hasNativeScan ,
191+ s " Expected CometNativeScanExec in the plan, but got: \n ${executedPlan.treeString}" )
125192 }
126- } finally {
127- spark.listenerManager.unregister(listener)
128- }
129193
130- // Verify the data was written correctly
131- val resultDf = spark.read.parquet(outputPath)
132- assert(resultDf.count() == 1000 , " Expected 1000 rows to be written" )
133-
134- // Verify multiple part files were created
135- val outputDir = new File (outputPath)
136- val partFiles = outputDir.listFiles().filter(_.getName.startsWith(" part-" ))
137- // With 1000 rows and default parallelism, we should get multiple partitions
138- assert(partFiles.length > 1 , " Expected multiple part files to be created" )
139-
140- // read with and without Comet and compare
141- var sparkDf : DataFrame = null
142- var cometDf : DataFrame = null
143- withSQLConf(CometConf .COMET_NATIVE_SCAN_ENABLED .key -> " false" ) {
144- sparkDf = spark.read.parquet(outputPath)
145- }
146- withSQLConf(CometConf .COMET_NATIVE_SCAN_ENABLED .key -> " true" ) {
147- cometDf = spark.read.parquet(outputPath)
194+ verifyWrittenFile(outputPath)
148195 }
149- checkAnswer(sparkDf, cometDf)
150196 }
151197 }
152198 }
153199 }
154-
155200}
0 commit comments