diff --git a/src/main/scala-spark-3.0/org/apache/spark/extension/package.scala b/src/main/scala-spark-3.0/org/apache/spark/extension/package.scala deleted file mode 120000 index 57501beb..00000000 --- a/src/main/scala-spark-3.0/org/apache/spark/extension/package.scala +++ /dev/null @@ -1 +0,0 @@ -../../../../../scala-spark-3.5/org/apache/spark/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.0/org/apache/spark/sql/extension/package.scala b/src/main/scala-spark-3.0/org/apache/spark/sql/extension/package.scala new file mode 120000 index 00000000..823ef3e0 --- /dev/null +++ b/src/main/scala-spark-3.0/org/apache/spark/sql/extension/package.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.5/org/apache/spark/sql/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.1/org/apache/spark/extension/package.scala b/src/main/scala-spark-3.1/org/apache/spark/extension/package.scala deleted file mode 120000 index 57501beb..00000000 --- a/src/main/scala-spark-3.1/org/apache/spark/extension/package.scala +++ /dev/null @@ -1 +0,0 @@ -../../../../../scala-spark-3.5/org/apache/spark/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.1/org/apache/spark/sql/extension/package.scala b/src/main/scala-spark-3.1/org/apache/spark/sql/extension/package.scala new file mode 120000 index 00000000..823ef3e0 --- /dev/null +++ b/src/main/scala-spark-3.1/org/apache/spark/sql/extension/package.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.5/org/apache/spark/sql/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.2/org/apache/spark/extension/package.scala b/src/main/scala-spark-3.2/org/apache/spark/extension/package.scala deleted file mode 120000 index 57501beb..00000000 --- a/src/main/scala-spark-3.2/org/apache/spark/extension/package.scala +++ /dev/null @@ -1 +0,0 @@ -../../../../../scala-spark-3.5/org/apache/spark/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.2/org/apache/spark/sql/extension/package.scala b/src/main/scala-spark-3.2/org/apache/spark/sql/extension/package.scala new file mode 120000 index 00000000..823ef3e0 --- /dev/null +++ b/src/main/scala-spark-3.2/org/apache/spark/sql/extension/package.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.5/org/apache/spark/sql/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.3/org/apache/spark/extension/package.scala b/src/main/scala-spark-3.3/org/apache/spark/extension/package.scala deleted file mode 120000 index 57501beb..00000000 --- a/src/main/scala-spark-3.3/org/apache/spark/extension/package.scala +++ /dev/null @@ -1 +0,0 @@ -../../../../../scala-spark-3.5/org/apache/spark/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.3/org/apache/spark/sql/extension/package.scala b/src/main/scala-spark-3.3/org/apache/spark/sql/extension/package.scala new file mode 120000 index 00000000..823ef3e0 --- /dev/null +++ b/src/main/scala-spark-3.3/org/apache/spark/sql/extension/package.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.5/org/apache/spark/sql/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.4/org/apache/spark/extension/package.scala b/src/main/scala-spark-3.4/org/apache/spark/extension/package.scala deleted file mode 120000 index 57501beb..00000000 --- a/src/main/scala-spark-3.4/org/apache/spark/extension/package.scala +++ /dev/null @@ -1 +0,0 @@ -../../../../../scala-spark-3.5/org/apache/spark/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.4/org/apache/spark/sql/extension/package.scala b/src/main/scala-spark-3.4/org/apache/spark/sql/extension/package.scala new file mode 120000 index 00000000..823ef3e0 --- /dev/null +++ b/src/main/scala-spark-3.4/org/apache/spark/sql/extension/package.scala @@ -0,0 +1 @@ +../../../../../../scala-spark-3.5/org/apache/spark/sql/extension/package.scala \ No newline at end of file diff --git a/src/main/scala-spark-3.5/org/apache/spark/extension/package.scala b/src/main/scala-spark-3.5/org/apache/spark/sql/extension/package.scala similarity index 78% rename from src/main/scala-spark-3.5/org/apache/spark/extension/package.scala rename to src/main/scala-spark-3.5/org/apache/spark/sql/extension/package.scala index cd89e27e..05e9b915 100644 --- a/src/main/scala-spark-3.5/org/apache/spark/extension/package.scala +++ b/src/main/scala-spark-3.5/org/apache/spark/sql/extension/package.scala @@ -14,13 +14,17 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.sql -import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression package object extension { + implicit class ColumnExtension(col: Column) { + // Column.expr exists in this Spark version + def sql: String = col.expr.sql + } + implicit class ExpressionExtension(expr: Expression) { - def toColumn: Column = new Column(expr) + def column: Column = new Column(expr) } } diff --git a/src/main/scala-spark-4.0/org/apache/spark/extension/package.scala b/src/main/scala-spark-4.0/org/apache/spark/sql/extension/extension.scala similarity index 72% rename from src/main/scala-spark-4.0/org/apache/spark/extension/package.scala rename to src/main/scala-spark-4.0/org/apache/spark/sql/extension/extension.scala index e7ad5006..05877de2 100644 --- a/src/main/scala-spark-4.0/org/apache/spark/extension/package.scala +++ b/src/main/scala-spark-4.0/org/apache/spark/sql/extension/extension.scala @@ -14,13 +14,18 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.sql -import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.ExpressionUtils.{column => toColumn, expression} package object extension { + implicit class ColumnExtension(col: Column) { + def expr: Expression = expression(col) + def sql: String = col.node.sql + } + implicit class ExpressionExtension(expr: Expression) { - def toColumn: Column = Column(expr) + def column: Column = toColumn(expr) } } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala index 26174488..be6d1f4e 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala @@ -16,12 +16,12 @@ package uk.co.gresearch.spark.diff.comparator -import org.apache.spark.extension.ExpressionExtension import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, BinaryOperator, Expression} +import org.apache.spark.sql.extension.{ColumnExtension, ExpressionExtension} import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.sql.{Column, Encoder} import uk.co.gresearch.spark.BinaryLikeWithNewChildrenInternal @@ -32,7 +32,7 @@ trait EquivDiffComparator[T] extends DiffComparator { private trait ExpressionEquivDiffComparator[T] extends EquivDiffComparator[T] { def equiv(left: Expression, right: Expression): EquivExpression[T] - def equiv(left: Column, right: Column): Column = equiv(left.expr, right.expr).toColumn + def equiv(left: Column, right: Column): Column = equiv(left.expr, right.expr).column } trait TypedEquivDiffComparator[T] extends EquivDiffComparator[T] with TypedDiffComparator diff --git a/src/main/scala/uk/co/gresearch/spark/package.scala b/src/main/scala/uk/co/gresearch/spark/package.scala index 4edd9272..682e1063 100644 --- a/src/main/scala/uk/co/gresearch/spark/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/package.scala @@ -16,10 +16,11 @@ package uk.co.gresearch -import org.apache.spark.extension.ExpressionExtension import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.ColumnName import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.extension.{ColumnExtension, ExpressionExtension} import org.apache.spark.sql.functions.{col, count, lit, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, LongType, TimestampType} @@ -272,7 +273,7 @@ package object spark extends Logging with SparkVersion with BuildVersion { * result tick value column */ def timestampToDotNetTicks(timestampColumn: Column): Column = - unixEpochTenthMicrosToDotNetTicks(UnixMicros.unixMicros(timestampColumn.expr).toColumn * 10) + unixEpochTenthMicrosToDotNetTicks(UnixMicros.unixMicros(timestampColumn.expr).column * 10) /** * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType. @@ -717,7 +718,7 @@ package object spark extends Logging with SparkVersion with BuildVersion { if (partitionColumns.isEmpty) throw new IllegalArgumentException(s"partition columns must not be empty") - if (partitionColumns.exists(!_.expr.isInstanceOf[NamedExpression])) + if (partitionColumns.exists(col => !col.isInstanceOf[ColumnName] && !col.expr.isInstanceOf[NamedExpression])) throw new IllegalArgumentException(s"partition columns must be named: ${partitionColumns.mkString(",")}") val requiresCaching = writePartitionedByRequiresCaching(ds) @@ -739,11 +740,11 @@ package object spark extends Logging with SparkVersion with BuildVersion { unpersistHandle.get.setDataFrame(ds.sparkSession.emptyDataFrame) case _ => } - - val partitionColumnsMap = partitionColumns.map(c => c.expr.asInstanceOf[NamedExpression].name -> c).toMap - val partitionColumnNames = partitionColumnsMap.keys.map(col).toSeq - val rangeColumns = partitionColumnNames ++ moreFileColumns - val sortColumns = partitionColumnNames ++ moreFileColumns ++ moreFileOrder + // resolve partition column names + val partitionColumnNames = ds.select(partitionColumns: _*).queryExecution.analyzed.output.map(_.name) + val partitionColumnsMap = partitionColumnNames.zip(partitionColumns).toMap + val rangeColumns = partitionColumnNames.map(col) ++ moreFileColumns + val sortColumns = partitionColumnNames.map(col) ++ moreFileColumns ++ moreFileOrder ds.toDF .call(ds => partitionColumnsMap.foldLeft(ds) { case (ds, (name, col)) => ds.withColumn(name, col) }) .when(partitions.isEmpty) diff --git a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala index 44799919..1e8a2b78 100644 --- a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala @@ -18,7 +18,7 @@ package uk.co.gresearch.spark import org.apache.spark.{SparkFiles, TaskContext} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{Descending, SortOrder} +import org.apache.spark.sql.extension.ColumnExtension import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -517,9 +517,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { try { // testing with descending order is only supported for a single column - val desc = columns.map(_.expr) match { - case Seq(SortOrder(_, Descending, _, _)) => true - case _ => false + val desc = columns.map(_.sql) match { + case Seq(so) if so.contains("DESC") => true + case _ => false } // assert row numbers are correct diff --git a/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala b/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala index 62db45eb..aff959fa 100644 --- a/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala @@ -270,6 +270,12 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { } } + test("write with un-named partition columns") { + assertThrows[IllegalArgumentException] { + values.writePartitionedBy(Seq($"id" + 1)) + } + } + test("write dataframe") { withTempPath { dir => withUnpersist() { handle =>