diff --git a/.github/actions/build/action.yml b/.github/actions/build/action.yml
index ab596939..8cf9d7d8 100644
--- a/.github/actions/build/action.yml
+++ b/.github/actions/build/action.yml
@@ -45,9 +45,9 @@ runs:
env:
JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED
run: |
- mvn --batch-mode --update-snapshots clean compile test-compile
- mvn --batch-mode package -DskipTests -Dmaven.test.skip=true
- mvn --batch-mode install -DskipTests -Dmaven.test.skip=true -Dgpg.skip
+ mvn --batch-mode --update-snapshots -Dspotless.check.skip clean compile test-compile
+ mvn --batch-mode package -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true
+ mvn --batch-mode install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip
shell: bash
- name: Upload Binaries
diff --git a/.github/actions/test-jvm/action.yml b/.github/actions/test-jvm/action.yml
index 76e55dde..51f107f3 100644
--- a/.github/actions/test-jvm/action.yml
+++ b/.github/actions/test-jvm/action.yml
@@ -73,7 +73,7 @@ runs:
- name: Scala and Java Tests
env:
JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED
- run: mvn --batch-mode test
+ run: mvn --batch-mode -Dspotless.check.skip test
shell: bash
- name: Diff App test
diff --git a/.github/actions/test-python/action.yml b/.github/actions/test-python/action.yml
index 7ca019a3..1f18472e 100644
--- a/.github/actions/test-python/action.yml
+++ b/.github/actions/test-python/action.yml
@@ -100,7 +100,7 @@ runs:
shell: bash
- name: Install Spark Extension
- run: mvn --batch-mode install -DskipTests -Dmaven.test.skip=true -Dgpg.skip
+ run: mvn --batch-mode install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip
shell: bash
- name: Python Integration Tests
diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml
index 38a1cd56..30a48ca0 100644
--- a/.github/workflows/check.yml
+++ b/.github/workflows/check.yml
@@ -2,9 +2,45 @@ name: Check
on:
workflow_call:
+
jobs:
+ lint:
+ name: Scala lint
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+
+ - name: Cache Maven packages
+ uses: actions/cache@v3
+ with:
+ path: ~/.m2/repository
+ key: ${{ runner.os }}-mvn-lint-${{ hashFiles('pom.xml') }}
+
+ - name: Setup JDK ${{ inputs.java-compat-version }}
+ uses: actions/setup-java@v3
+ with:
+ java-version: '11'
+ distribution: 'zulu'
+
+ - name: Check
+ id: check
+ run: |
+ mvn --batch-mode spotless:check
+ shell: bash
+
+ - name: Changes
+ if: failure() && steps.check.outcome == 'failure'
+ run: |
+ mvn --batch-mode spotless:apply
+ git diff
+ shell: bash
+
config:
- name: Configure check
+ name: Configure compat
runs-on: ubuntu-latest
outputs:
major-version: ${{ steps.versions.outputs.major-version }}
@@ -32,8 +68,8 @@ jobs:
echo "release-major-version=${release_version/.*/}" >> "$GITHUB_OUTPUT"
shell: bash
- check:
- name: Check (Spark ${{ matrix.spark-compat-version }} Scala ${{ matrix.scala-compat-version }})
+ compat:
+ name: Compat (Spark ${{ matrix.spark-compat-version }} Scala ${{ matrix.scala-compat-version }})
needs: config
runs-on: ubuntu-latest
if: needs.config.outputs.major-version == needs.config.outputs.release-major-version
diff --git a/pom.xml b/pom.xml
index 7fbb0aff..3066f2e7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -235,7 +235,7 @@
com.diffplug.spotless
spotless-maven-plugin
- 2.41.0
+ 2.30.0
@@ -244,6 +244,16 @@
+
+
+
+ spotless-check
+ compile
+
+ check
+
+
+
diff --git a/src/main/scala/uk/co/gresearch/package.scala b/src/main/scala/uk/co/gresearch/package.scala
index 188db09a..5053f2a5 100644
--- a/src/main/scala/uk/co/gresearch/package.scala
+++ b/src/main/scala/uk/co/gresearch/package.scala
@@ -20,28 +20,28 @@ package object gresearch {
trait ConditionalCall[T] {
def call(f: T => T): T
- def either[R](f: T => R): ConditionalCallOr[T,R]
+ def either[R](f: T => R): ConditionalCallOr[T, R]
}
- trait ConditionalCallOr[T,R] {
+ trait ConditionalCallOr[T, R] {
def or(f: T => R): R
}
case class TrueCall[T](t: T) extends ConditionalCall[T] {
override def call(f: T => T): T = f(t)
- override def either[R](f: T => R): ConditionalCallOr[T,R] = TrueCallOr[T,R](f(t))
+ override def either[R](f: T => R): ConditionalCallOr[T, R] = TrueCallOr[T, R](f(t))
}
case class FalseCall[T](t: T) extends ConditionalCall[T] {
override def call(f: T => T): T = t
- override def either[R](f: T => R): ConditionalCallOr[T,R] = FalseCallOr[T,R](t)
+ override def either[R](f: T => R): ConditionalCallOr[T, R] = FalseCallOr[T, R](t)
}
- case class TrueCallOr[T,R](r: R) extends ConditionalCallOr[T,R] {
+ case class TrueCallOr[T, R](r: R) extends ConditionalCallOr[T, R] {
override def or(f: T => R): R = r
}
- case class FalseCallOr[T,R](t: T) extends ConditionalCallOr[T,R] {
+ case class FalseCallOr[T, R](t: T) extends ConditionalCallOr[T, R] {
override def or(f: T => R): R = f(t)
}
@@ -71,16 +71,17 @@ package object gresearch {
*
* which either needs many temporary variables or duplicate code.
*
- * @param condition condition
- * @return the function result
+ * @param condition
+ * condition
+ * @return
+ * the function result
*/
def on(condition: Boolean): ConditionalCall[T] = {
if (condition) TrueCall[T](t) else FalseCall[T](t)
}
/**
- * Allows to call a function on the decorated instance conditionally.
- * This is an alias for the `on` function.
+ * Allows to call a function on the decorated instance conditionally. This is an alias for the `on` function.
*
* This allows fluent code like
*
@@ -103,8 +104,10 @@ package object gresearch {
*
* which either needs many temporary variables or duplicate code.
*
- * @param condition condition
- * @return the function result
+ * @param condition
+ * condition
+ * @return
+ * the function result
*/
def when(condition: Boolean): ConditionalCall[T] = on(condition)
@@ -131,8 +134,10 @@ package object gresearch {
*
* where the effective sequence of operations is not clear.
*
- * @param f function
- * @return the function result
+ * @param f
+ * function
+ * @return
+ * the function result
*/
def call[R](f: T => R): R = f(t)
}
diff --git a/src/main/scala/uk/co/gresearch/spark/Backticks.scala b/src/main/scala/uk/co/gresearch/spark/Backticks.scala
index efaaf19b..95cdd41e 100644
--- a/src/main/scala/uk/co/gresearch/spark/Backticks.scala
+++ b/src/main/scala/uk/co/gresearch/spark/Backticks.scala
@@ -34,12 +34,15 @@ object Backticks {
* col(Backticks.column_name("a.column", "a.field")) // produces "`a.column`.`a.field`"
* }}}
*
- * @param string a string
- * @param strings more strings
+ * @param string
+ * a string
+ * @param strings
+ * more strings
* @return
*/
@scala.annotation.varargs
def column_name(string: String, strings: String*): String = (string +: strings)
- .map(s => if (s.contains(".") && !s.startsWith("`") && !s.endsWith("`")) s"`$s`" else s).mkString(".")
+ .map(s => if (s.contains(".") && !s.startsWith("`") && !s.endsWith("`")) s"`$s`" else s)
+ .mkString(".")
}
diff --git a/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala b/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala
index 60241e0d..9b831833 100644
--- a/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala
+++ b/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala
@@ -37,7 +37,7 @@ trait BuildVersion {
}
lazy val VersionString: String = props.getProperty("project.version")
-
+
lazy val BuildSparkMajorVersion: Int = props.getProperty("spark.major.version").toInt
lazy val BuildSparkMinorVersion: Int = props.getProperty("spark.minor.version").toInt
lazy val BuildSparkPatchVersion: Int = props.getProperty("spark.patch.version").split("-").head.toInt
diff --git a/src/main/scala/uk/co/gresearch/spark/Histogram.scala b/src/main/scala/uk/co/gresearch/spark/Histogram.scala
index c4ef8c49..c52cc679 100644
--- a/src/main/scala/uk/co/gresearch/spark/Histogram.scala
+++ b/src/main/scala/uk/co/gresearch/spark/Histogram.scala
@@ -23,20 +23,25 @@ import uk.co.gresearch.ExtendedAny
import scala.collection.JavaConverters
object Histogram {
+
/**
- * Compute the histogram of a column when aggregated by aggregate columns.
- * Thresholds are expected to be provided in ascending order.
- * The result dataframe contains the aggregate and histogram columns only.
- * For each threshold value in thresholds, there will be a column named s"≤threshold".
- * There will also be a final column called s">last_threshold", that counts the remaining
- * values that exceed the last threshold.
+ * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
+ * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
+ * in thresholds, there will be a column named s"≤threshold". There will also be a final column called
+ * s">last_threshold", that counts the remaining values that exceed the last threshold.
*
- * @param df dataset to compute histogram from
- * @param thresholds sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
- * @param valueColumn histogram is computed for values of this column
- * @param aggregateColumns histogram is computed against these columns
- * @tparam T type of histogram thresholds
- * @return dataframe with aggregate and histogram columns
+ * @param df
+ * dataset to compute histogram from
+ * @param thresholds
+ * sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
+ * @param valueColumn
+ * histogram is computed for values of this column
+ * @param aggregateColumns
+ * histogram is computed against these columns
+ * @tparam T
+ * type of histogram thresholds
+ * @return
+ * dataframe with aggregate and histogram columns
*/
def of[D, T](df: Dataset[D], thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = {
if (thresholds.isEmpty)
@@ -49,10 +54,9 @@ object Histogram {
df.toDF()
.withColumn(s"≤${thresholds.head}", when(valueColumn <= thresholds.head, 1).otherwise(0))
- .call(
- bins.foldLeft(_) { case (df, bin) =>
- df.withColumn(s"≤${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0))
- })
+ .call(bins.foldLeft(_) { case (df, bin) =>
+ df.withColumn(s"≤${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0))
+ })
.withColumn(s">${thresholds.last}", when(valueColumn > thresholds.last, 1).otherwise(0))
.groupBy(aggregateColumns: _*)
.agg(
@@ -63,22 +67,31 @@ object Histogram {
}
/**
- * Compute the histogram of a column when aggregated by aggregate columns.
- * Thresholds are expected to be provided in ascending order.
- * The result dataframe contains the aggregate and histogram columns only.
- * For each threshold value in thresholds, there will be a column named s"≤threshold".
- * There will also be a final column called s">last_threshold", that counts the remaining
- * values that exceed the last threshold.
+ * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
+ * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
+ * in thresholds, there will be a column named s"≤threshold". There will also be a final column called
+ * s">last_threshold", that counts the remaining values that exceed the last threshold.
*
- * @param df dataset to compute histogram from
- * @param thresholds sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
- * @param valueColumn histogram is computed for values of this column
- * @param aggregateColumns histogram is computed against these columns
- * @tparam T type of histogram thresholds
- * @return dataframe with aggregate and histogram columns
+ * @param df
+ * dataset to compute histogram from
+ * @param thresholds
+ * sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
+ * @param valueColumn
+ * histogram is computed for values of this column
+ * @param aggregateColumns
+ * histogram is computed against these columns
+ * @tparam T
+ * type of histogram thresholds
+ * @return
+ * dataframe with aggregate and histogram columns
*/
@scala.annotation.varargs
- def of[D, T](df: Dataset[D], thresholds: java.util.List[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =
+ def of[D, T](
+ df: Dataset[D],
+ thresholds: java.util.List[T],
+ valueColumn: Column,
+ aggregateColumns: Column*
+ ): DataFrame =
of(df, JavaConverters.iterableAsScalaIterable(thresholds).toSeq, valueColumn, aggregateColumns: _*)
}
diff --git a/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala b/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala
index e5015e65..5bc7d48c 100644
--- a/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala
+++ b/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala
@@ -21,10 +21,12 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, functions}
import org.apache.spark.sql.functions.{coalesce, col, lit, max, monotonically_increasing_id, spark_partition_id, sum}
import org.apache.spark.storage.StorageLevel
-case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
- storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
- unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,
- orderColumns: Seq[Column] = Seq.empty) {
+case class RowNumbersFunc(
+ rowNumberColumnName: String = "row_number",
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,
+ orderColumns: Seq[Column] = Seq.empty
+) {
def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =
this.copy(rowNumberColumnName = rowNumberColumnName)
@@ -39,7 +41,11 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
this.copy(orderColumns = orderColumns)
def of[D](df: Dataset[D]): DataFrame = {
- if (storageLevel.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) {
+ if (
+ storageLevel.equals(
+ StorageLevel.NONE
+ ) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)
+ ) {
throw new IllegalArgumentException(s"Storage level $storageLevel not supported with Spark 3.5.0 and above.")
}
@@ -53,7 +59,9 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
val partitionOffsetColumnName = prefix + "partition_offset"
// if no order is given, we preserve existing order
- val dfOrdered = if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id()) else df.orderBy(orderColumns: _*)
+ val dfOrdered =
+ if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id())
+ else df.orderBy(orderColumns: _*)
val order = if (orderColumns.isEmpty) Seq(col(monoIdColumnName)) else orderColumns
// add partition ids and local row numbers
@@ -66,17 +74,22 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
.withColumn(localRowNumberColumnName, functions.row_number().over(localRowNumberWindow))
// compute row offset for the partitions
- val cumRowNumbersWindow = Window.orderBy(partitionIdColumnName)
+ val cumRowNumbersWindow = Window
+ .orderBy(partitionIdColumnName)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val partitionOffsets = dfWithLocalRowNumbers
.groupBy(partitionIdColumnName)
.agg(max(localRowNumberColumnName).alias(maxLocalRowNumberColumnName))
.withColumn(cumRowNumbersColumnName, sum(maxLocalRowNumberColumnName).over(cumRowNumbersWindow))
- .select(col(partitionIdColumnName) + 1 as partitionIdColumnName, col(cumRowNumbersColumnName).as(partitionOffsetColumnName))
+ .select(
+ col(partitionIdColumnName) + 1 as partitionIdColumnName,
+ col(cumRowNumbersColumnName).as(partitionOffsetColumnName)
+ )
// compute global row number by adding local row number with partition offset
val partitionOffsetColumn = coalesce(col(partitionOffsetColumnName), lit(0))
- dfWithLocalRowNumbers.join(partitionOffsets, Seq(partitionIdColumnName), "left")
+ dfWithLocalRowNumbers
+ .join(partitionOffsets, Seq(partitionIdColumnName), "left")
.withColumn(rowNumberColumnName, col(localRowNumberColumnName) + partitionOffsetColumn)
.drop(monoIdColumnName, partitionIdColumnName, localRowNumberColumnName, partitionOffsetColumnName)
}
diff --git a/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala b/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
index 863b4b6b..36f597d3 100644
--- a/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
+++ b/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
@@ -51,7 +51,7 @@ case class SilentUnpersistHandle() extends UnpersistHandle {
}
}
-case class NoopUnpersistHandle() extends UnpersistHandle{
+case class NoopUnpersistHandle() extends UnpersistHandle {
override def setDataFrame(dataframe: DataFrame): DataFrame = dataframe
override def apply(): Unit = {}
override def apply(blocking: Boolean): Unit = {}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/App.scala b/src/main/scala/uk/co/gresearch/spark/diff/App.scala
index 0df86d24..22a6be79 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/App.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/App.scala
@@ -23,31 +23,28 @@ import uk.co.gresearch._
object App {
// define available options
- case class Options(master: Option[String] = None,
- appName: Option[String] = None,
- hive: Boolean = false,
-
- leftPath: Option[String] = None,
- rightPath: Option[String] = None,
- outputPath: Option[String] = None,
-
- leftFormat: Option[String] = None,
- rightFormat: Option[String] = None,
- outputFormat: Option[String] = None,
-
- leftSchema: Option[String] = None,
- rightSchema: Option[String] = None,
-
- leftOptions: Map[String, String] = Map.empty,
- rightOptions: Map[String, String] = Map.empty,
- outputOptions: Map[String, String] = Map.empty,
-
- ids: Seq[String] = Seq.empty,
- ignore: Seq[String] = Seq.empty,
- saveMode: SaveMode = SaveMode.ErrorIfExists,
- filter: Set[String] = Set.empty,
- statistics: Boolean = false,
- diffOptions: DiffOptions = DiffOptions.default)
+ case class Options(
+ master: Option[String] = None,
+ appName: Option[String] = None,
+ hive: Boolean = false,
+ leftPath: Option[String] = None,
+ rightPath: Option[String] = None,
+ outputPath: Option[String] = None,
+ leftFormat: Option[String] = None,
+ rightFormat: Option[String] = None,
+ outputFormat: Option[String] = None,
+ leftSchema: Option[String] = None,
+ rightSchema: Option[String] = None,
+ leftOptions: Map[String, String] = Map.empty,
+ rightOptions: Map[String, String] = Map.empty,
+ outputOptions: Map[String, String] = Map.empty,
+ ids: Seq[String] = Seq.empty,
+ ignore: Seq[String] = Seq.empty,
+ saveMode: SaveMode = SaveMode.ErrorIfExists,
+ filter: Set[String] = Set.empty,
+ statistics: Boolean = false,
+ diffOptions: DiffOptions = DiffOptions.default
+ )
// read options from args
val programName = s"spark-extension_${spark.BuildScalaCompatVersionString}-${spark.VersionString}.jar"
@@ -105,11 +102,13 @@ object App {
note("Input and output")
opt[String]('f', "format")
.valueName("")
- .action((x, c) => c.copy(
- leftFormat = c.leftFormat.orElse(Some(x)),
- rightFormat = c.rightFormat.orElse(Some(x)),
- outputFormat = c.outputFormat.orElse(Some(x))
- ))
+ .action((x, c) =>
+ c.copy(
+ leftFormat = c.leftFormat.orElse(Some(x)),
+ rightFormat = c.rightFormat.orElse(Some(x)),
+ outputFormat = c.outputFormat.orElse(Some(x))
+ )
+ )
.text("input and output file format (csv, json, parquet, ...)")
opt[String]("left-format")
.valueName("")
@@ -127,10 +126,12 @@ object App {
note("")
opt[String]('s', "schema")
.valueName("")
- .action((x, c) => c.copy(
- leftSchema = c.leftSchema.orElse(Some(x)),
- rightSchema = c.rightSchema.orElse(Some(x))
- ))
+ .action((x, c) =>
+ c.copy(
+ leftSchema = c.leftSchema.orElse(Some(x)),
+ rightSchema = c.rightSchema.orElse(Some(x))
+ )
+ )
.text("input schema")
opt[String]("left-schema")
.valueName("")
@@ -182,7 +183,9 @@ object App {
.optional()
.valueName("")
.action((x, c) => c.copy(filter = c.filter + x))
- .text(s"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)")
+ .text(
+ s"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)"
+ )
opt[Unit]("statistics")
.optional()
.action((_, c) => c.copy(statistics = true))
@@ -245,45 +248,83 @@ object App {
help("help").text("prints this usage text")
}
- def read(spark: SparkSession, format: Option[String], path: String, schema: Option[String], options: Map[String, String]): DataFrame =
+ def read(
+ spark: SparkSession,
+ format: Option[String],
+ path: String,
+ schema: Option[String],
+ options: Map[String, String]
+ ): DataFrame =
spark.read
- .when(format.isDefined).call(_.format(format.get))
+ .when(format.isDefined)
+ .call(_.format(format.get))
.options(options)
- .when(schema.isDefined).call(_.schema(schema.get))
- .when(format.isDefined).either(_.load(path)).or(_.table(path))
+ .when(schema.isDefined)
+ .call(_.schema(schema.get))
+ .when(format.isDefined)
+ .either(_.load(path))
+ .or(_.table(path))
- def write(df: DataFrame, format: Option[String], path: String, options: Map[String, String], saveMode: SaveMode, filter: Set[String], saveStats: Boolean, diffOptions: DiffOptions): Unit =
- df.when(filter.nonEmpty).call(_.where(col(diffOptions.diffColumn).isInCollection(filter)))
- .when(saveStats).call(_.groupBy(diffOptions.diffColumn).count.orderBy(diffOptions.diffColumn))
+ def write(
+ df: DataFrame,
+ format: Option[String],
+ path: String,
+ options: Map[String, String],
+ saveMode: SaveMode,
+ filter: Set[String],
+ saveStats: Boolean,
+ diffOptions: DiffOptions
+ ): Unit =
+ df.when(filter.nonEmpty)
+ .call(_.where(col(diffOptions.diffColumn).isInCollection(filter)))
+ .when(saveStats)
+ .call(_.groupBy(diffOptions.diffColumn).count.orderBy(diffOptions.diffColumn))
.write
- .when(format.isDefined).call(_.format(format.get))
+ .when(format.isDefined)
+ .call(_.format(format.get))
.options(options)
.mode(saveMode)
- .when(format.isDefined).either(_.save(path)).or(_.saveAsTable(path))
+ .when(format.isDefined)
+ .either(_.save(path))
+ .or(_.saveAsTable(path))
def main(args: Array[String]): Unit = {
// parse options
val options = parser.parse(args, Options()) match {
case Some(options) => options
- case None => sys.exit(1)
+ case None => sys.exit(1)
}
val unknownFilters = options.filter.filter(filter => !options.diffOptions.diffValues.contains(filter))
if (unknownFilters.nonEmpty) {
- throw new RuntimeException(s"Filter ${unknownFilters.mkString("'", "', '", "'")} not allowed, " +
- s"these are the configured diff values: ${options.diffOptions.diffValues.mkString("'", "', '", "'")}")
+ throw new RuntimeException(
+ s"Filter ${unknownFilters.mkString("'", "', '", "'")} not allowed, " +
+ s"these are the configured diff values: ${options.diffOptions.diffValues.mkString("'", "', '", "'")}"
+ )
}
// create spark session
- val spark = SparkSession.builder()
+ val spark = SparkSession
+ .builder()
.appName(options.appName.get)
- .when(options.hive).call(_.enableHiveSupport())
- .when(options.master.isDefined).call(_.master(options.master.get))
+ .when(options.hive)
+ .call(_.enableHiveSupport())
+ .when(options.master.isDefined)
+ .call(_.master(options.master.get))
.getOrCreate()
// read and write
val left = read(spark, options.leftFormat, options.leftPath.get, options.leftSchema, options.leftOptions)
val right = read(spark, options.rightFormat, options.rightPath.get, options.rightSchema, options.rightOptions)
val diff = left.diff(right, options.diffOptions, options.ids, options.ignore)
- write(diff, options.outputFormat, options.outputPath.get, options.outputOptions, options.saveMode, options.filter, options.statistics, options.diffOptions)
+ write(
+ diff,
+ options.outputFormat,
+ options.outputPath.get,
+ options.outputOptions,
+ options.saveMode,
+ options.filter,
+ options.statistics,
+ options.diffOptions
+ )
}
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala b/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala
index 9fa60f5b..9106de15 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala
@@ -25,113 +25,156 @@ import scala.collection.JavaConverters
/**
* Differ class to diff two Datasets. See Differ.of(…) for details.
- * @param options options for the diffing process
+ * @param options
+ * options for the diffing process
*/
class Differ(options: DiffOptions) {
- private[diff] def checkSchema[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Unit = {
- require(left.columns.length == left.columns.toSet.size &&
- right.columns.length == right.columns.toSet.size,
+ private[diff] def checkSchema[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Unit = {
+ require(
+ left.columns.length == left.columns.toSet.size &&
+ right.columns.length == right.columns.toSet.size,
"The datasets have duplicate columns.\n" +
s"Left column names: ${left.columns.mkString(", ")}\n" +
- s"Right column names: ${right.columns.mkString(", ")}")
+ s"Right column names: ${right.columns.mkString(", ")}"
+ )
val leftNonIgnored = left.columns.diffCaseSensitivity(ignoreColumns)
val rightNonIgnored = right.columns.diffCaseSensitivity(ignoreColumns)
val exceptIgnoredColumnsMsg = if (ignoreColumns.nonEmpty) " except ignored columns" else ""
- require(leftNonIgnored.length == rightNonIgnored.length,
+ require(
+ leftNonIgnored.length == rightNonIgnored.length,
"The number of columns doesn't match.\n" +
s"Left column names$exceptIgnoredColumnsMsg (${leftNonIgnored.length}): ${leftNonIgnored.mkString(", ")}\n" +
- s"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(", ")}")
+ s"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(", ")}"
+ )
require(leftNonIgnored.length > 0, s"The schema$exceptIgnoredColumnsMsg must not be empty")
// column types must match but we ignore the nullability of columns
- val leftFields = left.schema.fields.filter(f => !ignoreColumns.containsCaseSensitivity(f.name)).map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
- val rightFields = right.schema.fields.filter(f => !ignoreColumns.containsCaseSensitivity(f.name)).map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
+ val leftFields = left.schema.fields
+ .filter(f => !ignoreColumns.containsCaseSensitivity(f.name))
+ .map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
+ val rightFields = right.schema.fields
+ .filter(f => !ignoreColumns.containsCaseSensitivity(f.name))
+ .map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType)
val leftExtraSchema = leftFields.diff(rightFields)
val rightExtraSchema = rightFields.diff(leftFields)
- require(leftExtraSchema.isEmpty && rightExtraSchema.isEmpty,
+ require(
+ leftExtraSchema.isEmpty && rightExtraSchema.isEmpty,
"The datasets do not have the same schema.\n" +
s"Left extra columns: ${leftExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}\n" +
- s"Right extra columns: ${rightExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}")
+ s"Right extra columns: ${rightExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}"
+ )
val columns = leftNonIgnored
val pkColumns = if (idColumns.isEmpty) columns.toList else idColumns
val nonPkColumns = columns.diffCaseSensitivity(pkColumns)
val missingIdColumns = pkColumns.diffCaseSensitivity(columns)
- require(missingIdColumns.isEmpty,
- s"Some id columns do not exist: ${missingIdColumns.mkString(", ")} missing among ${columns.mkString(", ")}")
+ require(
+ missingIdColumns.isEmpty,
+ s"Some id columns do not exist: ${missingIdColumns.mkString(", ")} missing among ${columns.mkString(", ")}"
+ )
val missingIgnoreColumns = ignoreColumns.diffCaseSensitivity(left.columns).diffCaseSensitivity(right.columns)
- require(missingIgnoreColumns.isEmpty,
+ require(
+ missingIgnoreColumns.isEmpty,
s"Some ignore columns do not exist: ${missingIgnoreColumns.mkString(", ")} " +
- s"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(", ")}")
+ s"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(", ")}"
+ )
- require(!pkColumns.containsCaseSensitivity(options.diffColumn),
- s"The id columns must not contain the diff column name '${options.diffColumn}': ${pkColumns.mkString(", ")}")
- require(options.changeColumn.forall(!pkColumns.containsCaseSensitivity(_)),
- s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(", ")}")
+ require(
+ !pkColumns.containsCaseSensitivity(options.diffColumn),
+ s"The id columns must not contain the diff column name '${options.diffColumn}': ${pkColumns.mkString(", ")}"
+ )
+ require(
+ options.changeColumn.forall(!pkColumns.containsCaseSensitivity(_)),
+ s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(", ")}"
+ )
val diffValueColumns = getDiffColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1).diff(pkColumns)
if (Seq(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode)) {
- require(!diffValueColumns.containsCaseSensitivity(options.diffColumn),
+ require(
+ !diffValueColumns.containsCaseSensitivity(options.diffColumn),
s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " +
s"non-id columns must not contain the diff column name '${options.diffColumn}': " +
- s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}")
+ s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}"
+ )
- require(options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),
+ require(
+ options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),
s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " +
s"non-id columns must not contain the change column name '${options.changeColumn.get}': " +
- s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}")
+ s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}"
+ )
} else {
- require(!diffValueColumns.containsCaseSensitivity(options.diffColumn),
+ require(
+ !diffValueColumns.containsCaseSensitivity(options.diffColumn),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the diff column name '${options.diffColumn}': " +
- s"${nonPkColumns.mkString(", ")}")
+ s"${nonPkColumns.mkString(", ")}"
+ )
- require(options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),
+ require(
+ options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce the change column name '${options.changeColumn.orNull}': " +
- s"${nonPkColumns.mkString(", ")}")
+ s"${nonPkColumns.mkString(", ")}"
+ )
- require(diffValueColumns.forall(!pkColumns.containsCaseSensitivity(_)),
+ require(
+ diffValueColumns.forall(!pkColumns.containsCaseSensitivity(_)),
s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " +
s"together with these non-id columns " +
s"must not produce any id column name '${pkColumns.mkString("', '")}': " +
- s"${nonPkColumns.mkString(", ")}")
+ s"${nonPkColumns.mkString(", ")}"
+ )
}
}
- private def getChangeColumn(existsColumnName: String, valueColumns: Seq[String], left: Dataset[_], right: Dataset[_]): Option[Column] = {
+ private def getChangeColumn(
+ existsColumnName: String,
+ valueColumns: Seq[String],
+ left: Dataset[_],
+ right: Dataset[_]
+ ): Option[Column] = {
options.changeColumn
.map(changeColumn =>
- when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null)).
- otherwise(
+ when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null))
+ .otherwise(
Some(valueColumns.toSeq)
.filter(_.nonEmpty)
.map(columns =>
concat(
- columns.map(c =>
- when(left(backticks(c)) <=> right(backticks(c)), array()).otherwise(array(lit(c)))
- ): _*
+ columns
+ .map(c => when(left(backticks(c)) <=> right(backticks(c)), array()).otherwise(array(lit(c)))): _*
)
- ).getOrElse(
- array().cast(ArrayType(StringType, containsNull = false))
- )
- ).
- as(changeColumn)
+ )
+ .getOrElse(
+ array().cast(ArrayType(StringType, containsNull = false))
+ )
+ )
+ .as(changeColumn)
)
}
- private[diff] def getDiffColumns[T, U](pkColumns: Seq[String], valueColumns: Seq[String],
- left: Dataset[T], right: Dataset[U],
- ignoreColumns: Seq[String]): Seq[(String, Column)] = {
+ private[diff] def getDiffColumns[T, U](
+ pkColumns: Seq[String],
+ valueColumns: Seq[String],
+ left: Dataset[T],
+ right: Dataset[U],
+ ignoreColumns: Seq[String]
+ ): Seq[(String, Column)] = {
val idColumns = pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c))
val leftValueColumns = left.columns.filterIsInCaseSensitivity(valueColumns)
@@ -145,8 +188,24 @@ class Differ(options: DiffOptions) {
val (leftValues, rightValues) = if (options.sparseMode) {
(
- leftNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c))) else left(backticks(c))))).toMap,
- rightNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c))) else right(backticks(c))))).toMap
+ leftNonPkColumns
+ .map(c =>
+ (
+ handleConfiguredCaseSensitivity(c),
+ c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c)))
+ else left(backticks(c)))
+ )
+ )
+ .toMap,
+ rightNonPkColumns
+ .map(c =>
+ (
+ handleConfiguredCaseSensitivity(c),
+ c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c)))
+ else right(backticks(c)))
+ )
+ )
+ .toMap
)
} else {
(
@@ -189,15 +248,22 @@ class Differ(options: DiffOptions) {
case DiffMode.LeftSide | DiffMode.RightSide =>
// in left-side / right-side mode, we do not prefix columns
(
- if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues)) else valueColumns.map(alias(None, rightValues))
- ) ++ (
- if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues)) else rightIgnoredColumns.map(alias(None, rightValues))
- )
+ if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues))
+ else valueColumns.map(alias(None, rightValues))
+ ) ++ (
+ if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues))
+ else rightIgnoredColumns.map(alias(None, rightValues))
+ )
}
idColumns ++ nonIdColumns
}
- private def doDiff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty): DataFrame = {
+ private def doDiff[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String] = Seq.empty
+ ): DataFrame = {
checkSchema(left, right, idColumns, ignoreColumns)
val columns = left.columns.diffCaseSensitivity(ignoreColumns).toList
@@ -209,18 +275,21 @@ class Differ(options: DiffOptions) {
val existsColumnName = distinctPrefixFor(left.columns) + "exists"
val leftWithExists = left.withColumn(existsColumnName, lit(1))
val rightWithExists = right.withColumn(existsColumnName, lit(1))
- val joinCondition = pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _)
- val unChanged = valueVolumnsWithComparator.map { case (c, cmp) =>
- cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c)))
- }.reduceOption(_ && _)
+ val joinCondition =
+ pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _)
+ val unChanged = valueVolumnsWithComparator
+ .map { case (c, cmp) =>
+ cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c)))
+ }
+ .reduceOption(_ && _)
val changeCondition = not(unChanged.getOrElse(lit(true)))
val diffActionColumn =
- when(leftWithExists(existsColumnName).isNull, lit(options.insertDiffValue)).
- when(rightWithExists(existsColumnName).isNull, lit(options.deleteDiffValue)).
- when(changeCondition, lit(options.changeDiffValue)).
- otherwise(lit(options.nochangeDiffValue)).
- as(options.diffColumn)
+ when(leftWithExists(existsColumnName).isNull, lit(options.insertDiffValue))
+ .when(rightWithExists(existsColumnName).isNull, lit(options.deleteDiffValue))
+ .when(changeCondition, lit(options.changeDiffValue))
+ .otherwise(lit(options.nochangeDiffValue))
+ .as(options.diffColumn)
val diffColumns = getDiffColumns(pkColumns, valueColumns, left, right, ignoreColumns).map(_._2)
val changeColumn = getChangeColumn(existsColumnName, valueColumns, leftWithExists, rightWithExists)
@@ -228,27 +297,26 @@ class Differ(options: DiffOptions) {
.map(Seq(_))
.getOrElse(Seq.empty[Column])
- leftWithExists.join(rightWithExists, joinCondition, "fullouter")
+ leftWithExists
+ .join(rightWithExists, joinCondition, "fullouter")
.select((diffActionColumn +: changeColumn) ++ diffColumns: _*)
}
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * the same type `T`. Both Datasets must contain the same set of column names and data types.
- * The order of columns in the two Datasets is not relevant as columns are compared based on the
- * name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must
+ * contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as
+ * columns are compared based on the name, not the the position.
*
- * Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between two Datasets, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset,
- * that do not exist in the right Dataset are marked as `"D"`elete.
+ * Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
+ * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
+ * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
- * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
- * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
+ * changes will exists as respective `"D"`elete and `"I"`nsert.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -281,33 +349,31 @@ class Differ(options: DiffOptions) {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
+ * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
@scala.annotation.varargs
def diff[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame =
doDiff(left, right, idColumns)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
- * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
- * columns are compared based on the name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
+ * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
+ * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
+ * position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between two Datasets, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset,
- * that do not exist in the right Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
+ * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
+ * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
- * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
- * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
+ * changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -340,32 +406,30 @@ class Differ(options: DiffOptions) {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
+ * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
def diff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame =
doDiff(left, right, idColumns, ignoreColumns)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
- * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
- * columns are compared based on the name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
+ * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
+ * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
+ * position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between two Datasets, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset,
- * that do not exist in the right Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
+ * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
+ * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
- * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
- * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
+ * changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -398,12 +462,21 @@ class Differ(options: DiffOptions) {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
+ * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
- def diff[T, U](left: Dataset[T], right: Dataset[U], idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): DataFrame = {
- diff(left, right, JavaConverters.iterableAsScalaIterable(idColumns).toSeq, JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq)
+ def diff[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: java.util.List[String],
+ ignoreColumns: java.util.List[String]
+ ): DataFrame = {
+ diff(
+ left,
+ right,
+ JavaConverters.iterableAsScalaIterable(idColumns).toSeq,
+ JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq
+ )
}
/**
@@ -415,21 +488,22 @@ class Differ(options: DiffOptions) {
*/
// no @scala.annotation.varargs here as implicit arguments are explicit in Java
// this signature is redundant to the other diffAs method in Java
- def diffAs[T, U, V](left: Dataset[T], right: Dataset[T], idColumns: String*)
- (implicit diffEncoder: Encoder[V]): Dataset[V] = {
+ def diffAs[T, U, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit
+ diffEncoder: Encoder[V]
+ ): Dataset[V] = {
diffAs(left, right, diffEncoder, idColumns: _*)
}
/**
- * Returns a new Dataset that contains the differences between two Datasets of
- * similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])
- (implicit diffEncoder: Encoder[V]): Dataset[V] = {
+ def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit
+ diffEncoder: Encoder[V]
+ ): Dataset[V] = {
diffAs(left, right, diffEncoder, idColumns, ignoreColumns)
}
@@ -441,50 +515,67 @@ class Differ(options: DiffOptions) {
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
@scala.annotation.varargs
- def diffAs[T, V](left: Dataset[T], right: Dataset[T],
- diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {
+ def diffAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {
diffAs(left, right, diffEncoder, idColumns, Seq.empty)
}
/**
- * Returns a new Dataset that contains the differences between two Datasets of
- * similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def diffAs[T, U, V](left: Dataset[T], right: Dataset[U],
- diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] = {
- val nonIdColumns = if (idColumns.isEmpty) Seq.empty else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq
+ def diffAs[T, U, V](
+ left: Dataset[T],
+ right: Dataset[U],
+ diffEncoder: Encoder[V],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[V] = {
+ val nonIdColumns =
+ if (idColumns.isEmpty) Seq.empty
+ else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq
val encColumns = diffEncoder.schema.fields.map(_.name)
- val diffColumns = Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1)
+ val diffColumns =
+ Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1)
val extraColumns = encColumns.diffCaseSensitivity(diffColumns)
- require(extraColumns.isEmpty,
+ require(
+ extraColumns.isEmpty,
s"Diff encoder's columns must be part of the diff result schema, " +
- s"these columns are unexpected: ${extraColumns.mkString(", ")}")
+ s"these columns are unexpected: ${extraColumns.mkString(", ")}"
+ )
diff(left, right, idColumns, ignoreColumns).as[V](diffEncoder)
}
/**
- * Returns a new Dataset that contains the differences between two Datasets of
- * similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], diffEncoder: Encoder[V],
- idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[V] = {
- diffAs(left, right, diffEncoder,
- JavaConverters.iterableAsScalaIterable(idColumns).toSeq, JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq)
+ def diffAs[T, U, V](
+ left: Dataset[T],
+ right: Dataset[U],
+ diffEncoder: Encoder[V],
+ idColumns: java.util.List[String],
+ ignoreColumns: java.util.List[String]
+ ): Dataset[V] = {
+ diffAs(
+ left,
+ right,
+ diffEncoder,
+ JavaConverters.iterableAsScalaIterable(idColumns).toSeq,
+ JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq
+ )
}
/**
- * Returns a new Dataset that contains the differences between two Dataset of
- * the same type `T` as tuples of type `(String, T, T)`.
+ * Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type
+ * `(String, T, T)`.
*
* See `diff(Dataset[T], Dataset[T], String*)`.
*/
@@ -495,27 +586,39 @@ class Differ(options: DiffOptions) {
}
/**
- * Returns a new Dataset that contains the differences between two Dataset of
- * similar types `T` and `U` as tuples of type `(String, T, U)`.
+ * Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of
+ * type `(String, T, U)`.
*
* See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
- def diffWith[T, U](left: Dataset[T], right: Dataset[U],
- idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] = {
+ def diffWith[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[(String, T, U)] = {
val df = diff(left, right, idColumns, ignoreColumns)
diffWith(df, idColumns: _*)(left.encoder, right.encoder)
}
/**
- * Returns a new Dataset that contains the differences between two Dataset of
- * similar types `T` and `U` as tuples of type `(String, T, U)`.
+ * Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of
+ * type `(String, T, U)`.
*
* See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
- def diffWith[T, U](left: Dataset[T], right: Dataset[U],
- idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[(String, T, U)] = {
- diffWith(left, right,
- JavaConverters.iterableAsScalaIterable(idColumns).toSeq, JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq)
+ def diffWith[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: java.util.List[String],
+ ignoreColumns: java.util.List[String]
+ ): Dataset[(String, T, U)] = {
+ diffWith(
+ left,
+ right,
+ JavaConverters.iterableAsScalaIterable(idColumns).toSeq,
+ JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq
+ )
}
private def columnsOfSide(df: DataFrame, idColumns: Seq[String], sidePrefix: String): Seq[Column] = {
@@ -525,7 +628,7 @@ class Differ(options: DiffOptions) {
.map(c => if (idColumns.contains(c)) col(c) else col(c).as(c.replace(prefix, "")))
}
- private def diffWith[T : Encoder, U : Encoder](diff: DataFrame, idColumns: String*): Dataset[(String, T, U)] = {
+ private def diffWith[T: Encoder, U: Encoder](diff: DataFrame, idColumns: String*): Dataset[(String, T, U)] = {
val leftColumns = columnsOfSide(diff, idColumns, options.leftColumnPrefix)
val rightColumns = columnsOfSide(diff, idColumns, options.rightColumnPrefix)
@@ -540,7 +643,9 @@ class Differ(options: DiffOptions) {
val plan = diff.select(diffColumn, leftStruct, rightStruct).queryExecution.logical
val encoder: Encoder[(String, T, U)] = Encoders.tuple(
- Encoders.STRING, implicitly[Encoder[T]], implicitly[Encoder[U]]
+ Encoders.STRING,
+ implicitly[Encoder[T]],
+ implicitly[Encoder[U]]
)
new Dataset(diff.sparkSession, plan, encoder)
@@ -555,22 +660,20 @@ object Diff {
val default = new Differ(DiffOptions.default)
/**
- * Returns a new DataFrame that contains the differences between two Datasets
- * of the same type `T`. Both Datasets must contain the same set of column names and data types.
- * The order of columns in the two Datasets is not relevant as columns are compared based on the
- * name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must
+ * contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as
+ * columns are compared based on the name, not the the position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between two Datasets, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset,
- * that do not exist in the right Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
+ * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
+ * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
- * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
- * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
+ * changes will exists as respective `"D"`elete and `"I"`nsert.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -603,33 +706,31 @@ object Diff {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
+ * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
@scala.annotation.varargs
def of[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame =
default.diff(left, right, idColumns: _*)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
- * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
- * columns are compared based on the name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
+ * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
+ * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
+ * position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between two Datasets, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset,
- * that do not exist in the right Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
+ * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
+ * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
- * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
- * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
+ * changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -662,32 +763,35 @@ object Diff {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
+ * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
- def of[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty): DataFrame =
+ def of[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String] = Seq.empty
+ ): DataFrame =
default.diff(left, right, idColumns, ignoreColumns)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
- * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
- * columns are compared based on the name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
+ * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
+ * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
+ * position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between two Datasets, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset,
- * that do not exist in the right Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset,
+ * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of
+ * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete.
*
- * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows
- * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all
+ * changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -720,16 +824,19 @@ object Diff {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are
+ * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*/
- def of[T, U](left: Dataset[T], right: Dataset[U], idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): DataFrame =
+ def of[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: java.util.List[String],
+ ignoreColumns: java.util.List[String]
+ ): DataFrame =
default.diff(left, right, idColumns, ignoreColumns)
/**
- * Returns a new Dataset that contains the differences between two Datasets of
- * the same type `T`.
+ * Returns a new Dataset that contains the differences between two Datasets of the same type `T`.
*
* See `of(Dataset[T], Dataset[T], String*)`.
*
@@ -737,62 +844,72 @@ object Diff {
*/
// no @scala.annotation.varargs here as implicit arguments are explicit in Java
// this signature is redundant to the other ofAs method in Java
- def ofAs[T, V](left: Dataset[T], right: Dataset[T], idColumns: String*)
- (implicit diffEncoder: Encoder[V]): Dataset[V] =
+ def ofAs[T, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit
+ diffEncoder: Encoder[V]
+ ): Dataset[V] =
default.diffAs(left, right, idColumns: _*)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def ofAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty)
- (implicit diffEncoder: Encoder[V]): Dataset[V] =
+ def ofAs[T, U, V](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String] = Seq.empty
+ )(implicit diffEncoder: Encoder[V]): Dataset[V] =
default.diffAs(left, right, idColumns, ignoreColumns)
/**
- * Returns a new Dataset that contains the differences between two Datasets of
- * the same type `T`.
+ * Returns a new Dataset that contains the differences between two Datasets of the same type `T`.
*
* See `of(Dataset[T], Dataset[T], String*)`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
@scala.annotation.varargs
- def ofAs[T, V](left: Dataset[T], right: Dataset[T],
- diffEncoder: Encoder[V], idColumns: String*): Dataset[V] =
+ def ofAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] =
default.diffAs(left, right, diffEncoder, idColumns: _*)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def ofAs[T, U, V](left: Dataset[T], right: Dataset[U],
- diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] =
+ def ofAs[T, U, V](
+ left: Dataset[T],
+ right: Dataset[U],
+ diffEncoder: Encoder[V],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[V] =
default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def ofAs[T, U, V](left: Dataset[T], right: Dataset[U], diffEncoder: Encoder[V],
- idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[V] =
+ def ofAs[T, U, V](
+ left: Dataset[T],
+ right: Dataset[U],
+ diffEncoder: Encoder[V],
+ idColumns: java.util.List[String],
+ ignoreColumns: java.util.List[String]
+ ): Dataset[V] =
default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns)
/**
- * Returns a new Dataset that contains the differences between two Dataset of
- * the same type `T` as tuples of type `(String, T, T)`.
+ * Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type
+ * `(String, T, T)`.
*
* See `of(Dataset[T], Dataset[T], String*)`.
*/
@@ -801,22 +918,30 @@ object Diff {
default.diffWith(left, right, idColumns: _*)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U` as tuples of type `(String, T, U)`.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples
+ * of type `(String, T, U)`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
- def ofWith[T, U](left: Dataset[T], right: Dataset[U],
- idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] =
+ def ofWith[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[(String, T, U)] =
default.diffWith(left, right, idColumns, ignoreColumns)
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U` as tuples of type `(String, T, U)`.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples
+ * of type `(String, T, U)`.
*
* See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`.
*/
- def ofWith[T, U](left: Dataset[T], right: Dataset[U],
- idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[(String, T, U)] =
+ def ofWith[T, U](
+ left: Dataset[T],
+ right: Dataset[U],
+ idColumns: java.util.List[String],
+ ignoreColumns: java.util.List[String]
+ ): Dataset[(String, T, U)] =
default.diffWith(left, right, idColumns, ignoreColumns)
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala b/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala
index 1c9e2f1d..170d2240 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala
@@ -23,6 +23,7 @@ import uk.co.gresearch.spark.diff.comparator._
import java.time.Duration
object DiffComparators {
+
/**
* The default comparator used in [[DiffOptions.default.defaultComparator]].
*/
@@ -34,17 +35,17 @@ object DiffComparators {
def nullSafeEqual(): DiffComparator = NullSafeEqualDiffComparator
/**
- * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]].
- * The implicit [[Encoder]] of type [[T]] determines the input data type of the comparator.
- * Only columns of that type can be compared.
+ * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. The implicit [[Encoder]] of
+ * type [[T]] determines the input data type of the comparator. Only columns of that type can be compared.
*/
- def equiv[T : Encoder](equiv: math.Equiv[T]): EquivDiffComparator[T] = EquivDiffComparator(equiv)
+ def equiv[T: Encoder](equiv: math.Equiv[T]): EquivDiffComparator[T] = EquivDiffComparator(equiv)
/**
- * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]].
- * Only columns of the given data type `inputType` can be compared.
+ * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. Only columns of the given
+ * data type `inputType` can be compared.
*/
- def equiv[T](equiv: math.Equiv[T], inputType: DataType): EquivDiffComparator[T] = EquivDiffComparator(equiv, inputType)
+ def equiv[T](equiv: math.Equiv[T], inputType: DataType): EquivDiffComparator[T] =
+ EquivDiffComparator(equiv, inputType)
/**
* Return a comparator that uses the given [[math.Equiv]] to compare values of any type.
@@ -52,17 +53,15 @@ object DiffComparators {
def equiv(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivDiffComparator(equiv)
/**
- * This comparator considers values equal when they are less than `epsilon` apart.
- * It can be configured to use `epsilon` as an absolute (`.asAbsolute()`) threshold,
- * or as relative (`.asRelative()`) to the larger value. Further, the threshold itself can be
- * considered equal (`.asInclusive()`) or not equal (`.asExclusive()`):
+ * This comparator considers values equal when they are less than `epsilon` apart. It can be configured to use
+ * `epsilon` as an absolute (`.asAbsolute()`) threshold, or as relative (`.asRelative()`) to the larger value.
+ * Further, the threshold itself can be considered equal (`.asInclusive()`) or not equal (`.asExclusive()`):
*
- *
- * - `DiffComparator.epsilon(epsilon).asAbsolute().asInclusive()`: `abs(left - right) ≤ epsilon`
- * - `DiffComparator.epsilon(epsilon).asAbsolute().asExclusive()`: `abs(left - right) < epsilon`
- * - `DiffComparator.epsilon(epsilon).asRelative().asInclusive()`: `abs(left - right) ≤ epsilon * max(abs(left), abs(right))`
- * - `DiffComparator.epsilon(epsilon).asRelative().asExclusive()`: `abs(left - right) < epsilon * max(abs(left), abs(right))`
- *
+ * - `DiffComparator.epsilon(epsilon).asAbsolute().asInclusive()`: `abs(left - right) ≤ epsilon`
+ * - `DiffComparator.epsilon(epsilon).asAbsolute().asExclusive()`: `abs(left - right) < epsilon`
+ * - `DiffComparator.epsilon(epsilon).asRelative().asInclusive()`: `abs(left - right) ≤ epsilon * max(abs(left),
+ * abs(right))`
- `DiffComparator.epsilon(epsilon).asRelative().asExclusive()`: `abs(left - right) < epsilon *
+ * max(abs(left), abs(right))`
*
* Requires compared column types to implement `-`, `*`, `<`, `==`, and `abs`.
*/
@@ -71,8 +70,9 @@ object DiffComparators {
/**
* A comparator for string values.
*
- * With `whitespaceAgnostic` set `true`, differences in white spaces are ignored. This ignores leading and trailing whitespaces as well.
- * With `whitespaceAgnostic` set `false`, this is equal to the default string comparison (see [[default()]]).
+ * With `whitespaceAgnostic` set `true`, differences in white spaces are ignored. This ignores leading and trailing
+ * whitespaces as well. With `whitespaceAgnostic` set `false`, this is equal to the default string comparison (see
+ * [[default()]]).
*/
def string(whitespaceAgnostic: Boolean = true): StringDiffComparator =
if (whitespaceAgnostic) {
@@ -85,11 +85,9 @@ object DiffComparators {
* This comparator considers two `DateType` or `TimestampType` values equal when they are at most `duration` apart.
* Duration is an instance of `java.time.Duration`.
*
- * The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal (`.asExclusive()`):
- *
- * - `DiffComparator.duration(duration).asInclusive()`: `left - right ≤ duration`
- * - `DiffComparator.duration(duration).asExclusive()`: `left - right < duration`
- *
+ * The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal
+ * (`.asExclusive()`): - `DiffComparator.duration(duration).asInclusive()`: `left - right ≤ duration`
+ * - `DiffComparator.duration(duration).asExclusive()`: `left - right < duration`
*/
def duration(duration: Duration): DurationDiffComparator = DurationDiffComparator(duration)
@@ -103,7 +101,9 @@ object DiffComparators {
/**
* This comparator compares two `Map[K,V]` values. They are equal when they match in all their keys and values.
*
- * @param keyOrderSensitive comparator compares key order if true
+ * @param keyOrderSensitive
+ * comparator compares key order if true
*/
- def map[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): DiffComparator = MapDiffComparator[K, V](keyOrderSensitive)
+ def map[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): DiffComparator =
+ MapDiffComparator[K, V](keyOrderSensitive)
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala b/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala
index 7f434038..0bd300f7 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala
@@ -20,7 +20,12 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types.{DataType, StructField}
import uk.co.gresearch.spark.diff
import uk.co.gresearch.spark.diff.DiffMode.{Default, DiffMode}
-import uk.co.gresearch.spark.diff.comparator.{DefaultDiffComparator, DiffComparator, EquivDiffComparator, TypedDiffComparator}
+import uk.co.gresearch.spark.diff.comparator.{
+ DefaultDiffComparator,
+ DiffComparator,
+ EquivDiffComparator,
+ TypedDiffComparator
+}
import scala.annotation.varargs
import scala.collection.Map
@@ -34,23 +39,20 @@ object DiffMode extends Enumeration {
/**
* The diff mode determines the output columns of the diffing transformation.
*
- * - ColumnByColumn: The diff contains value columns from the left and right dataset,
- * arranged column by column:
- * diff,( changes,) id-1, id-2, …, left-value-1, right-value-1, left-value-2, right-value-2, …
+ * - ColumnByColumn: The diff contains value columns from the left and right dataset, arranged column by column:
+ * diff,( changes,) id-1, id-2, …, left-value-1, right-value-1, left-value-2, right-value-2, …
*
- * - SideBySide: The diff contains value columns from the left and right dataset,
- * arranged side by side:
- * diff,( changes,) id-1, id-2, …, left-value-1, left-value-2, …, right-value-1, right-value-2, …
- * - LeftSide / RightSide: The diff contains value columns from the left / right dataset only.
+ * - SideBySide: The diff contains value columns from the left and right dataset, arranged side by side: diff,(
+ * changes,) id-1, id-2, …, left-value-1, left-value-2, …, right-value-1, right-value-2, …
+ * - LeftSide / RightSide: The diff contains value columns from the left / right dataset only.
*/
val ColumnByColumn, SideBySide, LeftSide, RightSide = Value
/**
- * The diff mode determines the output columns of the diffing transformation.
- * The default diff mode is ColumnByColumn.
+ * The diff mode determines the output columns of the diffing transformation. The default diff mode is ColumnByColumn.
*
- * Default is not a enum value here (hence the def) so that we do not have to include it in every
- * match clause. We will see the respective enum value that Default points to instead.
+ * Default is not a enum value here (hence the def) so that we do not have to include it in every match clause. We
+ * will see the respective enum value that Default points to instead.
*/
def Default: diff.DiffMode.Value = ColumnByColumn
@@ -72,45 +74,62 @@ object DiffMode extends Enumeration {
/**
* Configuration class for diffing Datasets.
*
- * @param diffColumn name of the diff column
- * @param leftColumnPrefix prefix of columns from the left Dataset
- * @param rightColumnPrefix prefix of columns from the right Dataset
- * @param insertDiffValue value in diff column for inserted rows
- * @param changeDiffValue value in diff column for changed rows
- * @param deleteDiffValue value in diff column for deleted rows
- * @param nochangeDiffValue value in diff column for un-changed rows
- * @param changeColumn name of change column
- * @param diffMode diff output format
- * @param sparseMode un-changed values are null on both sides
- * @param defaultComparator default custom comparator
- * @param dataTypeComparators custom comparator for some data type
- * @param columnNameComparators custom comparator for some column name
+ * @param diffColumn
+ * name of the diff column
+ * @param leftColumnPrefix
+ * prefix of columns from the left Dataset
+ * @param rightColumnPrefix
+ * prefix of columns from the right Dataset
+ * @param insertDiffValue
+ * value in diff column for inserted rows
+ * @param changeDiffValue
+ * value in diff column for changed rows
+ * @param deleteDiffValue
+ * value in diff column for deleted rows
+ * @param nochangeDiffValue
+ * value in diff column for un-changed rows
+ * @param changeColumn
+ * name of change column
+ * @param diffMode
+ * diff output format
+ * @param sparseMode
+ * un-changed values are null on both sides
+ * @param defaultComparator
+ * default custom comparator
+ * @param dataTypeComparators
+ * custom comparator for some data type
+ * @param columnNameComparators
+ * custom comparator for some column name
*/
-case class DiffOptions(diffColumn: String,
- leftColumnPrefix: String,
- rightColumnPrefix: String,
- insertDiffValue: String,
- changeDiffValue: String,
- deleteDiffValue: String,
- nochangeDiffValue: String,
- changeColumn: Option[String] = None,
- diffMode: DiffMode = Default,
- sparseMode: Boolean = false,
- defaultComparator: DiffComparator = DefaultDiffComparator,
- dataTypeComparators: Map[DataType, DiffComparator] = Map.empty,
- columnNameComparators: Map[String, DiffComparator] = Map.empty) {
+case class DiffOptions(
+ diffColumn: String,
+ leftColumnPrefix: String,
+ rightColumnPrefix: String,
+ insertDiffValue: String,
+ changeDiffValue: String,
+ deleteDiffValue: String,
+ nochangeDiffValue: String,
+ changeColumn: Option[String] = None,
+ diffMode: DiffMode = Default,
+ sparseMode: Boolean = false,
+ defaultComparator: DiffComparator = DefaultDiffComparator,
+ dataTypeComparators: Map[DataType, DiffComparator] = Map.empty,
+ columnNameComparators: Map[String, DiffComparator] = Map.empty
+) {
// Constructor for Java to construct default options
def this() = this("diff", "left", "right", "I", "C", "D", "N")
- def this(diffColumn: String,
- leftColumnPrefix: String,
- rightColumnPrefix: String,
- insertDiffValue: String,
- changeDiffValue: String,
- deleteDiffValue: String,
- nochangeDiffValue: String,
- changeColumn: Option[String],
- diffMode: DiffMode,
- sparseMode: Boolean) = {
+ def this(
+ diffColumn: String,
+ leftColumnPrefix: String,
+ rightColumnPrefix: String,
+ insertDiffValue: String,
+ changeDiffValue: String,
+ deleteDiffValue: String,
+ nochangeDiffValue: String,
+ changeColumn: Option[String],
+ diffMode: DiffMode,
+ sparseMode: Boolean
+ ) = {
this(
diffColumn,
leftColumnPrefix,
@@ -124,188 +143,215 @@ case class DiffOptions(diffColumn: String,
sparseMode,
DefaultDiffComparator,
Map.empty,
- Map.empty)
+ Map.empty
+ )
}
require(leftColumnPrefix.nonEmpty, "Left column prefix must not be empty")
require(rightColumnPrefix.nonEmpty, "Right column prefix must not be empty")
- require(handleConfiguredCaseSensitivity(leftColumnPrefix) != handleConfiguredCaseSensitivity(rightColumnPrefix),
- s"Left and right column prefix must be distinct: $leftColumnPrefix")
+ require(
+ handleConfiguredCaseSensitivity(leftColumnPrefix) != handleConfiguredCaseSensitivity(rightColumnPrefix),
+ s"Left and right column prefix must be distinct: $leftColumnPrefix"
+ )
val diffValues = Seq(insertDiffValue, changeDiffValue, deleteDiffValue, nochangeDiffValue)
- require(diffValues.distinct.length == diffValues.length,
- s"Diff values must be distinct: $diffValues")
+ require(diffValues.distinct.length == diffValues.length, s"Diff values must be distinct: $diffValues")
- require(!changeColumn.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(diffColumn)),
- s"Change column name must be different to diff column: $diffColumn")
+ require(
+ !changeColumn.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(diffColumn)),
+ s"Change column name must be different to diff column: $diffColumn"
+ )
/**
- * Fluent method to change the diff column name.
- * Returns a new immutable DiffOptions instance with the new diff column name.
- * @param diffColumn new diff column name
- * @return new immutable DiffOptions instance
+ * Fluent method to change the diff column name. Returns a new immutable DiffOptions instance with the new diff column
+ * name.
+ * @param diffColumn
+ * new diff column name
+ * @return
+ * new immutable DiffOptions instance
*/
def withDiffColumn(diffColumn: String): DiffOptions = {
this.copy(diffColumn = diffColumn)
}
/**
- * Fluent method to change the prefix of columns from the left Dataset.
- * Returns a new immutable DiffOptions instance with the new column prefix.
- * @param leftColumnPrefix new column prefix
- * @return new immutable DiffOptions instance
+ * Fluent method to change the prefix of columns from the left Dataset. Returns a new immutable DiffOptions instance
+ * with the new column prefix.
+ * @param leftColumnPrefix
+ * new column prefix
+ * @return
+ * new immutable DiffOptions instance
*/
def withLeftColumnPrefix(leftColumnPrefix: String): DiffOptions = {
this.copy(leftColumnPrefix = leftColumnPrefix)
}
/**
- * Fluent method to change the prefix of columns from the right Dataset.
- * Returns a new immutable DiffOptions instance with the new column prefix.
- * @param rightColumnPrefix new column prefix
- * @return new immutable DiffOptions instance
+ * Fluent method to change the prefix of columns from the right Dataset. Returns a new immutable DiffOptions instance
+ * with the new column prefix.
+ * @param rightColumnPrefix
+ * new column prefix
+ * @return
+ * new immutable DiffOptions instance
*/
def withRightColumnPrefix(rightColumnPrefix: String): DiffOptions = {
this.copy(rightColumnPrefix = rightColumnPrefix)
}
/**
- * Fluent method to change the value of inserted rows in the diff column.
- * Returns a new immutable DiffOptions instance with the new diff value.
- * @param insertDiffValue new diff value
- * @return new immutable DiffOptions instance
+ * Fluent method to change the value of inserted rows in the diff column. Returns a new immutable DiffOptions instance
+ * with the new diff value.
+ * @param insertDiffValue
+ * new diff value
+ * @return
+ * new immutable DiffOptions instance
*/
def withInsertDiffValue(insertDiffValue: String): DiffOptions = {
this.copy(insertDiffValue = insertDiffValue)
}
/**
- * Fluent method to change the value of changed rows in the diff column.
- * Returns a new immutable DiffOptions instance with the new diff value.
- * @param changeDiffValue new diff value
- * @return new immutable DiffOptions instance
+ * Fluent method to change the value of changed rows in the diff column. Returns a new immutable DiffOptions instance
+ * with the new diff value.
+ * @param changeDiffValue
+ * new diff value
+ * @return
+ * new immutable DiffOptions instance
*/
def withChangeDiffValue(changeDiffValue: String): DiffOptions = {
this.copy(changeDiffValue = changeDiffValue)
}
/**
- * Fluent method to change the value of deleted rows in the diff column.
- * Returns a new immutable DiffOptions instance with the new diff value.
- * @param deleteDiffValue new diff value
- * @return new immutable DiffOptions instance
+ * Fluent method to change the value of deleted rows in the diff column. Returns a new immutable DiffOptions instance
+ * with the new diff value.
+ * @param deleteDiffValue
+ * new diff value
+ * @return
+ * new immutable DiffOptions instance
*/
def withDeleteDiffValue(deleteDiffValue: String): DiffOptions = {
this.copy(deleteDiffValue = deleteDiffValue)
}
/**
- * Fluent method to change the value of un-changed rows in the diff column.
- * Returns a new immutable DiffOptions instance with the new diff value.
- * @param nochangeDiffValue new diff value
- * @return new immutable DiffOptions instance
+ * Fluent method to change the value of un-changed rows in the diff column. Returns a new immutable DiffOptions
+ * instance with the new diff value.
+ * @param nochangeDiffValue
+ * new diff value
+ * @return
+ * new immutable DiffOptions instance
*/
def withNochangeDiffValue(nochangeDiffValue: String): DiffOptions = {
this.copy(nochangeDiffValue = nochangeDiffValue)
}
/**
- * Fluent method to change the change column name.
- * Returns a new immutable DiffOptions instance with the new change column name.
- * @param changeColumn new change column name
- * @return new immutable DiffOptions instance
+ * Fluent method to change the change column name. Returns a new immutable DiffOptions instance with the new change
+ * column name.
+ * @param changeColumn
+ * new change column name
+ * @return
+ * new immutable DiffOptions instance
*/
def withChangeColumn(changeColumn: String): DiffOptions = {
this.copy(changeColumn = Some(changeColumn))
}
/**
- * Fluent method to remove change column.
- * Returns a new immutable DiffOptions instance without a change column.
- * @return new immutable DiffOptions instance
+ * Fluent method to remove change column. Returns a new immutable DiffOptions instance without a change column.
+ * @return
+ * new immutable DiffOptions instance
*/
def withoutChangeColumn(): DiffOptions = {
this.copy(changeColumn = None)
}
/**
- * Fluent method to change the diff mode.
- * Returns a new immutable DiffOptions instance with the new diff mode.
- * @return new immutable DiffOptions instance
+ * Fluent method to change the diff mode. Returns a new immutable DiffOptions instance with the new diff mode.
+ * @return
+ * new immutable DiffOptions instance
*/
def withDiffMode(diffMode: DiffMode): DiffOptions = {
this.copy(diffMode = diffMode)
}
/**
- * Fluent method to change the sparse mode.
- * Returns a new immutable DiffOptions instance with the new sparse mode.
- * @return new immutable DiffOptions instance
+ * Fluent method to change the sparse mode. Returns a new immutable DiffOptions instance with the new sparse mode.
+ * @return
+ * new immutable DiffOptions instance
*/
def withSparseMode(sparseMode: Boolean): DiffOptions = {
this.copy(sparseMode = sparseMode)
}
/**
- * Fluent method to add a default comparator.
- * Returns a new immutable DiffOptions instance with the new default comparator.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a default comparator. Returns a new immutable DiffOptions instance with the new default
+ * comparator.
+ * @return
+ * new immutable DiffOptions instance
*/
def withDefaultComparator(diffComparator: DiffComparator): DiffOptions = {
this.copy(defaultComparator = diffComparator)
}
/**
- * Fluent method to add a typed equivalent operator as a default comparator.
- * The encoder defines the input type of the comparator.
- * Returns a new immutable DiffOptions instance with the new default comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators
- * can only be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a typed equivalent operator as a default comparator. The encoder defines the input type of the
+ * comparator. Returns a new immutable DiffOptions instance with the new default comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
- def withDefaultComparator[T : Encoder](equiv: math.Equiv[T]): DiffOptions = {
+ def withDefaultComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions = {
withDefaultComparator(EquivDiffComparator(equiv))
}
/**
- * Fluent method to add a typed equivalent operator as a default comparator.
- * Returns a new immutable DiffOptions instance with the new default comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators
- * can only be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a typed equivalent operator as a default comparator. Returns a new immutable DiffOptions
+ * instance with the new default comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
def withDefaultComparator[T](equiv: math.Equiv[T], inputDataType: DataType): DiffOptions = {
withDefaultComparator(EquivDiffComparator(equiv, inputDataType))
}
/**
- * Fluent method to add an equivalent operator as a default comparator.
- * Returns a new immutable DiffOptions instance with the new default comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators
- * can only be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add an equivalent operator as a default comparator. Returns a new immutable DiffOptions instance
+ * with the new default comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
def withDefaultComparator(equiv: math.Equiv[Any]): DiffOptions = {
withDefaultComparator(EquivDiffComparator(equiv))
}
/**
- * Fluent method to add a comparator for its input data type.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a comparator for its input data type. Returns a new immutable DiffOptions instance with the
+ * new comparator.
+ * @return
+ * new immutable DiffOptions instance
*/
def withComparator(diffComparator: TypedDiffComparator): DiffOptions = {
if (dataTypeComparators.contains(diffComparator.inputType)) {
- throw new IllegalArgumentException(
- s"A comparator for data type ${diffComparator.inputType} exists already.")
+ throw new IllegalArgumentException(s"A comparator for data type ${diffComparator.inputType} exists already.")
}
this.copy(dataTypeComparators = dataTypeComparators ++ Map(diffComparator.inputType -> diffComparator))
}
/**
- * Fluent method to add a comparator for one or more data types.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a comparator for one or more data types. Returns a new immutable DiffOptions instance with the
+ * new comparator.
+ * @return
+ * new immutable DiffOptions instance
*/
@varargs
def withComparator(diffComparator: DiffComparator, dataType: DataType, dataTypes: DataType*): DiffOptions = {
@@ -313,8 +359,10 @@ case class DiffOptions(diffColumn: String,
diffComparator match {
case typed: TypedDiffComparator if allDataTypes.exists(_ != typed.inputType) =>
- throw new IllegalArgumentException(s"Comparator with input type ${typed.inputType.simpleString} " +
- s"cannot be used for data type ${allDataTypes.filter(_ != typed.inputType).map(_.simpleString).sorted.mkString(", ")}")
+ throw new IllegalArgumentException(
+ s"Comparator with input type ${typed.inputType.simpleString} " +
+ s"cannot be used for data type ${allDataTypes.filter(_ != typed.inputType).map(_.simpleString).sorted.mkString(", ")}"
+ )
case _ =>
}
@@ -322,15 +370,17 @@ case class DiffOptions(diffColumn: String,
if (existingDataTypes.nonEmpty) {
throw new IllegalArgumentException(
s"A comparator for data type${if (existingDataTypes.length > 1) "s" else ""} " +
- s"${existingDataTypes.map(_.simpleString).sorted.mkString(", ")} exists already.")
+ s"${existingDataTypes.map(_.simpleString).sorted.mkString(", ")} exists already."
+ )
}
this.copy(dataTypeComparators = dataTypeComparators ++ allDataTypes.map(dt => dt -> diffComparator))
}
/**
- * Fluent method to add a comparator for one or more column names.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a comparator for one or more column names. Returns a new immutable DiffOptions instance with
+ * the new comparator.
+ * @return
+ * new immutable DiffOptions instance
*/
@varargs
def withComparator(diffComparator: DiffComparator, columnName: String, columnNames: String*): DiffOptions = {
@@ -339,79 +389,96 @@ case class DiffOptions(diffColumn: String,
if (existingColumnNames.nonEmpty) {
throw new IllegalArgumentException(
s"A comparator for column name${if (existingColumnNames.length > 1) "s" else ""} " +
- s"${existingColumnNames.sorted.mkString(", ")} exists already.")
+ s"${existingColumnNames.sorted.mkString(", ")} exists already."
+ )
}
this.copy(columnNameComparators = columnNameComparators ++ allColumnNames.map(name => name -> diffComparator))
}
/**
- * Fluent method to add a typed equivalent operator as a comparator for its input data type.
- * The encoder defines the input type of the comparator.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators can only
- * be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a typed equivalent operator as a comparator for its input data type. The encoder defines the
+ * input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
- def withComparator[T : Encoder](equiv: math.Equiv[T]): DiffOptions =
+ def withComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions =
withComparator(EquivDiffComparator(equiv))
/**
- * Fluent method to add a typed equivalent operator as a comparator for one or more column names.
- * The encoder defines the input type of the comparator.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators can only
- * be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add a typed equivalent operator as a comparator for one or more column names. The encoder defines
+ * the input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
- def withComparator[T : Encoder](equiv: math.Equiv[T], columnName: String, columnNames: String*): DiffOptions =
+ def withComparator[T: Encoder](equiv: math.Equiv[T], columnName: String, columnNames: String*): DiffOptions =
withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*)
/**
- * Fluent method to add an equivalent operator as a comparator for one or more column names.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators
- * can only be implemented via the `DiffComparator` interface.
- * @note Java-specific method
- * @return new immutable DiffOptions instance
+ * Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable
+ * DiffOptions instance with the new comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @note
+ * Java-specific method
+ * @return
+ * new immutable DiffOptions instance
*/
@varargs
- def withComparator[T](equiv: math.Equiv[T], encoder: Encoder[T], columnName: String, columnNames: String*): DiffOptions =
+ def withComparator[T](
+ equiv: math.Equiv[T],
+ encoder: Encoder[T],
+ columnName: String,
+ columnNames: String*
+ ): DiffOptions =
withComparator(EquivDiffComparator(equiv)(encoder), columnName, columnNames: _*)
/**
- * Fluent method to add an equivalent operator as a comparator for one or more data types.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators
- * can only be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add an equivalent operator as a comparator for one or more data types. Returns a new immutable
+ * DiffOptions instance with the new comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
// There is probably no use case of calling this with multiple datatype while T not being Any
// But this is the only way to define withComparator[T](equiv: math.Equiv[T], dataType: DataType)
// without being ambiguous with withComparator(equiv: math.Equiv[Any], dataType: DataType, dataTypes: DataType*)
@varargs
def withComparator[T](equiv: math.Equiv[T], dataType: DataType, dataTypes: DataType*): DiffOptions =
- (dataType +: dataTypes).foldLeft(this)( (options, dataType) =>
+ (dataType +: dataTypes).foldLeft(this)((options, dataType) =>
options.withComparator(EquivDiffComparator(equiv, dataType))
)
/**
- * Fluent method to add an equivalent operator as a comparator for one or more column names.
- * Returns a new immutable DiffOptions instance with the new comparator.
- * @note The `math.Equiv` will not be given any null values. Null-aware comparators
- * can only be implemented via the `DiffComparator` interface.
- * @return new immutable DiffOptions instance
+ * Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable
+ * DiffOptions instance with the new comparator.
+ * @note
+ * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the
+ * `DiffComparator` interface.
+ * @return
+ * new immutable DiffOptions instance
*/
@varargs
def withComparator(equiv: math.Equiv[Any], columnName: String, columnNames: String*): DiffOptions =
withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*)
private[diff] def comparatorFor(column: StructField): DiffComparator =
- columnNameComparators.get(column.name)
+ columnNameComparators
+ .get(column.name)
.orElse(dataTypeComparators.get(column.dataType))
.getOrElse(defaultComparator)
}
object DiffOptions {
+
/**
* Default diffing options.
*/
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala
index fed6b2e3..4ebfc79e 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala
@@ -25,25 +25,30 @@ import uk.co.gresearch.spark.diff.comparator.DurationDiffComparator.isNotSupport
import java.time.Duration
/**
- * Compares two timestamps and considers them equal when they are less than
- * (or equal to when inclusive = true) a given duration apart.
+ * Compares two timestamps and considers them equal when they are less than (or equal to when inclusive = true) a given
+ * duration apart.
*
- * @param duration equality threshold
- * @param inclusive duration is considered equal when true
+ * @param duration
+ * equality threshold
+ * @param inclusive
+ * duration is considered equal when true
*/
case class DurationDiffComparator(duration: Duration, inclusive: Boolean = true) extends DiffComparator {
if (isNotSupportedBySpark) {
- throw new UnsupportedOperationException(s"java.time.Duration is not supported by Spark ${spark.SparkCompatVersionString}")
+ throw new UnsupportedOperationException(
+ s"java.time.Duration is not supported by Spark ${spark.SparkCompatVersionString}"
+ )
}
override def equiv(left: Column, right: Column): Column = {
- val inDuration = if (inclusive)
- (diff: Column) => diff <= duration
- else
- (diff: Column) => diff < duration
+ val inDuration =
+ if (inclusive)
+ (diff: Column) => diff <= duration
+ else
+ (diff: Column) => diff < duration
left.isNull && right.isNull ||
- left.isNotNull && right.isNotNull && inDuration(abs(left - right))
+ left.isNotNull && right.isNotNull && inDuration(abs(left - right))
}
def asInclusive(): DurationDiffComparator = if (inclusive) this else copy(inclusive = true)
@@ -52,5 +57,5 @@ case class DurationDiffComparator(duration: Duration, inclusive: Boolean = true)
object DurationDiffComparator extends SparkVersion {
val isSupportedBySpark: Boolean = SparkMajorVersion == 3 && SparkMinorVersion >= 3 || SparkMajorVersion > 3
- val isNotSupportedBySpark: Boolean = ! isSupportedBySpark
+ val isNotSupportedBySpark: Boolean = !isSupportedBySpark
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala
index a9d7658a..b95196e1 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala
@@ -20,17 +20,19 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{abs, greatest}
case class EpsilonDiffComparator(epsilon: Double, relative: Boolean = true, inclusive: Boolean = true)
- extends DiffComparator {
+ extends DiffComparator {
override def equiv(left: Column, right: Column): Column = {
- val threshold = if (relative)
- greatest(abs(left), abs(right)) * epsilon
- else
- epsilon
+ val threshold =
+ if (relative)
+ greatest(abs(left), abs(right)) * epsilon
+ else
+ epsilon
- val inEpsilon = if (inclusive)
- (diff: Column) => diff <= threshold
- else
- (diff: Column) => diff < threshold
+ val inEpsilon =
+ if (inclusive)
+ (diff: Column) => diff <= threshold
+ else
+ (diff: Column) => diff < threshold
left.isNull && right.isNull || left.isNotNull && right.isNotNull && inEpsilon(abs(left - right))
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala
index 4d26d393..6a9dc99b 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala
@@ -38,21 +38,23 @@ private trait ExpressionEquivDiffComparator[T] extends EquivDiffComparator[T] {
trait TypedEquivDiffComparator[T] extends EquivDiffComparator[T] with TypedDiffComparator
private[comparator] trait TypedEquivDiffComparatorWithInput[T]
- extends ExpressionEquivDiffComparator[T] with TypedEquivDiffComparator[T] {
+ extends ExpressionEquivDiffComparator[T]
+ with TypedEquivDiffComparator[T] {
def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType)
}
private[comparator] case class InputTypedEquivDiffComparator[T](equiv: math.Equiv[T], inputType: DataType)
- extends TypedEquivDiffComparatorWithInput[T]
-
+ extends TypedEquivDiffComparatorWithInput[T]
object EquivDiffComparator {
- def apply[T : Encoder](equiv: math.Equiv[T]): TypedEquivDiffComparator[T] = EncoderEquivDiffComparator(equiv)
- def apply[T](equiv: math.Equiv[T], inputType: DataType): TypedEquivDiffComparator[T] = InputTypedEquivDiffComparator(equiv, inputType)
+ def apply[T: Encoder](equiv: math.Equiv[T]): TypedEquivDiffComparator[T] = EncoderEquivDiffComparator(equiv)
+ def apply[T](equiv: math.Equiv[T], inputType: DataType): TypedEquivDiffComparator[T] =
+ InputTypedEquivDiffComparator(equiv, inputType)
def apply(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivAnyDiffComparator(equiv)
- private case class EncoderEquivDiffComparator[T : Encoder](equiv: math.Equiv[T])
- extends ExpressionEquivDiffComparator[T] with TypedEquivDiffComparator[T] {
+ private case class EncoderEquivDiffComparator[T: Encoder](equiv: math.Equiv[T])
+ extends ExpressionEquivDiffComparator[T]
+ with TypedEquivDiffComparator[T] {
override def inputType: DataType = encoderFor[T].schema.fields(0).dataType
def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType)
}
@@ -85,9 +87,12 @@ private trait EquivExpression[T] extends BinaryExpression with BinaryLikeWithNew
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val equivRef = ctx.addReferenceObj("equiv", equiv, math.Equiv.getClass.getName.stripSuffix("$"))
- ev.copy(code = eval1.code + eval2.code + code"""
+ ev.copy(
+ code = eval1.code + eval2.code + code"""
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
- (!${eval1.isNull} && !${eval2.isNull} && $equivRef.equiv(${eval1.value}, ${eval2.value}));""", isNull = FalseLiteral)
+ (!${eval1.isNull} && !${eval2.isNull} && $equivRef.equiv(${eval1.value}, ${eval2.value}));""",
+ isNull = FalseLiteral
+ )
}
}
@@ -100,13 +105,12 @@ private trait EquivOperator[T] extends BinaryOperator with EquivExpression[T] {
}
private case class Equiv[T](left: Expression, right: Expression, equiv: math.Equiv[T], equivInputType: DataType)
- extends EquivOperator[T] {
+ extends EquivOperator[T] {
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Equiv[T] =
- copy(left=newLeft, right=newRight)
+ copy(left = newLeft, right = newRight)
}
-private case class EquivAny(left: Expression, right: Expression, equiv: math.Equiv[Any])
- extends EquivExpression[Any] {
+private case class EquivAny(left: Expression, right: Expression, equiv: math.Equiv[Any]) extends EquivExpression[Any] {
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): EquivAny =
copy(left = newLeft, right = newRight)
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala
index d2ca79fc..71c010d8 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala
@@ -27,7 +27,8 @@ case class MapDiffComparator[K, V](private val comparator: EquivDiffComparator[U
override def equiv(left: Column, right: Column): Column = comparator.equiv(left, right)
}
-private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean) extends math.Equiv[UnsafeMapData] {
+private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean)
+ extends math.Equiv[UnsafeMapData] {
override def equiv(left: UnsafeMapData, right: UnsafeMapData): Boolean = {
val leftKeys: Array[K] = left.keyArray().toArray(keyType)
@@ -42,15 +43,20 @@ private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: Da
// can only be evaluated when right has same keys as left
lazy val valuesAreEqual = leftKeysIndices
.map { case (key, index) => index -> rightKeysIndices(key) }
- .map { case (leftIndex, rightIndex) => (leftIndex, rightIndex, leftValues.isNullAt(leftIndex), rightValues.isNullAt(rightIndex)) }
+ .map { case (leftIndex, rightIndex) =>
+ (leftIndex, rightIndex, leftValues.isNullAt(leftIndex), rightValues.isNullAt(rightIndex))
+ }
.map { case (leftIndex, rightIndex, leftIsNull, rightIsNull) =>
leftIsNull && rightIsNull ||
- !leftIsNull && !rightIsNull && leftValues.get(leftIndex, valueType).equals(rightValues.get(rightIndex, valueType))
+ !leftIsNull && !rightIsNull && leftValues
+ .get(leftIndex, valueType)
+ .equals(rightValues.get(rightIndex, valueType))
}
left.numElements() == right.numElements() &&
- (keyOrderSensitive && leftKeys.sameElements(rightKeys) || !keyOrderSensitive && leftKeys.toSet.diff(rightKeys.toSet).isEmpty) &&
- valuesAreEqual.forall(identity)
+ (keyOrderSensitive && leftKeys
+ .sameElements(rightKeys) || !keyOrderSensitive && leftKeys.toSet.diff(rightKeys.toSet).isEmpty) &&
+ valuesAreEqual.forall(identity)
}
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala
index ab243ec7..8b107ce6 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala
@@ -20,7 +20,10 @@ import org.apache.spark.unsafe.types.UTF8String
case object WhitespaceDiffComparator extends TypedEquivDiffComparatorWithInput[UTF8String] with StringDiffComparator {
override val equiv: scala.Equiv[UTF8String] = (x: UTF8String, y: UTF8String) =>
- x.trimAll().toString.replaceAll("\\s+"," ").equals(
- y.trimAll().toString.replaceAll("\\s+"," ")
- )
+ x.trimAll()
+ .toString
+ .replaceAll("\\s+", " ")
+ .equals(
+ y.trimAll().toString.replaceAll("\\s+", " ")
+ )
}
diff --git a/src/main/scala/uk/co/gresearch/spark/diff/package.scala b/src/main/scala/uk/co/gresearch/spark/diff/package.scala
index 53f47e72..e13802db 100644
--- a/src/main/scala/uk/co/gresearch/spark/diff/package.scala
+++ b/src/main/scala/uk/co/gresearch/spark/diff/package.scala
@@ -26,22 +26,21 @@ package object diff {
implicit class DatasetDiff[T](ds: Dataset[T]) {
/**
- * Returns a new DataFrame that contains the differences between this and the other Dataset of
- * the same type `T`. Both Datasets must contain the same set of column names and data types.
- * The order of columns in the two Datasets is not important as one column is compared to the
- * column with the same name of the other Dataset, not the column with the same position.
+ * Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`.
+ * Both Datasets must contain the same set of column names and data types. The order of columns in the two Datasets
+ * is not important as one column is compared to the column with the same name of the other Dataset, not the column
+ * with the same position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between this and the other Dataset, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the other Dataset, that do not exist in this Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of this Dataset, that
- * do not exist in the other Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between this and the other Dataset, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of
+ * the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as
+ * `"I"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `"D"`elete.
*
- * If no id columns are given (empty sequence), all columns are considered id columns. Then,
- * no `"C"`hange rows will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given (empty sequence), all columns are considered id columns. Then, no `"C"`hange rows will
+ * appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -74,12 +73,11 @@ package object diff {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset
+ * are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*
- * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a
- * column name containing a dot. This is not interpreted as a column "a" with a field "field" (struct).
+ * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a column name
+ * containing a dot. This is not interpreted as a column "a" with a field "field" (struct).
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
def diff(other: Dataset[T], idColumns: String*): DataFrame = {
@@ -87,24 +85,23 @@ package object diff {
}
/**
- * Returns a new DataFrame that contains the differences between two Datasets of
- * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types,
- * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as
- * columns are compared based on the name, not the the position.
+ * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both
+ * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The
+ * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the
+ * position.
*
- * Optional id columns are used to uniquely identify rows to compare. If values in any non-id
- * column are differing between this and the other Dataset, then that row is marked as `"C"`hange
- * and `"N"`o-change otherwise. Rows of the other Dataset, that do not exist in this Dataset
- * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of this Dataset, that
- * do not exist in the other Dataset are marked as `"D"`elete.
+ * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing
+ * between this and the other Dataset, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of
+ * the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as
+ * `"I"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `"D"`elete.
*
- * If no id columns are given (empty sequence), all columns are considered id columns. Then,
- * no `"C"`hange rows will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
+ * If no id columns are given (empty sequence), all columns are considered id columns. Then, no `"C"`hange rows will
+ * appear, as all changes will exists as respective `"D"`elete and `"I"`nsert.
*
* Values in optional ignore columns are not compared but included in the output DataFrame.
*
- * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`,
- * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns).
+ * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"`
+ * strings. The id columns follow, then the non-id columns (all remaining columns).
*
* {{{
* val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value")
@@ -137,20 +134,18 @@ package object diff {
*
* }}}
*
- * The id columns are in order as given to the method. If no id columns are given then all
- * columns of this Dataset are id columns and appear in the same order. The remaining non-id
- * columns are in the order of this Dataset.
+ * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset
+ * are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset.
*
- * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a
- * column name containing a dot. This is not interpreted as a column "a" with a field "field" (struct).
+ * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a column name
+ * containing a dot. This is not interpreted as a column "a" with a field "field" (struct).
*/
def diff[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = {
Diff.of(this.ds, other, idColumns, ignoreColumns)
}
/**
- * Returns a new DataFrame that contains the differences
- * between this and the other Dataset of the same type `T`.
+ * Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
@@ -162,76 +157,80 @@ package object diff {
}
/**
- * Returns a new DataFrame that contains the differences
- * between this and the other Dataset of similar types `T` and `U`.
+ * Returns a new DataFrame that contains the differences between this and the other Dataset of similar types `T` and
+ * `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* The schema of the returned DataFrame can be configured by the given `DiffOptions`.
*/
- def diff[U](other: Dataset[U], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = {
+ def diff[U](
+ other: Dataset[U],
+ options: DiffOptions,
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): DataFrame = {
new Differ(options).diff(this.ds, other, idColumns, ignoreColumns)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of the same type `T`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
* This requires an additional implicit `Encoder[U]` for the return type `Dataset[U]`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
- def diffAs[V](other: Dataset[T], idColumns: String*)
- (implicit diffEncoder: Encoder[V]): Dataset[V] = {
+ def diffAs[V](other: Dataset[T], idColumns: String*)(implicit diffEncoder: Encoder[V]): Dataset[V] = {
Diff.ofAs(this.ds, other, idColumns: _*)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
+ * `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])
- (implicit diffEncoder: Encoder[V]): Dataset[V] = {
+ def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit
+ diffEncoder: Encoder[V]
+ ): Dataset[V] = {
Diff.ofAs(this.ds, other, idColumns, ignoreColumns)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of the same type `T`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
- * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
- * The schema of the returned Dataset can be configured by the given `DiffOptions`.
+ * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
+ * Dataset can be configured by the given `DiffOptions`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
- def diffAs[V](other: Dataset[T], options: DiffOptions, idColumns: String*)
- (implicit diffEncoder: Encoder[V]): Dataset[V] = {
+ def diffAs[V](other: Dataset[T], options: DiffOptions, idColumns: String*)(implicit
+ diffEncoder: Encoder[V]
+ ): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, idColumns: _*)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
+ * `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
- * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`.
- * The schema of the returned Dataset can be configured by the given `DiffOptions`.
+ * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
+ * Dataset can be configured by the given `DiffOptions`.
*/
- def diffAs[U, V](other: Dataset[T], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String])
- (implicit diffEncoder: Encoder[V]): Dataset[V] = {
+ def diffAs[U, V](other: Dataset[T], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String])(
+ implicit diffEncoder: Encoder[V]
+ ): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, idColumns, ignoreColumns)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of the same type `T`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
@@ -243,117 +242,114 @@ package object diff {
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
+ * `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
*/
- def diffAs[U, V](other: Dataset[U], diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] = {
+ def diffAs[U, V](
+ other: Dataset[U],
+ diffEncoder: Encoder[V],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[V] = {
Diff.ofAs(this.ds, other, diffEncoder, idColumns, ignoreColumns)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of the same type `T`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`.
*
* See `diff(Dataset[T], String*)`.
*
- * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
- * The schema of the returned Dataset can be configured by the given `DiffOptions`.
+ * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
+ * Dataset can be configured by the given `DiffOptions`.
*/
// no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java
- def diffAs[V](other: Dataset[T],
- options: DiffOptions,
- diffEncoder: Encoder[V],
- idColumns: String*): Dataset[V] = {
+ def diffAs[V](other: Dataset[T], options: DiffOptions, diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns: _*)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of similar types `T` and `U`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
+ * `U`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
- * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`.
- * The schema of the returned Dataset can be configured by the given `DiffOptions`.
+ * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned
+ * Dataset can be configured by the given `DiffOptions`.
*/
- def diffAs[U, V](other: Dataset[U],
- options: DiffOptions,
- diffEncoder: Encoder[V],
- idColumns: Seq[String],
- ignoreColumns: Seq[String]): Dataset[V] = {
+ def diffAs[U, V](
+ other: Dataset[U],
+ options: DiffOptions,
+ diffEncoder: Encoder[V],
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[V] = {
new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns, ignoreColumns)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of the same type `T`
- * as tuples of type `(String, T, T)`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as
+ * tuples of type `(String, T, T)`.
*
* See `diff(Dataset[T], Seq[String])`.
*/
- def diffWith(other: Dataset[T],
- idColumns: String*): Dataset[(String, T, T)] =
+ def diffWith(other: Dataset[T], idColumns: String*): Dataset[(String, T, T)] =
Diff.default.diffWith(this.ds, other, idColumns: _*)
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of similar types `T` and `U`
- * as tuples of type `(String, T, U)`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
+ * `U` as tuples of type `(String, T, U)`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*/
- def diffWith[U](other: Dataset[U],
- idColumns: Seq[String],
- ignoreColumns: Seq[String]): Dataset[(String, T, U)] =
+ def diffWith[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] =
Diff.default.diffWith(this.ds, other, idColumns, ignoreColumns)
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of the same type `T`
- * as tuples of type `(String, T, T)`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as
+ * tuples of type `(String, T, T)`.
*
* See `diff(Dataset[T], String*)`.
*
* The schema of the returned Dataset can be configured by the given `DiffOptions`.
*/
- def diffWith(other: Dataset[T],
- options: DiffOptions,
- idColumns: String*): Dataset[(String, T, T)] = {
+ def diffWith(other: Dataset[T], options: DiffOptions, idColumns: String*): Dataset[(String, T, T)] = {
new Differ(options).diffWith(this.ds, other, idColumns: _*)
}
/**
- * Returns a new Dataset that contains the differences
- * between this and the other Dataset of similar types `T` and `U`.
- * as tuples of type `(String, T, T)`.
+ * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and
+ * `U`. as tuples of type `(String, T, T)`.
*
* See `diff(Dataset[U], Seq[String], Seq[String])`.
*
* The schema of the returned Dataset can be configured by the given `DiffOptions`.
*/
- def diffWith[U](other: Dataset[U],
- options: DiffOptions,
- idColumns: Seq[String],
- ignoreColumns: Seq[String]): Dataset[(String, T, U)] = {
+ def diffWith[U](
+ other: Dataset[U],
+ options: DiffOptions,
+ idColumns: Seq[String],
+ ignoreColumns: Seq[String]
+ ): Dataset[(String, T, U)] = {
new Differ(options).diffWith(this.ds, other, idColumns, ignoreColumns)
}
}
/**
- * Produces a column name that considers configured case-sensitivity of column names.
- * When case sensitivity is deactivated, it lower-cases the given column name and no-ops otherwise.
+ * Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is
+ * deactivated, it lower-cases the given column name and no-ops otherwise.
*
- * @param columnName column name
- * @return case sensitive or insensitive column name
+ * @param columnName
+ * column name
+ * @return
+ * case sensitive or insensitive column name
*/
private[diff] def handleConfiguredCaseSensitivity(columnName: String): String =
if (SQLConf.get.caseSensitiveAnalysis) columnName else columnName.toLowerCase(Locale.ROOT)
-
implicit class CaseInsensitiveSeq(seq: Seq[String]) {
def containsCaseSensitivity(string: String): Boolean =
seq.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string))
@@ -372,7 +368,8 @@ package object diff {
implicit class CaseInsensitiveArray(array: Array[String]) {
def containsCaseSensitivity(string: String): Boolean =
array.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string))
- def filterIsInCaseSensitivity(other: Iterable[String]): Array[String] = array.toSeq.filterIsInCaseSensitivity(other).toArray
+ def filterIsInCaseSensitivity(other: Iterable[String]): Array[String] =
+ array.toSeq.filterIsInCaseSensitivity(other).toArray
def diffCaseSensitivity(other: Iterable[String]): Array[String] = array.toSeq.diffCaseSensitivity(other).toArray
}
diff --git a/src/main/scala/uk/co/gresearch/spark/group/package.scala b/src/main/scala/uk/co/gresearch/spark/group/package.scala
index 5607ee1d..41063b1a 100644
--- a/src/main/scala/uk/co/gresearch/spark/group/package.scala
+++ b/src/main/scala/uk/co/gresearch/spark/group/package.scala
@@ -23,54 +23,50 @@ import uk.co.gresearch.ExtendedAny
package object group {
/**
- * This is a Dataset of key-value tuples, that provide a flatMap function over the individual groups,
- * while providing a sorted iterator over group values.
+ * This is a Dataset of key-value tuples, that provide a flatMap function over the individual groups, while providing
+ * a sorted iterator over group values.
*
- * The key-value Dataset given the constructor has to be partitioned by the key
- * and sorted within partitions by the key and value.
+ * The key-value Dataset given the constructor has to be partitioned by the key and sorted within partitions by the
+ * key and value.
*
- * @param ds the properly partitioned and sorted dataset
- * @tparam K type of the keys with ordering and encoder
- * @tparam V type of the values with encoder
+ * @param ds
+ * the properly partitioned and sorted dataset
+ * @tparam K
+ * type of the keys with ordering and encoder
+ * @tparam V
+ * type of the values with encoder
*/
- case class SortedGroupByDataset[K: Ordering : Encoder, V: Encoder] private (ds: Dataset[(K, V)]) {
+ case class SortedGroupByDataset[K: Ordering: Encoder, V: Encoder] private (ds: Dataset[(K, V)]) {
+
/**
- * (Scala-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and a sorted iterator that contains all of the elements in the group.
- * The function can return an iterator containing elements of an arbitrary type which will be
- * returned as a new [[Dataset]].
+ * (Scala-specific) Applies the given function to each group of data. For each unique group, the function will be
+ * passed the group key and a sorted iterator that contains all of the elements in the group. The function can
+ * return an iterator containing elements of an arbitrary type which will be returned as a new [[Dataset]].
*
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
+ * This function does not support partial aggregation, and as a result requires shuffling all the data in the
+ * [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce
+ * function or an `org.apache.spark.sql.expressions#Aggregator`.
*
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
+ * Internally, the implementation will spill to disk if any given group is too large to fit into memory. However,
+ * users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`)
+ * unless they are sure that this is possible given the memory constraints of their cluster.
*/
def flatMapSortedGroups[W: Encoder](func: (K, Iterator[V]) => TraversableOnce[W]): Dataset[W] =
ds.mapPartitions(new GroupedIterator(_).flatMap(v => func(v._1, v._2)))
/**
- * (Scala-specific)
- * Applies the given function to each group of data. For each unique group, the function s will
- * be passed the group key to create a state instance, while the function func will be passed
- * that state instance and group values in sequence according to the sort order in the groups.
- * The function func can return an iterator containing elements of an arbitrary type which will
- * be returned as a new [[Dataset]].
+ * (Scala-specific) Applies the given function to each group of data. For each unique group, the function s will be
+ * passed the group key to create a state instance, while the function func will be passed that state instance and
+ * group values in sequence according to the sort order in the groups. The function func can return an iterator
+ * containing elements of an arbitrary type which will be returned as a new [[Dataset]].
*
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
+ * This function does not support partial aggregation, and as a result requires shuffling all the data in the
+ * [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce
+ * function or an `org.apache.spark.sql.expressions#Aggregator`.
*
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
+ * Internally, the implementation will spill to disk if any given group is too large to fit into memory. However,
+ * users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`)
+ * unless they are sure that this is possible given the memory constraints of their cluster.
*/
def flatMapSortedGroups[S, W: Encoder](s: K => S)(func: (S, V) => TraversableOnce[W]): Dataset[W] = {
ds.mapPartitions(new GroupedIterator(_).flatMap { case (k, it) =>
@@ -81,10 +77,12 @@ package object group {
}
object SortedGroupByDataset {
- def apply[K: Ordering : Encoder, V](ds: Dataset[V],
- groupColumns: Seq[Column],
- orderColumns: Seq[Column],
- partitions: Option[Int]): SortedGroupByDataset[K, V] = {
+ def apply[K: Ordering: Encoder, V](
+ ds: Dataset[V],
+ groupColumns: Seq[Column],
+ orderColumns: Seq[Column],
+ partitions: Option[Int]
+ ): SortedGroupByDataset[K, V] = {
// make ds encoder implicitly available
implicit val valueEncoder: Encoder[V] = ds.encoder
@@ -115,11 +113,13 @@ package object group {
SortedGroupByDataset(grouped)
}
- def apply[K: Ordering : Encoder, V, O: Encoder](ds: Dataset[V],
- key: V => K,
- order: V => O,
- partitions: Option[Int],
- reverse: Boolean): SortedGroupByDataset[K, V] = {
+ def apply[K: Ordering: Encoder, V, O: Encoder](
+ ds: Dataset[V],
+ key: V => K,
+ order: V => O,
+ partitions: Option[Int],
+ reverse: Boolean
+ ): SortedGroupByDataset[K, V] = {
// prepare encoder needed for this exercise
val keyEncoder: Encoder[K] = implicitly[Encoder[K]]
implicit val valueEncoder: Encoder[V] = ds.encoder
diff --git a/src/main/scala/uk/co/gresearch/spark/package.scala b/src/main/scala/uk/co/gresearch/spark/package.scala
index aa644886..a1396f9d 100644
--- a/src/main/scala/uk/co/gresearch/spark/package.scala
+++ b/src/main/scala/uk/co/gresearch/spark/package.scala
@@ -32,8 +32,10 @@ package object spark extends Logging with SparkVersion with BuildVersion {
/**
* Provides a prefix that makes any string distinct w.r.t. the given strings.
- * @param existing strings
- * @return distinct prefix
+ * @param existing
+ * strings
+ * @return
+ * distinct prefix
*/
private[spark] def distinctPrefixFor(existing: Seq[String]): String = {
"_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1)
@@ -41,8 +43,10 @@ package object spark extends Logging with SparkVersion with BuildVersion {
/**
* Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown.
- * @param prefix prefix string of temporary directory name
- * @return absolute path of temporary directory
+ * @param prefix
+ * prefix string of temporary directory name
+ * @return
+ * absolute path of temporary directory
*/
def createTemporaryDir(prefix: String): String = {
// SparkFiles.getRootDirectory() will be deleted on spark application shutdown
@@ -51,12 +55,15 @@ package object spark extends Logging with SparkVersion with BuildVersion {
// https://issues.apache.org/jira/browse/SPARK-40588
private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = {
- ds.sparkSession.conf.get(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key,
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString
- ).equalsIgnoreCase("true") && Some(ds.sparkSession.version).exists(ver =>
- Set("3.0.", "3.1.", "3.2.0", "3.2.1" ,"3.2.2", "3.3.0", "3.3.1").exists(pat =>
- if (pat.endsWith(".")) { ver.startsWith(pat) } else { ver.equals(pat) || ver.startsWith(pat + "-") }
+ ds.sparkSession.conf
+ .get(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key,
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString
+ )
+ .equalsIgnoreCase("true") && Some(ds.sparkSession.version).exists(ver =>
+ Set("3.0.", "3.1.", "3.2.0", "3.2.1", "3.2.2", "3.3.0", "3.3.1").exists(pat =>
+ if (pat.endsWith(".")) { ver.startsWith(pat) }
+ else { ver.equals(pat) || ver.startsWith(pat + "-") }
)
)
}
@@ -80,8 +87,10 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* col(backticks("a.column", "a.field")) // produces "`a.column`.`a.field`"
* }}}
*
- * @param string a string
- * @param strings more strings
+ * @param string
+ * a string
+ * @param strings
+ * more strings
*/
@scala.annotation.varargs
def backticks(string: String, strings: String*): String =
@@ -97,307 +106,308 @@ package object spark extends Logging with SparkVersion with BuildVersion {
private val unixEpochDotNetTicks: Long = 621355968000000000L
/**
- * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be
- * convertible to a number (e.g. string, int, long). The Spark timestamp type does not support
- * nanoseconds, so the the last digit of the timestamp (1/10 of a microsecond) is lost.
+ * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number
+ * (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the
+ * timestamp (1/10 of a microsecond) is lost.
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToTimestamp($"ticks").as("timestamp")).show(false)
* }}}
*
- * +------------------+--------------------------+
- * |ticks |timestamp |
- * +------------------+--------------------------+
- * |638155413748959318|2023-03-27 21:16:14.895931|
- * +------------------+--------------------------+
+ * | ticks | timestamp |
+ * |:-------------------|:---------------------------|
+ * | 638155413748959318 | 2023-03-27 21:16:14.895931 |
*
- * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to
- * preserve the full precision of the tick timestamp.
+ * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full
+ * precision of the tick timestamp.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param tickColumn column with a tick value
- * @return result timestamp column
+ * @param tickColumn
+ * column with a tick value
+ * @return
+ * result timestamp column
*/
def dotNetTicksToTimestamp(tickColumn: Column): Column =
dotNetTicksToUnixEpoch(tickColumn).cast(TimestampType)
/**
- * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be
- * convertible to a number (e.g. string, int, long). The Spark timestamp type does not support
- * nanoseconds, so the the last digit of the timestamp (1/10 of a microsecond) is lost.
+ * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number
+ * (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the
+ * timestamp (1/10 of a microsecond) is lost.
*
* {{{
* df.select($"ticks", dotNetTicksToTimestamp("ticks").as("timestamp")).show(false)
* }}}
*
- * +------------------+--------------------------+
- * |ticks |timestamp |
- * +------------------+--------------------------+
- * |638155413748959318|2023-03-27 21:16:14.895931|
- * +------------------+--------------------------+
+ * | ticks | timestamp |
+ * |:-------------------|:---------------------------|
+ * | 638155413748959318 | 2023-03-27 21:16:14.895931 |
*
- * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to
- * preserve the full precision of the tick timestamp.
+ * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full
+ * precision of the tick timestamp.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param tickColumnName name of a column with a tick value
- * @return result timestamp column
+ * @param tickColumnName
+ * name of a column with a tick value
+ * @return
+ * result timestamp column
*/
def dotNetTicksToTimestamp(tickColumnName: String): Column = dotNetTicksToTimestamp(col(tickColumnName))
/**
- * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be
- * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp
- * is preserved (1/10 of a microsecond).
+ * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be convertible to a number
+ * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpoch($"ticks").as("timestamp")).show(false)
* }}}
*
- * +------------------+--------------------+
- * |ticks |timestamp |
- * +------------------+--------------------+
- * |638155413748959318|1679944574.895931800|
- * +------------------+--------------------+
+ * | ticks | timestamp |
+ * |:-------------------|:---------------------|
+ * | 638155413748959318 | 1679944574.895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param tickColumn column with a tick value
- * @return result unix epoch seconds column as decimal
+ * @param tickColumn
+ * column with a tick value
+ * @return
+ * result unix epoch seconds column as decimal
*/
def dotNetTicksToUnixEpoch(tickColumn: Column): Column =
(tickColumn.cast(DecimalType(19, 0)) - unixEpochDotNetTicks) / dotNetTicksPerSecond
/**
- * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be
- * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp
- * is preserved (1/10 of a microsecond).
+ * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number
+ * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpoch("ticks").as("timestamp")).show(false)
* }}}
*
- * +------------------+--------------------+
- * |ticks |timestamp |
- * +------------------+--------------------+
- * |638155413748959318|1679944574.895931800|
- * +------------------+--------------------+
+ * | ticks | timestamp |
+ * |:-------------------|:---------------------|
+ * | 638155413748959318 | 1679944574.895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param tickColumnName name of column with a tick value
- * @return result unix epoch seconds column as decimal
+ * @param tickColumnName
+ * name of column with a tick value
+ * @return
+ * result unix epoch seconds column as decimal
*/
def dotNetTicksToUnixEpoch(tickColumnName: String): Column = dotNetTicksToUnixEpoch(col(tickColumnName))
/**
- * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be
- * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp
- * is preserved (1/10 of a microsecond).
+ * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number
+ * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpochNanos($"ticks").as("timestamp")).show(false)
* }}}
*
- * +------------------+-------------------+
- * |ticks |timestamp |
- * +------------------+-------------------+
- * |638155413748959318|1679944574895931800|
- * +------------------+-------------------+
+ * | ticks | timestamp |
+ * |:-------------------|:--------------------|
+ * | 638155413748959318 | 1679944574895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param tickColumn column with a tick value
- * @return result unix epoch nanoseconds column as long
+ * @param tickColumn
+ * column with a tick value
+ * @return
+ * result unix epoch nanoseconds column as long
*/
def dotNetTicksToUnixEpochNanos(tickColumn: Column): Column = {
- when(tickColumn <= 713589688368547758L, (tickColumn.cast(LongType) - unixEpochDotNetTicks) * nanoSecondsPerDotNetTick)
+ when(
+ tickColumn <= 713589688368547758L,
+ (tickColumn.cast(LongType) - unixEpochDotNetTicks) * nanoSecondsPerDotNetTick
+ )
}
/**
- * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be
- * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp
- * is preserved (1/10 of a microsecond).
+ * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be convertible to a
+ * number (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond).
*
* Example:
* {{{
* df.select($"ticks", dotNetTicksToUnixEpochNanos("ticks").as("timestamp")).show(false)
* }}}
*
- * +------------------+-------------------+
- * |ticks |timestamp |
- * +------------------+-------------------+
- * |638155413748959318|1679944574895931800|
- * +------------------+-------------------+
+ * | ticks | timestamp |
+ * |:-------------------|:--------------------|
+ * | 638155413748959318 | 1679944574895931800 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param tickColumnName name of column with a tick value
- * @return result unix epoch nanoseconds column as long
+ * @param tickColumnName
+ * name of column with a tick value
+ * @return
+ * result unix epoch nanoseconds column as long
*/
def dotNetTicksToUnixEpochNanos(tickColumnName: String): Column = dotNetTicksToUnixEpochNanos(col(tickColumnName))
/**
- * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp.
- * The input column must be of TimestampType.
+ * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType.
*
* Example:
* {{{
* df.select($"timestamp", timestampToDotNetTicks($"timestamp").as("ticks")).show(false)
* }}}
*
- * +--------------------------+------------------+
- * |timestamp |ticks |
- * +--------------------------+------------------+
- * |2023-03-27 21:16:14.895931|638155413748959310|
- * +--------------------------+------------------+
+ * | timestamp | ticks |
+ * |:---------------------------|:-------------------|
+ * | 2023-03-27 21:16:14.895931 | 638155413748959310 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param timestampColumn column with a timestamp value
- * @return result tick value column
+ * @param timestampColumn
+ * column with a timestamp value
+ * @return
+ * result tick value column
*/
def timestampToDotNetTicks(timestampColumn: Column): Column =
unixEpochTenthMicrosToDotNetTicks(new Column(UnixMicros.unixMicros(timestampColumn.expr)) * 10)
/**
- * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp.
- * The input column must be of TimestampType.
+ * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType.
*
* Example:
* {{{
* df.select($"timestamp", timestampToDotNetTicks("timestamp").as("ticks")).show(false)
* }}}
*
- * +--------------------------+------------------+
- * |timestamp |ticks |
- * +--------------------------+------------------+
- * |2023-03-27 21:16:14.895931|638155413748959310|
- * +--------------------------+------------------+
+ * | timestamp | ticks |
+ * |:---------------------------|:-------------------|
+ * | 2023-03-27 21:16:14.895931 | 638155413748959310 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param timestampColumnName name of column with a timestamp value
- * @return result tick value column
+ * @param timestampColumnName
+ * name of column with a timestamp value
+ * @return
+ * result tick value column
*/
def timestampToDotNetTicks(timestampColumnName: String): Column = timestampToDotNetTicks(col(timestampColumnName))
/**
- * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp.
- * The input column must represent a numerical unix epoch timestamp, e.g. long, double, string or decimal.
- * The input must not be of TimestampType, as that may be interpreted incorrectly.
- * Use `timestampToDotNetTicks` for TimestampType columns instead.
+ * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical
+ * unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be
+ * interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead.
*
* Example:
* {{{
* df.select($"unix", unixEpochToDotNetTicks($"unix").as("ticks")).show(false)
* }}}
*
- * +-----------------------------+------------------+
- * |unix |ticks |
- * +-----------------------------+------------------+
- * |1679944574.895931234000000000|638155413748959312|
- * +-----------------------------+------------------+
+ * | unix | ticks |
+ * |:------------------------------|:-------------------|
+ * | 1679944574.895931234000000000 | 638155413748959312 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param unixTimeColumn column with a unix epoch timestamp value
- * @return result tick value column
+ * @param unixTimeColumn
+ * column with a unix epoch timestamp value
+ * @return
+ * result tick value column
*/
- def unixEpochToDotNetTicks(unixTimeColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(unixTimeColumn.cast(DecimalType(19, 7)) * 10000000)
+ def unixEpochToDotNetTicks(unixTimeColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(
+ unixTimeColumn.cast(DecimalType(19, 7)) * 10000000
+ )
/**
- * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp.
- * The input column must represent a numerical unix epoch timestamp, e.g. long, double, string or decimal.
- * The input must not be of TimestampType, as that may be interpreted incorrectly.
- * Use `timestampToDotNetTicks` for TimestampType columns instead.
+ * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical
+ * unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be
+ * interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead.
*
* Example:
* {{{
* df.select($"unix", unixEpochToDotNetTicks("unix").as("ticks")).show(false)
* }}}
*
- * +-----------------------------+------------------+
- * |unix |ticks |
- * +-----------------------------+------------------+
- * |1679944574.895931234000000000|638155413748959312|
- * +-----------------------------+------------------+
+ * | unix | ticks |
+ * |:------------------------------|:-------------------|
+ * | 1679944574.895931234000000000 | 638155413748959312 |
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param unixTimeColumnName name of column with a unix epoch timestamp value
- * @return result tick value column
+ * @param unixTimeColumnName
+ * name of column with a unix epoch timestamp value
+ * @return
+ * result tick value column
*/
def unixEpochToDotNetTicks(unixTimeColumnName: String): Column = unixEpochToDotNetTicks(col(unixTimeColumnName))
/**
- * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp.
- * The .Net ticks timestamp does not support the two lowest nanosecond digits,
- * so only a 1/10 of a microsecond is the smallest resolution.
- * The input column must represent a numerical unix epoch nanoseconds timestamp,
- * e.g. long, double, string or decimal.
+ * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not
+ * support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input
+ * column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal.
*
* Example:
* {{{
* df.select($"unix_nanos", unixEpochNanosToDotNetTicks($"unix_nanos").as("ticks")).show(false)
* }}}
*
- * +-------------------+------------------+
- * |unix_nanos |ticks |
- * +-------------------+------------------+
- * |1679944574895931234|638155413748959312|
- * +-------------------+------------------+
+ * | unix_nanos | ticks |
+ * |:--------------------|:-------------------|
+ * | 1679944574895931234 | 638155413748959312 |
*
* Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param unixNanosColumn column with a unix epoch timestamp value
- * @return result tick value column
+ * @param unixNanosColumn
+ * column with a unix epoch timestamp value
+ * @return
+ * result tick value column
*/
- def unixEpochNanosToDotNetTicks(unixNanosColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(unixNanosColumn.cast(DecimalType(21, 0)) / nanoSecondsPerDotNetTick)
+ def unixEpochNanosToDotNetTicks(unixNanosColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(
+ unixNanosColumn.cast(DecimalType(21, 0)) / nanoSecondsPerDotNetTick
+ )
/**
- * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp.
- * The .Net ticks timestamp does not support the two lowest nanosecond digits,
- * so only a 1/10 of a microsecond is the smallest resolution.
- * The input column must represent a numerical unix epoch nanoseconds timestamp,
- * e.g. long, double, string or decimal.
+ * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not
+ * support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input
+ * column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal.
*
* Example:
* {{{
* df.select($"unix_nanos", unixEpochNanosToDotNetTicks($"unix_nanos").as("ticks")).show(false)
* }}}
*
- * +-------------------+------------------+
- * |unix_nanos |ticks |
- * +-------------------+------------------+
- * |1679944574895931234|638155413748959312|
- * +-------------------+------------------+
+ * | unix_nanos | ticks |
+ * |:--------------------|:-------------------|
+ * | 1679944574895931234 | 638155413748959312 |
*
* Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks.
*
* https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks
*
- * @param unixNanosColumnName name of column with a unix epoch timestamp value
- * @return result tick value column
+ * @param unixNanosColumnName
+ * name of column with a unix epoch timestamp value
+ * @return
+ * result tick value column
*/
- def unixEpochNanosToDotNetTicks(unixNanosColumnName: String): Column = unixEpochNanosToDotNetTicks(col(unixNanosColumnName))
+ def unixEpochNanosToDotNetTicks(unixNanosColumnName: String): Column = unixEpochNanosToDotNetTicks(
+ col(unixNanosColumnName)
+ )
- private def unixEpochTenthMicrosToDotNetTicks(unixNanosColumn: Column): Column = unixNanosColumn.cast(LongType) + unixEpochDotNetTicks
+ private def unixEpochTenthMicrosToDotNetTicks(unixNanosColumn: Column): Column =
+ unixNanosColumn.cast(LongType) + unixEpochDotNetTicks
/**
* Set the job description and return the earlier description. Only set the description if it is not set.
*
- * @param description job description
- * @param ifNotSet job description is only set if no description is set yet
- * @param context spark context
+ * @param description
+ * job description
+ * @param ifNotSet
+ * job description is only set if no description is set yet
+ * @param context
+ * spark context
* @return
*/
def setJobDescription(description: String, ifNotSet: Boolean = false)(implicit context: SparkContext): String = {
@@ -409,8 +419,8 @@ package object spark extends Logging with SparkVersion with BuildVersion {
}
/**
- * Adds a job description to all Spark jobs started within the given function.
- * The current Job description is restored after exit of the function.
+ * Adds a job description to all Spark jobs started within the given function. The current Job description is restored
+ * after exit of the function.
*
* Usage example:
*
@@ -427,16 +437,22 @@ package object spark extends Logging with SparkVersion with BuildVersion {
*
* With `ifNotSet == true`, the description is only set if no job description is set yet.
*
- * Any modification to the job description during execution of the function is reverted,
- * even if `ifNotSet == true`.
- *
- * @param description job description
- * @param ifNotSet job description is only set if no description is set yet
- * @param func code to execute while job description is set
- * @param session spark session
- * @tparam T return type of func
+ * Any modification to the job description during execution of the function is reverted, even if `ifNotSet == true`.
+ *
+ * @param description
+ * job description
+ * @param ifNotSet
+ * job description is only set if no description is set yet
+ * @param func
+ * code to execute while job description is set
+ * @param session
+ * spark session
+ * @tparam T
+ * return type of func
*/
- def withJobDescription[T](description: String, ifNotSet: Boolean = false)(func: => T)(implicit session: SparkSession): T = {
+ def withJobDescription[T](description: String, ifNotSet: Boolean = false)(
+ func: => T
+ )(implicit session: SparkSession): T = {
val earlierDescription = setJobDescription(description, ifNotSet)(session.sparkContext)
try {
func
@@ -448,9 +464,12 @@ package object spark extends Logging with SparkVersion with BuildVersion {
/**
* Append the job description and return the earlier description.
*
- * @param extraDescription job description
- * @param separator separator to join exiting and extra description with
- * @param context spark context
+ * @param extraDescription
+ * job description
+ * @param separator
+ * separator to join exiting and extra description with
+ * @param context
+ * spark context
* @return
*/
def appendJobDescription(extraDescription: String, separator: String, context: SparkContext): String = {
@@ -461,9 +480,9 @@ package object spark extends Logging with SparkVersion with BuildVersion {
}
/**
- * Appends a job description to all Spark jobs started within the given function.
- * The current Job description is extended by the separator and the extra description
- * on entering the function, and restored after exit of the function.
+ * Appends a job description to all Spark jobs started within the given function. The current Job description is
+ * extended by the separator and the extra description on entering the function, and restored after exit of the
+ * function.
*
* Usage example:
*
@@ -482,13 +501,20 @@ package object spark extends Logging with SparkVersion with BuildVersion {
*
* Any modification to the job description during execution of the function is reverted.
*
- * @param extraDescription job description to be appended
- * @param separator separator used when appending description
- * @param func code to execute while job description is set
- * @param session spark session
- * @tparam T return type of func
+ * @param extraDescription
+ * job description to be appended
+ * @param separator
+ * separator used when appending description
+ * @param func
+ * code to execute while job description is set
+ * @param session
+ * spark session
+ * @tparam T
+ * return type of func
*/
- def appendJobDescription[T](extraDescription: String, separator: String = " - ")(func: => T)(implicit session: SparkSession): T = {
+ def appendJobDescription[T](extraDescription: String, separator: String = " - ")(
+ func: => T
+ )(implicit session: SparkSession): T = {
val earlierDescription = appendJobDescription(extraDescription, separator, session.sparkContext)
try {
func
@@ -500,37 +526,61 @@ package object spark extends Logging with SparkVersion with BuildVersion {
/**
* Class to extend a Spark Dataset.
*
- * @param ds dataset
- * @tparam V inner type of dataset
+ * @param ds
+ * dataset
+ * @tparam V
+ * inner type of dataset
*/
- @deprecated("Constructor with encoder is deprecated, the encoder argument is ignored, ds.encoder is used instead.", since = "2.9.0")
+ @deprecated(
+ "Constructor with encoder is deprecated, the encoder argument is ignored, ds.encoder is used instead.",
+ since = "2.9.0"
+ )
class ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]) {
private val eds = ExtendedDatasetV2[V](ds)
def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =
eds.histogram(thresholds, valueColumn, aggregateColumns: _*)
- def writePartitionedBy(partitionColumns: Seq[Column],
- moreFileColumns: Seq[Column] = Seq.empty,
- moreFileOrder: Seq[Column] = Seq.empty,
- partitions: Option[Int] = None,
- writtenProjection: Option[Seq[Column]] = None,
- unpersistHandle: Option[UnpersistHandle] = None): DataFrameWriter[Row] =
- eds.writePartitionedBy(partitionColumns, moreFileColumns, moreFileOrder, partitions, writtenProjection, unpersistHandle)
+ def writePartitionedBy(
+ partitionColumns: Seq[Column],
+ moreFileColumns: Seq[Column] = Seq.empty,
+ moreFileOrder: Seq[Column] = Seq.empty,
+ partitions: Option[Int] = None,
+ writtenProjection: Option[Seq[Column]] = None,
+ unpersistHandle: Option[UnpersistHandle] = None
+ ): DataFrameWriter[Row] =
+ eds.writePartitionedBy(
+ partitionColumns,
+ moreFileColumns,
+ moreFileOrder,
+ partitions,
+ writtenProjection,
+ unpersistHandle
+ )
- def groupBySorted[K: Ordering : Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] =
+ def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] =
eds.groupBySorted(cols: _*)(order: _*)
- def groupBySorted[K: Ordering : Encoder](partitions: Int)(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] =
+ def groupBySorted[K: Ordering: Encoder](partitions: Int)(cols: Column*)(
+ order: Column*
+ ): SortedGroupByDataset[K, V] =
eds.groupBySorted(partitions)(cols: _*)(order: _*)
- def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O): SortedGroupByDataset[K, V] =
+ def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
+ order: V => O
+ ): SortedGroupByDataset[K, V] =
eds.groupByKeySorted(key, Some(partitions))(order)
- def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O, reverse: Boolean): SortedGroupByDataset[K, V] =
+ def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
+ order: V => O,
+ reverse: Boolean
+ ): SortedGroupByDataset[K, V] =
eds.groupByKeySorted(key, Some(partitions))(order, reverse)
- def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] =
+ def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(
+ order: V => O,
+ reverse: Boolean = false
+ ): SortedGroupByDataset[K, V] =
eds.groupByKeySorted(key, partitions)(order, reverse)
def withRowNumbers(order: Column*): DataFrame =
@@ -545,72 +595,74 @@ package object spark extends Logging with SparkVersion with BuildVersion {
def withRowNumbers(unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
eds.withRowNumbers(unpersistHandle, order: _*)
- def withRowNumbers(rowNumberColumnName: String,
- storageLevel: StorageLevel,
- order: Column*): DataFrame =
+ def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame =
eds.withRowNumbers(rowNumberColumnName, storageLevel, order: _*)
- def withRowNumbers(rowNumberColumnName: String,
- unpersistHandle: UnpersistHandle,
- order: Column*): DataFrame =
+ def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
eds.withRowNumbers(rowNumberColumnName, unpersistHandle, order: _*)
- def withRowNumbers(storageLevel: StorageLevel,
- unpersistHandle: UnpersistHandle,
- order: Column*): DataFrame =
+ def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
eds.withRowNumbers(storageLevel, unpersistHandle, order: _*)
- def withRowNumbers(rowNumberColumnName: String,
- storageLevel: StorageLevel,
- unpersistHandle: UnpersistHandle,
- order: Column*): DataFrame =
+ def withRowNumbers(
+ rowNumberColumnName: String,
+ storageLevel: StorageLevel,
+ unpersistHandle: UnpersistHandle,
+ order: Column*
+ ): DataFrame =
eds.withRowNumbers(rowNumberColumnName, storageLevel, unpersistHandle, order: _*)
}
/**
* Class to extend a Spark Dataset.
*
- * @param ds dataset
- * @tparam V inner type of dataset
+ * @param ds
+ * dataset
+ * @tparam V
+ * inner type of dataset
*/
def ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]): ExtendedDataset[V] = new ExtendedDataset(ds, encoder)
/**
* Implicit class to extend a Spark Dataset.
*
- * @param ds dataset
- * @tparam V inner type of dataset
+ * @param ds
+ * dataset
+ * @tparam V
+ * inner type of dataset
*/
implicit class ExtendedDatasetV2[V](ds: Dataset[V]) {
private implicit val encoder: Encoder[V] = ds.encoder
/**
- * Compute the histogram of a column when aggregated by aggregate columns.
- * Thresholds are expected to be provided in ascending order.
- * The result dataframe contains the aggregate and histogram columns only.
- * For each threshold value in thresholds, there will be a column named s"≤threshold".
- * There will also be a final column called s">last_threshold", that counts the remaining
- * values that exceed the last threshold.
- *
- * @param thresholds sequence of thresholds, must implement <= and > operators w.r.t. valueColumn
- * @param valueColumn histogram is computed for values of this column
- * @param aggregateColumns histogram is computed against these columns
- * @tparam T type of histogram thresholds
- * @return dataframe with aggregate and histogram columns
+ * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
+ * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
+ * in thresholds, there will be a column named s"≤threshold". There will also be a final column called
+ * s">last_threshold", that counts the remaining values that exceed the last threshold.
+ *
+ * @param thresholds
+ * sequence of thresholds, must implement <= and > operators w.r.t. valueColumn
+ * @param valueColumn
+ * histogram is computed for values of this column
+ * @param aggregateColumns
+ * histogram is computed against these columns
+ * @tparam T
+ * type of histogram thresholds
+ * @return
+ * dataframe with aggregate and histogram columns
*/
def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =
Histogram.of(ds, thresholds, valueColumn, aggregateColumns: _*)
/**
- * Writes the Dataset / DataFrame via DataFrameWriter.partitionBy. In addition to partitionBy,
- * this method sorts the data to improve partition file size. Small partitions will contain few
- * files, large partitions contain more files. Partition ids are contained in a single partition
- * file per `partitionBy` partition only. Rows within the partition files are also sorted,
- * if partitionOrder is defined.
+ * Writes the Dataset / DataFrame via DataFrameWriter.partitionBy. In addition to partitionBy, this method sorts the
+ * data to improve partition file size. Small partitions will contain few files, large partitions contain more
+ * files. Partition ids are contained in a single partition file per `partitionBy` partition only. Rows within the
+ * partition files are also sorted, if partitionOrder is defined.
*
- * Note: With Spark 3.0, 3.1, 3.2 before 3.2.3, 3.3 before 3.3.2, and AQE enabled, an intermediate DataFrame is being
- * cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588.
- * That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method.
+ * Note: With Spark 3.0, 3.1, 3.2 before 3.2.3, 3.3 before 3.3.2, and AQE enabled, an intermediate DataFrame is
+ * being cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588.
+ * That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method.
*
* Calling:
* {{{
@@ -638,20 +690,29 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* cached.unpersist
* }}}
*
- * @param partitionColumns columns used for partitioning
- * @param moreFileColumns columns where individual values are written to a single file
- * @param moreFileOrder additional columns to sort partition files
- * @param partitions optional number of partition files
- * @param writtenProjection additional transformation to be applied before calling write
- * @param unpersistHandle handle to unpersist internally created DataFrame after writing
- * @return configured DataFrameWriter
+ * @param partitionColumns
+ * columns used for partitioning
+ * @param moreFileColumns
+ * columns where individual values are written to a single file
+ * @param moreFileOrder
+ * additional columns to sort partition files
+ * @param partitions
+ * optional number of partition files
+ * @param writtenProjection
+ * additional transformation to be applied before calling write
+ * @param unpersistHandle
+ * handle to unpersist internally created DataFrame after writing
+ * @return
+ * configured DataFrameWriter
*/
- def writePartitionedBy(partitionColumns: Seq[Column],
- moreFileColumns: Seq[Column] = Seq.empty,
- moreFileOrder: Seq[Column] = Seq.empty,
- partitions: Option[Int] = None,
- writtenProjection: Option[Seq[Column]] = None,
- unpersistHandle: Option[UnpersistHandle] = None): DataFrameWriter[Row] = {
+ def writePartitionedBy(
+ partitionColumns: Seq[Column],
+ moreFileColumns: Seq[Column] = Seq.empty,
+ moreFileOrder: Seq[Column] = Seq.empty,
+ partitions: Option[Int] = None,
+ writtenProjection: Option[Seq[Column]] = None,
+ unpersistHandle: Option[UnpersistHandle] = None
+ ): DataFrameWriter[Row] = {
if (partitionColumns.isEmpty)
throw new IllegalArgumentException(s"partition columns must not be empty")
@@ -661,15 +722,19 @@ package object spark extends Logging with SparkVersion with BuildVersion {
val requiresCaching = writePartitionedByRequiresCaching(ds)
(requiresCaching, unpersistHandle.isDefined) match {
case (true, false) =>
- warning("Partitioned-writing with AQE enabled and Spark 3.0, 3.1, 3.2 below 3.2.3, " +
- "and 3.3 below 3.3.2 requires caching an intermediate DataFrame, " +
- "which calling code has to unpersist once writing is done. " +
- "Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop. " +
- "See https://issues.apache.org/jira/browse/SPARK-40588")
+ warning(
+ "Partitioned-writing with AQE enabled and Spark 3.0, 3.1, 3.2 below 3.2.3, " +
+ "and 3.3 below 3.3.2 requires caching an intermediate DataFrame, " +
+ "which calling code has to unpersist once writing is done. " +
+ "Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop. " +
+ "See https://issues.apache.org/jira/browse/SPARK-40588"
+ )
case (false, true) if !unpersistHandle.get.isInstanceOf[NoopUnpersistHandle] =>
- info("UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as " +
- "partitioned-writing with AQE disabled or Spark 3.2.3, 3.3.2 or 3.4 and above " +
- "does not require caching intermediate DataFrame.")
+ info(
+ "UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as " +
+ "partitioned-writing with AQE disabled or Spark 3.2.3, 3.3.2 or 3.4 and above " +
+ "does not require caching intermediate DataFrame."
+ )
unpersistHandle.get.setDataFrame(ds.sparkSession.emptyDataFrame)
case _ =>
}
@@ -680,24 +745,28 @@ package object spark extends Logging with SparkVersion with BuildVersion {
val sortColumns = partitionColumnNames ++ moreFileColumns ++ moreFileOrder
ds.toDF
.call(ds => partitionColumnsMap.foldLeft(ds) { case (ds, (name, col)) => ds.withColumn(name, col) })
- .when(partitions.isEmpty).call(_.repartitionByRange(rangeColumns: _*))
- .when(partitions.isDefined).call(_.repartitionByRange(partitions.get, rangeColumns: _*))
+ .when(partitions.isEmpty)
+ .call(_.repartitionByRange(rangeColumns: _*))
+ .when(partitions.isDefined)
+ .call(_.repartitionByRange(partitions.get, rangeColumns: _*))
.sortWithinPartitions(sortColumns: _*)
- .when(writtenProjection.isDefined).call(_.select(writtenProjection.get: _*))
- .when(requiresCaching && unpersistHandle.isDefined).call(unpersistHandle.get.setDataFrame(_))
+ .when(writtenProjection.isDefined)
+ .call(_.select(writtenProjection.get: _*))
+ .when(requiresCaching && unpersistHandle.isDefined)
+ .call(unpersistHandle.get.setDataFrame(_))
.write
.partitionBy(partitionColumnsMap.keys.toSeq: _*)
}
/**
- * (Scala-specific)
- * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
+ * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
*
- * @see `org.apache.spark.sql.Dataset.groupByKey(T => K)`
+ * @see
+ * `org.apache.spark.sql.Dataset.groupByKey(T => K)`
*
- * @note Calling this method should be preferred to `groupByKey(T => K)` because the
- * Catalyst query planner cannot exploit existing partitioning and ordering of
- * this Dataset with that function.
+ * @note
+ * Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot
+ * exploit existing partitioning and ordering of this Dataset with that function.
*
* {{{
* ds.groupByKey[Int]($"age").flatMapGroups(...)
@@ -708,14 +777,14 @@ package object spark extends Logging with SparkVersion with BuildVersion {
ds.groupBy(column +: columns: _*).as[K, V]
/**
- * (Scala-specific)
- * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
+ * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
*
- * @see `org.apache.spark.sql.Dataset.groupByKey(T => K)`
+ * @see
+ * `org.apache.spark.sql.Dataset.groupByKey(T => K)`
*
- * @note Calling this method should be preferred to `groupByKey(T => K)` because the
- * Catalyst query planner cannot exploit existing partitioning and ordering of
- * this Dataset with that function.
+ * @note
+ * Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot
+ * exploit existing partitioning and ordering of this Dataset with that function.
*
* {{{
* ds.groupByKey[Int]($"age").flatMapGroups(...)
@@ -726,9 +795,8 @@ package object spark extends Logging with SparkVersion with BuildVersion {
ds.groupBy(column, columns: _*).as[K, V]
/**
- * Groups the Dataset and sorts the groups using the specified columns, so we can run
- * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available
- * functions.
+ * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
+ * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
@@ -736,17 +804,18 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
- * @param cols grouping columns
- * @param order sort columns
+ * @param cols
+ * grouping columns
+ * @param order
+ * sort columns
*/
- def groupBySorted[K: Ordering : Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {
+ def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {
SortedGroupByDataset(ds, cols, order, None)
}
/**
- * Groups the Dataset and sorts the groups using the specified columns, so we can run
- * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available
- * functions.
+ * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
+ * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
@@ -754,18 +823,22 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
- * @param partitions number of partitions
- * @param cols grouping columns
- * @param order sort columns
+ * @param partitions
+ * number of partitions
+ * @param cols
+ * grouping columns
+ * @param order
+ * sort columns
*/
- def groupBySorted[K: Ordering : Encoder](partitions: Int)(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {
+ def groupBySorted[K: Ordering: Encoder](
+ partitions: Int
+ )(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = {
SortedGroupByDataset(ds, cols, order, Some(partitions))
}
/**
- * Groups the Dataset and sorts the groups using the specified columns, so we can run
- * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available
- * functions.
+ * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
+ * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
@@ -773,17 +846,21 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
- * @param partitions number of partitions
- * @param key grouping key
- * @param order sort key
+ * @param partitions
+ * number of partitions
+ * @param key
+ * grouping key
+ * @param order
+ * sort key
*/
- def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O): SortedGroupByDataset[K, V] =
+ def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
+ order: V => O
+ ): SortedGroupByDataset[K, V] =
groupByKeySorted(key, Some(partitions))(order)
/**
- * Groups the Dataset and sorts the groups using the specified columns, so we can run
- * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available
- * functions.
+ * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
+ * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
@@ -791,18 +868,24 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
- * @param partitions number of partitions
- * @param key grouping key
- * @param order sort key
- * @param reverse sort reverse order
+ * @param partitions
+ * number of partitions
+ * @param key
+ * grouping key
+ * @param order
+ * sort key
+ * @param reverse
+ * sort reverse order
*/
- def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O, reverse: Boolean): SortedGroupByDataset[K, V] =
+ def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)(
+ order: V => O,
+ reverse: Boolean
+ ): SortedGroupByDataset[K, V] =
groupByKeySorted(key, Some(partitions))(order, reverse)
/**
- * Groups the Dataset and sorts the groups using the specified columns, so we can run
- * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available
- * functions.
+ * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted
+ * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions.
*
* {{{
* // Enumerate elements in the sorted group
@@ -810,12 +893,19 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* .flatMapSortedGroups((key, it) => it.zipWithIndex)
* }}}
*
- * @param partitions optional number of partitions
- * @param key grouping key
- * @param order sort key
- * @param reverse sort reverse order
+ * @param partitions
+ * optional number of partitions
+ * @param key
+ * grouping key
+ * @param order
+ * sort key
+ * @param reverse
+ * sort reverse order
*/
- def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = {
+ def groupByKeySorted[K: Ordering: Encoder, O: Encoder](
+ key: V => K,
+ partitions: Option[Int] = None
+ )(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = {
SortedGroupByDataset(ds, key, order, partitions, reverse)
}
@@ -856,34 +946,36 @@ package object spark extends Logging with SparkVersion with BuildVersion {
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
- def withRowNumbers(rowNumberColumnName: String,
- storageLevel: StorageLevel,
- order: Column*): DataFrame =
- RowNumbers.withRowNumberColumnName(rowNumberColumnName).withStorageLevel(storageLevel).withOrderColumns(order).of(ds)
+ def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame =
+ RowNumbers
+ .withRowNumberColumnName(rowNumberColumnName)
+ .withStorageLevel(storageLevel)
+ .withOrderColumns(order)
+ .of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
- def withRowNumbers(rowNumberColumnName: String,
- unpersistHandle: UnpersistHandle,
- order: Column*): DataFrame =
- RowNumbers.withRowNumberColumnName(rowNumberColumnName).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)
+ def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
+ RowNumbers
+ .withRowNumberColumnName(rowNumberColumnName)
+ .withUnpersistHandle(unpersistHandle)
+ .withOrderColumns(order)
+ .of(ds)
/**
* Adds a global continuous row number starting at 1.
*
* See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details.
*/
- def withRowNumbers(storageLevel: StorageLevel,
- unpersistHandle: UnpersistHandle,
- order: Column*): DataFrame =
+ def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame =
RowNumbers.withStorageLevel(storageLevel).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)
/**
- * Adds a global continuous row number starting at 1, after sorting rows by the given columns.
- * When no columns are given, the existing order is used.
+ * Adds a global continuous row number starting at 1, after sorting rows by the given columns. When no columns are
+ * given, the existing order is used.
*
* Hence, the following examples are equivalent:
* {{{
@@ -894,8 +986,8 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* The column name of the column with the row numbers can be set via the `rowNumberColumnName` argument.
*
* To avoid some known issues optimizing the query plan, this function has to internally call
- * `Dataset.persist(StorageLevel)` on an intermediate DataFrame. The storage level of that cached
- * DataFrame can be set via `storageLevel`, where the default is `StorageLevel.MEMORY_AND_DISK`.
+ * `Dataset.persist(StorageLevel)` on an intermediate DataFrame. The storage level of that cached DataFrame can be
+ * set via `storageLevel`, where the default is `StorageLevel.MEMORY_AND_DISK`.
*
* That cached intermediate DataFrame can be un-persisted / un-cached as follows:
* {{{
@@ -906,23 +998,36 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* unpersist()
* }}}
*
- * @param rowNumberColumnName name of the row number column
- * @param storageLevel storage level of the cached intermediate DataFrame
- * @param unpersistHandle handle to un-persist intermediate DataFrame
- * @param order columns to order dataframe before assigning row numbers
- * @return dataframe with row numbers
+ * @param rowNumberColumnName
+ * name of the row number column
+ * @param storageLevel
+ * storage level of the cached intermediate DataFrame
+ * @param unpersistHandle
+ * handle to un-persist intermediate DataFrame
+ * @param order
+ * columns to order dataframe before assigning row numbers
+ * @return
+ * dataframe with row numbers
*/
- def withRowNumbers(rowNumberColumnName: String,
- storageLevel: StorageLevel,
- unpersistHandle: UnpersistHandle,
- order: Column*): DataFrame =
- RowNumbers.withRowNumberColumnName(rowNumberColumnName).withStorageLevel(storageLevel).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds)
+ def withRowNumbers(
+ rowNumberColumnName: String,
+ storageLevel: StorageLevel,
+ unpersistHandle: UnpersistHandle,
+ order: Column*
+ ): DataFrame =
+ RowNumbers
+ .withRowNumberColumnName(rowNumberColumnName)
+ .withStorageLevel(storageLevel)
+ .withUnpersistHandle(unpersistHandle)
+ .withOrderColumns(order)
+ .of(ds)
}
/**
* Class to extend a Spark Dataframe.
*
- * @param df dataframe
+ * @param df
+ * dataframe
*/
@deprecated("Implicit class ExtendedDataframe is deprecated, please recompile your source code.", since = "2.9.0")
class ExtendedDataframe(df: DataFrame) extends ExtendedDataset[Row](df, df.encoder)
@@ -930,14 +1035,16 @@ package object spark extends Logging with SparkVersion with BuildVersion {
/**
* Class to extend a Spark Dataframe.
*
- * @param df dataframe
+ * @param df
+ * dataframe
*/
def ExtendedDataframe(df: DataFrame): ExtendedDataframe = new ExtendedDataframe(df)
/**
* Implicit class to extend a Spark Dataframe, which is a Dataset[Row].
*
- * @param df dataframe
+ * @param df
+ * dataframe
*/
implicit class ExtendedDataframeV2(df: DataFrame) extends ExtendedDatasetV2[Row](df)
diff --git a/src/main/scala/uk/co/gresearch/spark/parquet/package.scala b/src/main/scala/uk/co/gresearch/spark/parquet/package.scala
index 8291d7d1..d1f91380 100644
--- a/src/main/scala/uk/co/gresearch/spark/parquet/package.scala
+++ b/src/main/scala/uk/co/gresearch/spark/parquet/package.scala
@@ -34,31 +34,35 @@ package object parquet {
/**
* Implicit class to extend a Spark DataFrameReader.
*
- * @param reader data frame reader
+ * @param reader
+ * data frame reader
*/
implicit class ExtendedDataFrameReader(reader: DataFrameReader) {
+
/**
* Read the metadata of Parquet files into a Dataframe.
*
- * The returned DataFrame has as many partitions as there are Parquet files,
- * at most `spark.sparkContext.defaultParallelism` partitions.
+ * The returned DataFrame has as many partitions as there are Parquet files, at most
+ * `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-file information:
- * - filename (string): The file name
- * - blocks (int): Number of blocks / RowGroups in the Parquet file
- * - compressedBytes (long): Number of compressed bytes of all blocks
- * - uncompressedBytes (long): Number of uncompressed bytes of all blocks
- * - rows (long): Number of rows in the file
- * - columns (int): Number of columns in the file
- * - values (long): Number of values in the file
- * - nulls (long): Number of null values in the file
- * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
- * - schema (string): The schema
- * - encryption (string): The encryption
- * - keyValues (string-to-string map): Key-value data of the file
+ * - filename (string): The file name
+ * - blocks (int): Number of blocks / RowGroups in the Parquet file
+ * - compressedBytes (long): Number of compressed bytes of all blocks
+ * - uncompressedBytes (long): Number of uncompressed bytes of all blocks
+ * - rows (long): Number of rows in the file
+ * - columns (int): Number of columns in the file
+ * - values (long): Number of values in the file
+ * - nulls (long): Number of null values in the file
+ * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
+ * - schema (string): The schema
+ * - encryption (string): The encryption
+ * - keyValues (string-to-string map): Key-value data of the file
*
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet metadata
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetMetadata(paths: String*): DataFrame = parquetMetadata(None, paths)
@@ -69,22 +73,25 @@ package object parquet {
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-file information:
- * - filename (string): The file name
- * - blocks (int): Number of blocks / RowGroups in the Parquet file
- * - compressedBytes (long): Number of compressed bytes of all blocks
- * - uncompressedBytes (long): Number of uncompressed bytes of all blocks
- * - rows (long): Number of rows in the file
- * - columns (int): Number of columns in the file
- * - values (long): Number of values in the file
- * - nulls (long): Number of null values in the file
- * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
- * - schema (string): The schema
- * - encryption (string): The encryption
- * - keyValues (string-to-string map): Key-value data of the file
+ * - filename (string): The file name
+ * - blocks (int): Number of blocks / RowGroups in the Parquet file
+ * - compressedBytes (long): Number of compressed bytes of all blocks
+ * - uncompressedBytes (long): Number of uncompressed bytes of all blocks
+ * - rows (long): Number of rows in the file
+ * - columns (int): Number of columns in the file
+ * - values (long): Number of values in the file
+ * - nulls (long): Number of null values in the file
+ * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file
+ * - schema (string): The schema
+ * - encryption (string): The encryption
+ * - keyValues (string-to-string map): Key-value data of the file
*
- * @param parallelism number of partitions of returned DataFrame
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet metadata
+ * @param parallelism
+ * number of partitions of returned DataFrame
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetMetadata(parallelism: Int, paths: String*): DataFrame = parquetMetadata(Some(parallelism), paths)
@@ -94,64 +101,70 @@ package object parquet {
import files.sparkSession.implicits._
- files.flatMap { case (_, file) =>
- readFooters(file).map { footer =>
- (
- footer.getFile.toString,
- footer.getParquetMetadata.getBlocks.size(),
- footer.getParquetMetadata.getBlocks.asScala.map(_.getCompressedSize).sum,
- footer.getParquetMetadata.getBlocks.asScala.map(_.getTotalByteSize).sum,
- footer.getParquetMetadata.getBlocks.asScala.map(_.getRowCount).sum,
- footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.size(),
- footer.getParquetMetadata.getBlocks.asScala.map(_.getColumns.map(_.getValueCount).sum).sum,
- // when all columns have statistics, count the null values
- Option(footer.getParquetMetadata.getBlocks.asScala.flatMap(_.getColumns.map(c => Option(c.getStatistics))))
- .filter(_.forall(_.isDefined))
- .map(_.map(_.get.getNumNulls).sum),
- footer.getParquetMetadata.getFileMetaData.getCreatedBy,
- footer.getParquetMetadata.getFileMetaData.getSchema.toString,
- FileMetaDataUtil.getEncryptionType(footer.getParquetMetadata.getFileMetaData),
- footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData.asScala,
- )
+ files
+ .flatMap { case (_, file) =>
+ readFooters(file).map { footer =>
+ (
+ footer.getFile.toString,
+ footer.getParquetMetadata.getBlocks.size(),
+ footer.getParquetMetadata.getBlocks.asScala.map(_.getCompressedSize).sum,
+ footer.getParquetMetadata.getBlocks.asScala.map(_.getTotalByteSize).sum,
+ footer.getParquetMetadata.getBlocks.asScala.map(_.getRowCount).sum,
+ footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.size(),
+ footer.getParquetMetadata.getBlocks.asScala.map(_.getColumns.map(_.getValueCount).sum).sum,
+ // when all columns have statistics, count the null values
+ Option(
+ footer.getParquetMetadata.getBlocks.asScala.flatMap(_.getColumns.map(c => Option(c.getStatistics)))
+ )
+ .filter(_.forall(_.isDefined))
+ .map(_.map(_.get.getNumNulls).sum),
+ footer.getParquetMetadata.getFileMetaData.getCreatedBy,
+ footer.getParquetMetadata.getFileMetaData.getSchema.toString,
+ FileMetaDataUtil.getEncryptionType(footer.getParquetMetadata.getFileMetaData),
+ footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData.asScala,
+ )
+ }
}
- }.toDF(
- "filename",
- "blocks",
- "compressedBytes",
- "uncompressedBytes",
- "rows",
- "columns",
- "values",
- "nulls",
- "createdBy",
- "schema",
- "encryption",
- "keyValues"
- )
+ .toDF(
+ "filename",
+ "blocks",
+ "compressedBytes",
+ "uncompressedBytes",
+ "rows",
+ "columns",
+ "values",
+ "nulls",
+ "createdBy",
+ "schema",
+ "encryption",
+ "keyValues"
+ )
}
/**
* Read the schema of Parquet files into a Dataframe.
*
- * The returned DataFrame has as many partitions as there are Parquet files,
- * at most `spark.sparkContext.defaultParallelism` partitions.
+ * The returned DataFrame has as many partitions as there are Parquet files, at most
+ * `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-file information:
- * - filename (string): The Parquet file name
- * - columnName (string): The column name
- * - columnPath (string array): The column path
- * - repetition (string): The repetition
- * - type (string): The data type
- * - length (int): The length of the type
- * - originalType (string): The original type
- * - isPrimitive (boolean: True if type is primitive
- * - primitiveType (string: The primitive type
- * - primitiveOrder (string: The order of the primitive type
- * - maxDefinitionLevel (int): The max definition level
- * - maxRepetitionLevel (int): The max repetition level
+ * - filename (string): The Parquet file name
+ * - columnName (string): The column name
+ * - columnPath (string array): The column path
+ * - repetition (string): The repetition
+ * - type (string): The data type
+ * - length (int): The length of the type
+ * - originalType (string): The original type
+ * - isPrimitive (boolean: True if type is primitive
+ * - primitiveType (string: The primitive type
+ * - primitiveOrder (string: The order of the primitive type
+ * - maxDefinitionLevel (int): The max definition level
+ * - maxRepetitionLevel (int): The max repetition level
*
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet metadata
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetSchema(paths: String*): DataFrame = parquetSchema(None, paths)
@@ -162,22 +175,25 @@ package object parquet {
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-file information:
- * - filename (string): The Parquet file name
- * - columnName (string): The column name
- * - columnPath (string array): The column path
- * - repetition (string): The repetition
- * - type (string): The data type
- * - length (int): The length of the type
- * - originalType (string): The original type
- * - isPrimitive (boolean: True if type is primitive
- * - primitiveType (string: The primitive type
- * - primitiveOrder (string: The order of the primitive type
- * - maxDefinitionLevel (int): The max definition level
- * - maxRepetitionLevel (int): The max repetition level
+ * - filename (string): The Parquet file name
+ * - columnName (string): The column name
+ * - columnPath (string array): The column path
+ * - repetition (string): The repetition
+ * - type (string): The data type
+ * - length (int): The length of the type
+ * - originalType (string): The original type
+ * - isPrimitive (boolean: True if type is primitive
+ * - primitiveType (string: The primitive type
+ * - primitiveOrder (string: The order of the primitive type
+ * - maxDefinitionLevel (int): The max definition level
+ * - maxRepetitionLevel (int): The max repetition level
*
- * @param parallelism number of partitions of returned DataFrame
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet metadata
+ * @param parallelism
+ * number of partitions of returned DataFrame
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet metadata
*/
@scala.annotation.varargs
def parquetSchema(parallelism: Int, paths: String*): DataFrame = parquetSchema(Some(parallelism), paths)
@@ -187,62 +203,66 @@ package object parquet {
import files.sparkSession.implicits._
- files.flatMap { case (_, file) =>
- readFooters(file).flatMap { footer =>
- footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.map { column =>
- (
- footer.getFile.toString,
- Option(column.getPrimitiveType).map(_.getName),
- column.getPath,
- Option(column.getPrimitiveType).flatMap(v => Option(v.getRepetition)).map(_.name),
- Option(column.getPrimitiveType).flatMap(v => Option(v.getPrimitiveTypeName)).map(_.name),
- Option(column.getPrimitiveType).map(_.getTypeLength),
- Option(column.getPrimitiveType).flatMap(v => Option(v.getOriginalType)).map(_.name),
- Option(column.getPrimitiveType).flatMap(PrimitiveTypeUtil.getLogicalTypeAnnotation),
- column.getPrimitiveType.isPrimitive,
- Option(column.getPrimitiveType).map(_.getPrimitiveTypeName.name),
- Option(column.getPrimitiveType).flatMap(v => Option(v.columnOrder)).map(_.getColumnOrderName.name),
- column.getMaxDefinitionLevel,
- column.getMaxRepetitionLevel,
- )
+ files
+ .flatMap { case (_, file) =>
+ readFooters(file).flatMap { footer =>
+ footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.map { column =>
+ (
+ footer.getFile.toString,
+ Option(column.getPrimitiveType).map(_.getName),
+ column.getPath,
+ Option(column.getPrimitiveType).flatMap(v => Option(v.getRepetition)).map(_.name),
+ Option(column.getPrimitiveType).flatMap(v => Option(v.getPrimitiveTypeName)).map(_.name),
+ Option(column.getPrimitiveType).map(_.getTypeLength),
+ Option(column.getPrimitiveType).flatMap(v => Option(v.getOriginalType)).map(_.name),
+ Option(column.getPrimitiveType).flatMap(PrimitiveTypeUtil.getLogicalTypeAnnotation),
+ column.getPrimitiveType.isPrimitive,
+ Option(column.getPrimitiveType).map(_.getPrimitiveTypeName.name),
+ Option(column.getPrimitiveType).flatMap(v => Option(v.columnOrder)).map(_.getColumnOrderName.name),
+ column.getMaxDefinitionLevel,
+ column.getMaxRepetitionLevel,
+ )
+ }
}
}
- }.toDF(
- "filename",
- "columnName",
- "columnPath",
- "repetition",
- "type",
- "length",
- "originalType",
- "logicalType",
- "isPrimitive",
- "primitiveType",
- "primitiveOrder",
- "maxDefinitionLevel",
- "maxRepetitionLevel",
- )
+ .toDF(
+ "filename",
+ "columnName",
+ "columnPath",
+ "repetition",
+ "type",
+ "length",
+ "originalType",
+ "logicalType",
+ "isPrimitive",
+ "primitiveType",
+ "primitiveOrder",
+ "maxDefinitionLevel",
+ "maxRepetitionLevel",
+ )
}
/**
* Read the metadata of Parquet blocks into a Dataframe.
*
- * The returned DataFrame has as many partitions as there are Parquet files,
- * at most `spark.sparkContext.defaultParallelism` partitions.
+ * The returned DataFrame has as many partitions as there are Parquet files, at most
+ * `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-block information:
- * - filename (string): The file name
- * - block (int): Block / RowGroup number starting at 1
- * - blockStart (long): Start position of the block in the Parquet file
- * - compressedBytes (long): Number of compressed bytes in block
- * - uncompressedBytes (long): Number of uncompressed bytes in block
- * - rows (long): Number of rows in block
- * - columns (int): Number of columns in block
- * - values (long): Number of values in block
- * - nulls (long): Number of null values in block
+ * - filename (string): The file name
+ * - block (int): Block / RowGroup number starting at 1
+ * - blockStart (long): Start position of the block in the Parquet file
+ * - compressedBytes (long): Number of compressed bytes in block
+ * - uncompressedBytes (long): Number of uncompressed bytes in block
+ * - rows (long): Number of rows in block
+ * - columns (int): Number of columns in block
+ * - values (long): Number of values in block
+ * - nulls (long): Number of null values in block
*
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet block metadata
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlocks(paths: String*): DataFrame = parquetBlocks(None, paths)
@@ -253,19 +273,22 @@ package object parquet {
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-block information:
- * - filename (string): The file name
- * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)
- * - blockStart (long): Start position of the block in the Parquet file
- * - compressedBytes (long): Number of compressed bytes in block
- * - uncompressedBytes (long): Number of uncompressed bytes in block
- * - rows (long): Number of rows in block
- * - columns (int): Number of columns in block
- * - values (long): Number of values in block
- * - nulls (long): Number of null values in block
+ * - filename (string): The file name
+ * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)
+ * - blockStart (long): Start position of the block in the Parquet file
+ * - compressedBytes (long): Number of compressed bytes in block
+ * - uncompressedBytes (long): Number of uncompressed bytes in block
+ * - rows (long): Number of rows in block
+ * - columns (int): Number of columns in block
+ * - values (long): Number of values in block
+ * - nulls (long): Number of null values in block
*
- * @param parallelism number of partitions of returned DataFrame
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet block metadata
+ * @param parallelism
+ * number of partitions of returned DataFrame
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlocks(parallelism: Int, paths: String*): DataFrame = parquetBlocks(Some(parallelism), paths)
@@ -275,61 +298,65 @@ package object parquet {
import files.sparkSession.implicits._
- files.flatMap { case (_, file) =>
- readFooters(file).flatMap { footer =>
- footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.map { case (block, idx) =>
- (
- footer.getFile.toString,
- BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,
- block.getStartingPos,
- block.getCompressedSize,
- block.getTotalByteSize,
- block.getRowCount,
- block.getColumns.asScala.size,
- block.getColumns.asScala.map(_.getValueCount).sum,
- // when all columns have statistics, count the null values
- Option(block.getColumns.asScala.map(c => Option(c.getStatistics)))
- .filter(_.forall(_.isDefined))
- .map(_.map(_.get.getNumNulls).sum),
- )
+ files
+ .flatMap { case (_, file) =>
+ readFooters(file).flatMap { footer =>
+ footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.map { case (block, idx) =>
+ (
+ footer.getFile.toString,
+ BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,
+ block.getStartingPos,
+ block.getCompressedSize,
+ block.getTotalByteSize,
+ block.getRowCount,
+ block.getColumns.asScala.size,
+ block.getColumns.asScala.map(_.getValueCount).sum,
+ // when all columns have statistics, count the null values
+ Option(block.getColumns.asScala.map(c => Option(c.getStatistics)))
+ .filter(_.forall(_.isDefined))
+ .map(_.map(_.get.getNumNulls).sum),
+ )
+ }
}
}
- }.toDF(
- "filename",
- "block",
- "blockStart",
- "compressedBytes",
- "uncompressedBytes",
- "rows",
- "columns",
- "values",
- "nulls"
- )
+ .toDF(
+ "filename",
+ "block",
+ "blockStart",
+ "compressedBytes",
+ "uncompressedBytes",
+ "rows",
+ "columns",
+ "values",
+ "nulls"
+ )
}
/**
* Read the metadata of Parquet block columns into a Dataframe.
*
- * The returned DataFrame has as many partitions as there are Parquet files,
- * at most `spark.sparkContext.defaultParallelism` partitions.
+ * The returned DataFrame has as many partitions as there are Parquet files, at most
+ * `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-block-column information:
- * - filename (string): The file name
- * - block (int): Block / RowGroup number starting at 1
- * - column (string): Block / RowGroup column name
- * - codec (string): The coded used to compress the block column values
- * - type (string): The data type of the block column
- * - encodings (string): Encodings of the block column
- * - minValue (string): Minimum value of this column in this block
- * - maxValue (string): Maximum value of this column in this block
- * - columnStart (long): Start position of the block column in the Parquet file
- * - compressedBytes (long): Number of compressed bytes of this block column
- * - uncompressedBytes (long): Number of uncompressed bytes of this block column
- * - values (long): Number of values in this block column
- * - nulls (long): Number of null values in block
+ * - filename (string): The file name
+ * - block (int): Block / RowGroup number starting at 1
+ * - column (string): Block / RowGroup column name
+ * - codec (string): The coded used to compress the block column values
+ * - type (string): The data type of the block column
+ * - encodings (string): Encodings of the block column
+ * - minValue (string): Minimum value of this column in this block
+ * - maxValue (string): Maximum value of this column in this block
+ * - columnStart (long): Start position of the block column in the Parquet file
+ * - compressedBytes (long): Number of compressed bytes of this block column
+ * - uncompressedBytes (long): Number of uncompressed bytes of this block column
+ * - values (long): Number of values in this block column
+ * - nulls (long): Number of null values in block
*
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet block metadata
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlockColumns(paths: String*): DataFrame = parquetBlockColumns(None, paths)
@@ -340,23 +367,26 @@ package object parquet {
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-block-column information:
- * - filename (string): The file name
- * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)
- * - column (string): Block / RowGroup column name
- * - codec (string): The coded used to compress the block column values
- * - type (string): The data type of the block column
- * - encodings (string): Encodings of the block column
- * - minValue (string): Minimum value of this column in this block
- * - maxValue (string): Maximum value of this column in this block
- * - columnStart (long): Start position of the block column in the Parquet file
- * - compressedBytes (long): Number of compressed bytes of this block column
- * - uncompressedBytes (long): Number of uncompressed bytes of this block column
- * - values (long): Number of values in this block column
- * - nulls (long): Number of null values in block
+ * - filename (string): The file name
+ * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1)
+ * - column (string): Block / RowGroup column name
+ * - codec (string): The coded used to compress the block column values
+ * - type (string): The data type of the block column
+ * - encodings (string): Encodings of the block column
+ * - minValue (string): Minimum value of this column in this block
+ * - maxValue (string): Maximum value of this column in this block
+ * - columnStart (long): Start position of the block column in the Parquet file
+ * - compressedBytes (long): Number of compressed bytes of this block column
+ * - uncompressedBytes (long): Number of uncompressed bytes of this block column
+ * - values (long): Number of values in this block column
+ * - nulls (long): Number of null values in block
*
- * @param parallelism number of partitions of returned DataFrame
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Parquet block metadata
+ * @param parallelism
+ * number of partitions of returned DataFrame
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Parquet block metadata
*/
@scala.annotation.varargs
def parquetBlockColumns(parallelism: Int, paths: String*): DataFrame = parquetBlockColumns(Some(parallelism), paths)
@@ -366,67 +396,71 @@ package object parquet {
import files.sparkSession.implicits._
- files.flatMap { case (_, file) =>
- readFooters(file).flatMap { footer =>
- footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.flatMap { case (block, idx) =>
- block.getColumns.asScala.map { column =>
- (
- footer.getFile.toString,
- BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,
- column.getPath.toSeq,
- column.getCodec.toString,
- column.getPrimitiveType.toString,
- column.getEncodings.asScala.toSeq.map(_.toString).sorted,
- Option(column.getStatistics).map(_.minAsString),
- Option(column.getStatistics).map(_.maxAsString),
- column.getStartingPos,
- column.getTotalSize,
- column.getTotalUncompressedSize,
- column.getValueCount,
- Option(column.getStatistics).map(_.getNumNulls),
- )
+ files
+ .flatMap { case (_, file) =>
+ readFooters(file).flatMap { footer =>
+ footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.flatMap { case (block, idx) =>
+ block.getColumns.asScala.map { column =>
+ (
+ footer.getFile.toString,
+ BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1,
+ column.getPath.toSeq,
+ column.getCodec.toString,
+ column.getPrimitiveType.toString,
+ column.getEncodings.asScala.toSeq.map(_.toString).sorted,
+ Option(column.getStatistics).map(_.minAsString),
+ Option(column.getStatistics).map(_.maxAsString),
+ column.getStartingPos,
+ column.getTotalSize,
+ column.getTotalUncompressedSize,
+ column.getValueCount,
+ Option(column.getStatistics).map(_.getNumNulls),
+ )
+ }
}
}
}
- }.toDF(
- "filename",
- "block",
- "column",
- "codec",
- "type",
- "encodings",
- "minValue",
- "maxValue",
- "columnStart",
- "compressedBytes",
- "uncompressedBytes",
- "values",
- "nulls"
- )
+ .toDF(
+ "filename",
+ "block",
+ "column",
+ "codec",
+ "type",
+ "encodings",
+ "minValue",
+ "maxValue",
+ "columnStart",
+ "compressedBytes",
+ "uncompressedBytes",
+ "values",
+ "nulls"
+ )
}
/**
* Read the metadata of how Spark partitions Parquet files into a Dataframe.
*
- * The returned DataFrame has as many partitions as there are Parquet files,
- * at most `spark.sparkContext.defaultParallelism` partitions.
+ * The returned DataFrame has as many partitions as there are Parquet files, at most
+ * `spark.sparkContext.defaultParallelism` partitions.
*
* This provides the following per-partition information:
- * - partition (int): The Spark partition id
- * - start (long): The start position of the partition
- * - end (long): The end position of the partition
- * - length (long): The length of the partition
- * - blocks (int): The number of Parquet blocks / RowGroups in this partition
- * - compressedBytes (long): The number of compressed bytes in this partition
- * - uncompressedBytes (long): The number of uncompressed bytes in this partition
- * - rows (long): The number of rows in this partition
- * - columns (int): Number of columns in the file
- * - values (long): The number of values in this partition
- * - filename (string): The Parquet file name
- * - fileLength (long): The length of the Parquet file
+ * - partition (int): The Spark partition id
+ * - start (long): The start position of the partition
+ * - end (long): The end position of the partition
+ * - length (long): The length of the partition
+ * - blocks (int): The number of Parquet blocks / RowGroups in this partition
+ * - compressedBytes (long): The number of compressed bytes in this partition
+ * - uncompressedBytes (long): The number of uncompressed bytes in this partition
+ * - rows (long): The number of rows in this partition
+ * - columns (int): Number of columns in the file
+ * - values (long): The number of values in this partition
+ * - filename (string): The Parquet file name
+ * - fileLength (long): The length of the Parquet file
*
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Spark Parquet partition metadata
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Spark Parquet partition metadata
*/
@scala.annotation.varargs
def parquetPartitions(paths: String*): DataFrame = parquetPartitions(None, paths)
@@ -437,22 +471,25 @@ package object parquet {
* The returned DataFrame has as many partitions as specified via `parallelism`.
*
* This provides the following per-partition information:
- * - partition (int): The Spark partition id
- * - start (long): The start position of the partition
- * - end (long): The end position of the partition
- * - length (long): The length of the partition
- * - blocks (int): The number of Parquet blocks / RowGroups in this partition
- * - compressedBytes (long): The number of compressed bytes in this partition
- * - uncompressedBytes (long): The number of uncompressed bytes in this partition
- * - rows (long): The number of rows in this partition
- * - columns (int): Number of columns in the file
- * - values (long): The number of values in this partition
- * - filename (string): The Parquet file name
- * - fileLength (long): The length of the Parquet file
+ * - partition (int): The Spark partition id
+ * - start (long): The start position of the partition
+ * - end (long): The end position of the partition
+ * - length (long): The length of the partition
+ * - blocks (int): The number of Parquet blocks / RowGroups in this partition
+ * - compressedBytes (long): The number of compressed bytes in this partition
+ * - uncompressedBytes (long): The number of uncompressed bytes in this partition
+ * - rows (long): The number of rows in this partition
+ * - columns (int): Number of columns in the file
+ * - values (long): The number of values in this partition
+ * - filename (string): The Parquet file name
+ * - fileLength (long): The length of the Parquet file
*
- * @param parallelism number of partitions of returned DataFrame
- * @param paths one or more paths to Parquet files or directories
- * @return dataframe with Spark Parquet partition metadata
+ * @param parallelism
+ * number of partitions of returned DataFrame
+ * @param paths
+ * one or more paths to Parquet files or directories
+ * @return
+ * dataframe with Spark Parquet partition metadata
*/
@scala.annotation.varargs
def parquetPartitions(parallelism: Int, paths: String*): DataFrame = parquetPartitions(Some(parallelism), paths)
@@ -462,56 +499,69 @@ package object parquet {
import files.sparkSession.implicits._
- files.flatMap { case (part, file) =>
- readFooters(file)
- .map(footer => (footer, getBlocks(footer, file.start, file.length)))
- .map { case (footer, blocks) => (
- part,
- file.start,
- file.start + file.length,
- file.length,
- blocks.size,
- blocks.map(_.getCompressedSize).sum,
- blocks.map(_.getTotalByteSize).sum,
- blocks.map(_.getRowCount).sum,
- blocks.map(_.getColumns.map(_.getPath.mkString(".")).toSet).foldLeft(Set.empty[String])((left, right) => left.union(right)).size,
- blocks.map(_.getColumns.asScala.map(_.getValueCount).sum).sum,
- // when all columns have statistics, count the null values
- Option(blocks.flatMap(_.getColumns.asScala.map(c => Option(c.getStatistics))))
- .filter(_.forall(_.isDefined))
- .map(_.map(_.get.getNumNulls).sum),
- footer.getFile.toString,
- file.fileSize,
- )}
- }.toDF(
- "partition",
- "start",
- "end",
- "length",
- "blocks",
- "compressedBytes",
- "uncompressedBytes",
- "rows",
- "columns",
- "values",
- "nulls",
- "filename",
- "fileLength"
- )
+ files
+ .flatMap { case (part, file) =>
+ readFooters(file)
+ .map(footer => (footer, getBlocks(footer, file.start, file.length)))
+ .map { case (footer, blocks) =>
+ (
+ part,
+ file.start,
+ file.start + file.length,
+ file.length,
+ blocks.size,
+ blocks.map(_.getCompressedSize).sum,
+ blocks.map(_.getTotalByteSize).sum,
+ blocks.map(_.getRowCount).sum,
+ blocks
+ .map(_.getColumns.map(_.getPath.mkString(".")).toSet)
+ .foldLeft(Set.empty[String])((left, right) => left.union(right))
+ .size,
+ blocks.map(_.getColumns.asScala.map(_.getValueCount).sum).sum,
+ // when all columns have statistics, count the null values
+ Option(blocks.flatMap(_.getColumns.asScala.map(c => Option(c.getStatistics))))
+ .filter(_.forall(_.isDefined))
+ .map(_.map(_.get.getNumNulls).sum),
+ footer.getFile.toString,
+ file.fileSize,
+ )
+ }
+ }
+ .toDF(
+ "partition",
+ "start",
+ "end",
+ "length",
+ "blocks",
+ "compressedBytes",
+ "uncompressedBytes",
+ "rows",
+ "columns",
+ "values",
+ "nulls",
+ "filename",
+ "fileLength"
+ )
}
private def getFiles(parallelism: Option[Int], paths: Seq[String]): Dataset[(Int, SplitFile)] = {
val df = reader.parquet(paths: _*)
- val parts = df.rdd.partitions.flatMap(part =>
- part.asInstanceOf[FilePartition]
- .files
- .map(file => (part.index, SplitFile(file)))
- ).toSeq.distinct
+ val parts = df.rdd.partitions
+ .flatMap(part =>
+ part
+ .asInstanceOf[FilePartition]
+ .files
+ .map(file => (part.index, SplitFile(file)))
+ )
+ .toSeq
+ .distinct
import df.sparkSession.implicits._
- parts.toDS()
- .when(parallelism.isDefined).call(_.repartition(parallelism.get))
+ parts
+ .toDS()
+ .when(parallelism.isDefined)
+ .call(_.repartition(parallelism.get))
}
}
diff --git a/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala b/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala
index 5a7dca51..d473965a 100644
--- a/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala
@@ -92,7 +92,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
testGroupByIdSortBySeq(ds.groupByKeySorted(v => v.id)(v => (v.seq, v.value)))
testGroupByIdSortBySeqDesc(ds.groupByKeySorted(v => v.id)(v => (v.seq, v.value), reverse = true))
testGroupByIdSortBySeqWithPartitionNum(ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value)))
- testGroupByIdSortBySeqDescWithPartitionNum(ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value), reverse = true))
+ testGroupByIdSortBySeqDescWithPartitionNum(
+ ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value), reverse = true)
+ )
testGroupByIdSeqSortByValue(ds.groupByKeySorted(v => (v.id, v.seq))(v => v.value))
}
@@ -106,14 +108,19 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
describe("df.groupByKeySorted") {
testGroupByIdSortBySeq(df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2))))
- testGroupByIdSortBySeqDesc(df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2)), reverse = true))
- testGroupByIdSortBySeqWithPartitionNum(df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2))))
- testGroupByIdSortBySeqDescWithPartitionNum(df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)), reverse = true))
+ testGroupByIdSortBySeqDesc(
+ df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2)), reverse = true)
+ )
+ testGroupByIdSortBySeqWithPartitionNum(
+ df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)))
+ )
+ testGroupByIdSortBySeqDescWithPartitionNum(
+ df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)), reverse = true)
+ )
testGroupByIdSeqSortByValue(df.groupByKeySorted(v => (v.getInt(0), v.getInt(1)))(v => v.getDouble(2)))
}
- def testGroupByIdSortBySeq[T](ds: SortedGroupByDataset[Int, T])
- (implicit asTuple: T => (Int, Int, Double)): Unit = {
+ def testGroupByIdSortBySeq[T](ds: SortedGroupByDataset[Int, T])(implicit asTuple: T => (Int, Int, Double)): Unit = {
it("should flatMapSortedGroups") {
val actual = ds
@@ -167,8 +174,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
}
- def testGroupByIdSortBySeqDesc[T](ds: SortedGroupByDataset[Int, T])
- (implicit asTuple: T => (Int, Int, Double)): Unit = {
+ def testGroupByIdSortBySeqDesc[T](
+ ds: SortedGroupByDataset[Int, T]
+ )(implicit asTuple: T => (Int, Int, Double)): Unit = {
it("should flatMapSortedGroups reverse") {
val actual = ds
.flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))
@@ -196,8 +204,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
}
- def testGroupByIdSortBySeqWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)
- (implicit asTuple: T => (Int, Int, Double)): Unit = {
+ def testGroupByIdSortBySeqWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)(implicit
+ asTuple: T => (Int, Int, Double)
+ ): Unit = {
it("should flatMapSortedGroups with partition num") {
val grouped = ds
@@ -230,8 +239,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
}
- def testGroupByIdSortBySeqDescWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)
- (implicit asTuple: T => (Int, Int, Double)): Unit = {
+ def testGroupByIdSortBySeqDescWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)(implicit
+ asTuple: T => (Int, Int, Double)
+ ): Unit = {
it("should flatMapSortedGroups with partition num and reverse") {
val grouped = ds
.flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1))))
@@ -262,8 +272,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
}
}
- def testGroupByIdSeqSortByValue[T](ds: SortedGroupByDataset[(Int, Int), T])
- (implicit asTuple: T => (Int, Int, Double)): Unit = {
+ def testGroupByIdSeqSortByValue[T](
+ ds: SortedGroupByDataset[(Int, Int), T]
+ )(implicit asTuple: T => (Int, Int, Double)): Unit = {
it("should flatMapSortedGroups with tuple key") {
val actual = ds
@@ -323,7 +334,6 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession {
}
-
object GroupBySortedSuite {
implicit def valueToTuple(value: Val): (Int, Int, Double) = (value.id, value.seq, value.value)
implicit def valueRowToTuple(value: Row): (Int, Int, Double) = (value.getInt(0), value.getInt(1), value.getDouble(2))
diff --git a/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala b/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala
index 0cd1d4ac..04084dcb 100644
--- a/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala
@@ -63,34 +63,40 @@ class HistogramSuite extends AnyFunSuite with SparkTestSession {
Seq(3, 0, 1, 0, 1, 0, 1),
Seq(4, 0, 0, 1, 0, 0, 0)
)
- val expectedSchema: StructType = StructType(Seq(
- StructField("id", IntegerType, nullable = false),
- StructField("≤-200", LongType, nullable = true),
- StructField("≤-100", LongType, nullable = true),
- StructField("≤0", LongType, nullable = true),
- StructField("≤100", LongType, nullable = true),
- StructField("≤200", LongType, nullable = true),
- StructField(">200", LongType, nullable = true)
- ))
- val expectedSchema2: StructType = StructType(Seq(
- StructField("id", IntegerType, nullable = false),
- StructField("title", StringType, nullable = true),
- StructField("≤-200", LongType, nullable = true),
- StructField("≤-100", LongType, nullable = true),
- StructField("≤0", LongType, nullable = true),
- StructField("≤100", LongType, nullable = true),
- StructField("≤200", LongType, nullable = true),
- StructField(">200", LongType, nullable = true)
- ))
- val expectedDoubleSchema: StructType = StructType(Seq(
- StructField("id", IntegerType, nullable = false),
- StructField("≤-200.0", LongType, nullable = true),
- StructField("≤-100.0", LongType, nullable = true),
- StructField("≤0.0", LongType, nullable = true),
- StructField("≤100.0", LongType, nullable = true),
- StructField("≤200.0", LongType, nullable = true),
- StructField(">200.0", LongType, nullable = true)
- ))
+ val expectedSchema: StructType = StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("≤-200", LongType, nullable = true),
+ StructField("≤-100", LongType, nullable = true),
+ StructField("≤0", LongType, nullable = true),
+ StructField("≤100", LongType, nullable = true),
+ StructField("≤200", LongType, nullable = true),
+ StructField(">200", LongType, nullable = true)
+ )
+ )
+ val expectedSchema2: StructType = StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("title", StringType, nullable = true),
+ StructField("≤-200", LongType, nullable = true),
+ StructField("≤-100", LongType, nullable = true),
+ StructField("≤0", LongType, nullable = true),
+ StructField("≤100", LongType, nullable = true),
+ StructField("≤200", LongType, nullable = true),
+ StructField(">200", LongType, nullable = true)
+ )
+ )
+ val expectedDoubleSchema: StructType = StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("≤-200.0", LongType, nullable = true),
+ StructField("≤-100.0", LongType, nullable = true),
+ StructField("≤0.0", LongType, nullable = true),
+ StructField("≤100.0", LongType, nullable = true),
+ StructField("≤200.0", LongType, nullable = true),
+ StructField(">200.0", LongType, nullable = true)
+ )
+ )
test("histogram with no aggregate columns") {
val histogram = ints.histogram(intThresholds, $"value")
@@ -110,12 +116,14 @@ class HistogramSuite extends AnyFunSuite with SparkTestSession {
val histogram = ints.histogram(intThresholds, $"value", $"id", $"title")
val actual = histogram.orderBy($"id").collect().toSeq.map(_.toSeq)
assert(histogram.schema === expectedSchema2)
- assert(actual === Seq(
- Seq(1, "one", 0, 0, 0, 3, 0, 0),
- Seq(2, "two", 0, 0, 0, 4, 0, 0),
- Seq(3, "three", 0, 1, 0, 1, 0, 1),
- Seq(4, "four", 0, 0, 1, 0, 0, 0)
- ))
+ assert(
+ actual === Seq(
+ Seq(1, "one", 0, 0, 0, 3, 0, 0),
+ Seq(2, "two", 0, 0, 0, 4, 0, 0),
+ Seq(3, "three", 0, 1, 0, 1, 0, 1),
+ Seq(4, "four", 0, 0, 1, 0, 0, 0)
+ )
+ )
}
test("histogram with int values") {
@@ -156,18 +164,23 @@ class HistogramSuite extends AnyFunSuite with SparkTestSession {
test("histogram with one threshold") {
val histogram = ints.histogram(Seq(0), $"value", $"id")
val actual = histogram.orderBy($"id").collect().toSeq.map(_.toSeq)
- assert(histogram.schema === StructType(Seq(
- StructField("id", IntegerType, nullable = false),
- StructField("≤0", LongType, nullable = true),
- StructField(">0", LongType, nullable = true)
- ))
+ assert(
+ histogram.schema === StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("≤0", LongType, nullable = true),
+ StructField(">0", LongType, nullable = true)
+ )
+ )
+ )
+ assert(
+ actual === Seq(
+ Seq(1, 0, 3),
+ Seq(2, 0, 4),
+ Seq(3, 1, 2),
+ Seq(4, 1, 0)
+ )
)
- assert(actual === Seq(
- Seq(1, 0, 3),
- Seq(2, 0, 4),
- Seq(3, 1, 2),
- Seq(4, 1, 0)
- ))
}
test("histogram with duplicate thresholds") {
diff --git a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
index fbb12e40..78c90b48 100644
--- a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
@@ -39,7 +39,11 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
val emptyDataFrame: DataFrame = spark.createDataFrame(Seq.empty[Value])
test("Get Spark version") {
- assert(VersionString.contains(s"-$BuildSparkCompatVersionString-") || VersionString.endsWith(s"-$BuildSparkCompatVersionString"))
+ assert(
+ VersionString.contains(s"-$BuildSparkCompatVersionString-") || VersionString.endsWith(
+ s"-$BuildSparkCompatVersionString"
+ )
+ )
assert(spark.version.startsWith(s"$BuildSparkCompatVersionString."))
assert(SparkVersion === BuildSparkVersion)
@@ -121,7 +125,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
test("UnpersistHandle throws on unpersist with blocking if no DataFrame is set") {
val unpersist = UnpersistHandle()
- assert(intercept[IllegalStateException] { unpersist(blocking = true) }.getMessage === s"DataFrame has to be set first")
+ assert(intercept[IllegalStateException] {
+ unpersist(blocking = true)
+ }.getMessage === s"DataFrame has to be set first")
}
test("SilentUnpersistHandle does not throw on unpersist if no DataFrame is set") {
@@ -145,7 +151,14 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
test("count_null") {
val df = Seq(
- (1, "some"), (2, "text"), (3, "and"), (4, "some"), (5, "null"), (6, "values"), (7, null), (8, null)
+ (1, "some"),
+ (2, "text"),
+ (3, "and"),
+ (4, "some"),
+ (5, "null"),
+ (6, "values"),
+ (7, null),
+ (8, null)
).toDF("id", "str")
val actual =
df.select(
@@ -153,7 +166,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
count($"str").as("strs"),
count_null($"id").as("null ids"),
count_null($"str").as("null strs")
- ).collect().head
+ ).collect()
+ .head
assert(actual === Row(8, 6, 0, 2))
}
@@ -260,7 +274,6 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
assertIsDataset[Value](spark.createDataFrame(Seq.empty[Value]).call(_.as[Value]))
}
-
Seq(true, false).foreach { condition =>
test(s"call on $condition condition dataset-to-dataset transformation") {
assertIsGenericType[Dataset[Value]](
@@ -298,10 +311,10 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
)
}
-
test(s"call on $condition condition either dataset-to-dataset transformation") {
assertIsGenericType[Dataset[Value]](
- spark.emptyDataset[Value]
+ spark
+ .emptyDataset[Value]
.transform(
_.on(condition)
.either(_.sort())
@@ -312,7 +325,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
test(s"call on $condition condition either dataset-to-dataframe transformation") {
assertIsGenericType[DataFrame](
- spark.emptyDataset[Value]
+ spark
+ .emptyDataset[Value]
.transform(
_.on(condition)
.either(_.drop("string"))
@@ -323,7 +337,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
test(s"call on $condition condition either dataframe-to-dataset transformation") {
assertIsGenericType[Dataset[Value]](
- spark.createDataFrame(Seq.empty[Value])
+ spark
+ .createDataFrame(Seq.empty[Value])
.transform(
_.on(condition)
.either(_.as[Value])
@@ -334,7 +349,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
test(s"call on $condition condition either dataframe-to-dataframe transformation") {
assertIsGenericType[DataFrame](
- spark.createDataFrame(Seq.empty[Value])
+ spark
+ .createDataFrame(Seq.empty[Value])
.transform(
_.on(condition)
.either(_.drop("string"))
@@ -344,7 +360,6 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
}
}
-
test("on true condition call either writer-to-writer methods") {
assertIsGenericType[DataFrameWriter[Value]](
spark
@@ -396,7 +411,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
}
test("global row number preserves order") {
- doTestWithRowNumbers()(){ df =>
+ doTestWithRowNumbers()() { df =>
assert(df.columns === Seq("id", "rand", "row_number"))
}
}
@@ -415,7 +430,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, OFF_HEAP, NONE).foreach { level =>
test(s"global row number with $level") {
- if (level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) {
+ if (
+ level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)
+ ) {
assertThrows[IllegalArgumentException] {
doTestWithRowNumbers(storageLevel = level)($"id")()
}
@@ -427,7 +444,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, OFF_HEAP, NONE).foreach { level =>
test(s"global row number allows to unpersist with $level") {
- if (level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) {
+ if (
+ level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)
+ ) {
assertThrows[IllegalArgumentException] {
doTestWithRowNumbers(storageLevel = level)($"id")()
}
@@ -447,59 +466,60 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
test("global row number with existing row_number column") {
// this overwrites the existing column 'row_number' (formerly 'rand') with the row numbers
- doTestWithRowNumbers { df => df.withColumnRenamed("rand", "row_number") }(){ df =>
+ doTestWithRowNumbers { df => df.withColumnRenamed("rand", "row_number") }() { df =>
assert(df.columns === Seq("id", "row_number"))
}
}
test("global row number with custom row_number column") {
// this puts the row numbers in the column "row", which is not the default column name
- doTestWithRowNumbers(df => df.withColumnRenamed("rand", "row_number"),
- rowNumberColumnName = "row" )(){ df =>
+ doTestWithRowNumbers(df => df.withColumnRenamed("rand", "row_number"), rowNumberColumnName = "row")() { df =>
assert(df.columns === Seq("id", "row_number", "row"))
}
}
test("global row number with internal column names") {
- val cols = Seq("mono_id", "partition_id", "local_row_number", "max_local_row_number",
- "cum_row_numbers", "partition_offset")
+ val cols =
+ Seq("mono_id", "partition_id", "local_row_number", "max_local_row_number", "cum_row_numbers", "partition_offset")
var prefix: String = null
doTestWithRowNumbers { df =>
prefix = distinctPrefixFor(df.columns)
- cols.foldLeft(df){ (df, name) => df.withColumn(prefix + name, rand()) }
- }(){ df =>
+ cols.foldLeft(df) { (df, name) => df.withColumn(prefix + name, rand()) }
+ }() { df =>
assert(df.columns === Seq("id", "rand") ++ cols.map(prefix + _) :+ "row_number")
}
}
- def doTestWithRowNumbers(transform: DataFrame => DataFrame = identity,
- rowNumberColumnName: String = "row_number",
- storageLevel: StorageLevel = MEMORY_AND_DISK,
- unpersistHandle: UnpersistHandle = UnpersistHandle.Noop)
- (columns: Column*)
- (handle: DataFrame => Unit = identity[DataFrame]): Unit = {
+ def doTestWithRowNumbers(
+ transform: DataFrame => DataFrame = identity,
+ rowNumberColumnName: String = "row_number",
+ storageLevel: StorageLevel = MEMORY_AND_DISK,
+ unpersistHandle: UnpersistHandle = UnpersistHandle.Noop
+ )(columns: Column*)(handle: DataFrame => Unit = identity[DataFrame]): Unit = {
val partitions = 10
val rowsPerPartition = 1000
val rows = partitions * rowsPerPartition
assert(partitions > 1)
assert(rowsPerPartition > 1)
- val df = spark.range(1, rows + 1, 1, partitions)
+ val df = spark
+ .range(1, rows + 1, 1, partitions)
.withColumn("rand", rand())
.transform(transform)
.withRowNumbers(
- rowNumberColumnName=rowNumberColumnName,
- storageLevel=storageLevel,
- unpersistHandle=unpersistHandle,
- columns: _*)
+ rowNumberColumnName = rowNumberColumnName,
+ storageLevel = storageLevel,
+ unpersistHandle = unpersistHandle,
+ columns: _*
+ )
.cache()
try {
// testing with descending order is only supported for a single column
val desc = columns.map(_.expr) match {
case Seq(SortOrder(_, Descending, _, _)) => true
- case _ => false
+ case _ => false
}
// assert row numbers are correct
@@ -519,7 +539,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
}
val correctRowNumbers = df.where(expect).count()
- val incorrectRowNumbers = df.where(! expect).count()
+ val incorrectRowNumbers = df.where(!expect).count()
assert(correctRowNumbers === rows)
assert(incorrectRowNumbers === 0)
}
@@ -536,50 +556,68 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
(7, 3155378975999999999L)
).toDF("id", "ts")
- val plan = df.select(
- $"id",
- dotNetTicksToTimestamp($"ts"),
- dotNetTicksToTimestamp("ts"),
- dotNetTicksToUnixEpoch($"ts"),
- dotNetTicksToUnixEpoch("ts"),
- dotNetTicksToUnixEpochNanos($"ts"),
- dotNetTicksToUnixEpochNanos("ts")
- ).orderBy($"id")
- assert(plan.schema.fields.map(_.dataType) === Seq(
- IntegerType, TimestampType, TimestampType, DecimalType(29, 9), DecimalType(29, 9), LongType, LongType
- ))
+ val plan = df
+ .select(
+ $"id",
+ dotNetTicksToTimestamp($"ts"),
+ dotNetTicksToTimestamp("ts"),
+ dotNetTicksToUnixEpoch($"ts"),
+ dotNetTicksToUnixEpoch("ts"),
+ dotNetTicksToUnixEpochNanos($"ts"),
+ dotNetTicksToUnixEpochNanos("ts")
+ )
+ .orderBy($"id")
+ assert(
+ plan.schema.fields.map(_.dataType) === Seq(
+ IntegerType,
+ TimestampType,
+ TimestampType,
+ DecimalType(29, 9),
+ DecimalType(29, 9),
+ LongType,
+ LongType
+ )
+ )
val actual = plan.collect()
- assert(actual.map(_.getTimestamp(1)) === Seq(
- Timestamp.from(Instant.parse("1900-01-01T00:00:00Z")),
- Timestamp.from(Instant.parse("1970-01-01T00:00:00Z")),
- Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")),
- Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")),
- Timestamp.from(Instant.parse("2023-03-27T19:16:14.895931Z")),
- // largest possible unix epoch nanos
- Timestamp.from(Instant.parse("2262-04-11T23:47:16.854775Z")),
- Timestamp.from(Instant.parse("9999-12-31T23:59:59.999999Z")),
- ))
+ assert(
+ actual.map(_.getTimestamp(1)) === Seq(
+ Timestamp.from(Instant.parse("1900-01-01T00:00:00Z")),
+ Timestamp.from(Instant.parse("1970-01-01T00:00:00Z")),
+ Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")),
+ Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")),
+ Timestamp.from(Instant.parse("2023-03-27T19:16:14.895931Z")),
+ // largest possible unix epoch nanos
+ Timestamp.from(Instant.parse("2262-04-11T23:47:16.854775Z")),
+ Timestamp.from(Instant.parse("9999-12-31T23:59:59.999999Z")),
+ )
+ )
assert(actual.map(_.getTimestamp(2)) === actual.map(_.getTimestamp(1)))
- assert(actual.map(_.getDecimal(3)).map(BigDecimal(_)) === Array(
- BigDecimal(-2208988800000000000L, 9),
- BigDecimal(0, 9),
- BigDecimal(1679944574895930800L, 9),
- BigDecimal(1679944574895930900L, 9),
- BigDecimal(1679944574895931000L, 9),
- // largest possible unix epoch nanos
- BigDecimal(9223372036854775800L, 9),
- BigDecimal(2534023007999999999L, 7).setScale(9),
- ))
+ assert(
+ actual.map(_.getDecimal(3)).map(BigDecimal(_)) === Array(
+ BigDecimal(-2208988800000000000L, 9),
+ BigDecimal(0, 9),
+ BigDecimal(1679944574895930800L, 9),
+ BigDecimal(1679944574895930900L, 9),
+ BigDecimal(1679944574895931000L, 9),
+ // largest possible unix epoch nanos
+ BigDecimal(9223372036854775800L, 9),
+ BigDecimal(2534023007999999999L, 7).setScale(9),
+ )
+ )
assert(actual.map(_.getDecimal(4)) === actual.map(_.getDecimal(3)))
- assert(actual.map(row =>
- if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getLong(5) else null
- ) === actual.map(row =>
- if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getDecimal(3).multiply(new java.math.BigDecimal(1000000000)).longValue() else null
- ))
+ assert(
+ actual.map(row =>
+ if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getLong(5) else null
+ ) === actual.map(row =>
+ if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9))
+ row.getDecimal(3).multiply(new java.math.BigDecimal(1000000000)).longValue()
+ else null
+ )
+ )
assert(actual.map(_.get(6)) === actual.map(_.get(5)))
}
@@ -596,24 +634,32 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
df.select(timestampToDotNetTicks($"ts"))
}
} else {
- val plan = df.select(
- $"id",
- timestampToDotNetTicks($"ts"),
- timestampToDotNetTicks("ts"),
- ).orderBy($"id")
+ val plan = df
+ .select(
+ $"id",
+ timestampToDotNetTicks($"ts"),
+ timestampToDotNetTicks("ts"),
+ )
+ .orderBy($"id")
- assert(plan.schema.fields.map(_.dataType) === Seq(
- IntegerType, LongType, LongType
- ))
+ assert(
+ plan.schema.fields.map(_.dataType) === Seq(
+ IntegerType,
+ LongType,
+ LongType
+ )
+ )
val actual = plan.collect()
- assert(actual.map(_.getLong(1)) === Seq(
- 599266080000000000L,
- 621355968000000000L,
- 638155413748959310L,
- 3155378975999999990L
- ))
+ assert(
+ actual.map(_.getLong(1)) === Seq(
+ 599266080000000000L,
+ 621355968000000000L,
+ 638155413748959310L,
+ 3155378975999999990L
+ )
+ )
assert(actual.map(_.getLong(2)) === actual.map(_.getLong(1)))
val message = intercept[AnalysisException] {
@@ -621,17 +667,29 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
}.getMessage
if (SparkMajorVersion == 3 && SparkMinorVersion == 1) {
- assert(message.startsWith("cannot resolve 'unix_micros(`ts`)' due to data type mismatch: argument 1 requires timestamp type, however, '`ts`' is of bigint type.;"))
+ assert(
+ message.startsWith(
+ "cannot resolve 'unix_micros(`ts`)' due to data type mismatch: argument 1 requires timestamp type, however, '`ts`' is of bigint type.;"
+ )
+ )
} else if (SparkMajorVersion == 3 && SparkMinorVersion < 4) {
- assert(message.startsWith("cannot resolve 'unix_micros(ts)' due to data type mismatch: argument 1 requires timestamp type, however, 'ts' is of bigint type.;"))
+ assert(
+ message.startsWith(
+ "cannot resolve 'unix_micros(ts)' due to data type mismatch: argument 1 requires timestamp type, however, 'ts' is of bigint type.;"
+ )
+ )
} else if (SparkMajorVersion == 3 && SparkMinorVersion >= 4 || SparkMajorVersion > 3) {
- assert(message.startsWith("[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve \"unix_micros(ts)\" due to data type mismatch: Parameter 1 requires the \"TIMESTAMP\" type, however \"ts\" has the type \"BIGINT\"."))
+ assert(
+ message.startsWith(
+ "[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve \"unix_micros(ts)\" due to data type mismatch: Parameter 1 requires the \"TIMESTAMP\" type, however \"ts\" has the type \"BIGINT\"."
+ )
+ )
}
}
}
test("Unix epoch to .Net ticks") {
- def df[T : Encoder](v: T): DataFrame =
+ def df[T: Encoder](v: T): DataFrame =
spark.createDataset(Seq(v)).withColumnRenamed("value", "ts")
Seq(
@@ -667,7 +725,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
}
test("Unix epoch nanos to .Net ticks") {
- def df[T : Encoder](v: T): DataFrame =
+ def df[T: Encoder](v: T): DataFrame =
spark.createDataset(Seq(v)).withColumnRenamed("value", "ts")
Seq(
@@ -713,7 +771,9 @@ object SparkSuite {
import spark.implicits._
spark
.range(0, 3, 1, 3)
- .mapPartitions(it => it.map(id => (id, TaskContext.get().partitionId(), TaskContext.get().getLocalProperty("spark.job.description"))))
+ .mapPartitions(it =>
+ it.map(id => (id, TaskContext.get().partitionId(), TaskContext.get().getLocalProperty("spark.job.description")))
+ )
.as[(Long, Long, String)]
.sort()
.collect()
diff --git a/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala b/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala
index 57f7fcf7..62db45eb 100644
--- a/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala
@@ -44,7 +44,6 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
Value(4, Date.valueOf("2020-07-01"), "four")
).toDS()
-
test("write partitionedBy requires caching with AQE enabled") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
Some(spark.version)
@@ -108,7 +107,9 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
test("write with one partition") {
withTempPath { dir =>
withUnpersist() { handle =>
- values.writePartitionedBy(Seq($"id"), Seq($"date"), partitions = Some(1), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)
+ values
+ .writePartitionedBy(Seq($"id"), Seq($"date"), partitions = Some(1), unpersistHandle = Some(handle))
+ .csv(dir.getAbsolutePath)
}
val partitions = dir.list().filter(_.startsWith("id=")).sorted
@@ -123,7 +124,9 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
test("write with partition order") {
withTempPath { dir =>
withUnpersist() { handle =>
- values.writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date"), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)
+ values
+ .writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date"), unpersistHandle = Some(handle))
+ .csv(dir.getAbsolutePath)
}
val partitions = dir.list().filter(_.startsWith("id=")).sorted
@@ -134,26 +137,40 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
assert(files.length === 1)
val source = Source.fromFile(new File(file, files(0)))
- val lines = try source.getLines().toList finally source.close()
+ val lines =
+ try source.getLines().toList
+ finally source.close()
partition match {
- case "id=1" => assert(lines === Seq(
- "2020-07-01,one",
- "2020-07-02,One",
- "2020-07-03,ONE",
- "2020-07-04,one"
- ))
- case "id=2" => assert(lines === Seq(
- "2020-07-01,two",
- "2020-07-02,Two",
- "2020-07-03,TWO",
- "2020-07-04,two"
- ))
- case "id=3" => assert(lines === Seq(
- "2020-07-01,three"
- ))
- case "id=4" => assert(lines === Seq(
- "2020-07-01,four"
- ))
+ case "id=1" =>
+ assert(
+ lines === Seq(
+ "2020-07-01,one",
+ "2020-07-02,One",
+ "2020-07-03,ONE",
+ "2020-07-04,one"
+ )
+ )
+ case "id=2" =>
+ assert(
+ lines === Seq(
+ "2020-07-01,two",
+ "2020-07-02,Two",
+ "2020-07-03,TWO",
+ "2020-07-04,two"
+ )
+ )
+ case "id=3" =>
+ assert(
+ lines === Seq(
+ "2020-07-01,three"
+ )
+ )
+ case "id=4" =>
+ assert(
+ lines === Seq(
+ "2020-07-01,four"
+ )
+ )
}
}
}
@@ -162,7 +179,9 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
test("write with desc partition order") {
withTempPath { dir =>
withUnpersist() { handle =>
- values.writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date".desc), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath)
+ values
+ .writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date".desc), unpersistHandle = Some(handle))
+ .csv(dir.getAbsolutePath)
}
val partitions = dir.list().filter(_.startsWith("id=")).sorted
@@ -173,26 +192,40 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
assert(files.length === 1)
val source = Source.fromFile(new File(file, files(0)))
- val lines = try source.getLines().toList finally source.close()
+ val lines =
+ try source.getLines().toList
+ finally source.close()
partition match {
- case "id=1" => assert(lines === Seq(
- "2020-07-04,one",
- "2020-07-03,ONE",
- "2020-07-02,One",
- "2020-07-01,one"
- ))
- case "id=2" => assert(lines === Seq(
- "2020-07-04,two",
- "2020-07-03,TWO",
- "2020-07-02,Two",
- "2020-07-01,two"
- ))
- case "id=3" => assert(lines === Seq(
- "2020-07-01,three"
- ))
- case "id=4" => assert(lines === Seq(
- "2020-07-01,four"
- ))
+ case "id=1" =>
+ assert(
+ lines === Seq(
+ "2020-07-04,one",
+ "2020-07-03,ONE",
+ "2020-07-02,One",
+ "2020-07-01,one"
+ )
+ )
+ case "id=2" =>
+ assert(
+ lines === Seq(
+ "2020-07-04,two",
+ "2020-07-03,TWO",
+ "2020-07-02,Two",
+ "2020-07-01,two"
+ )
+ )
+ case "id=3" =>
+ assert(
+ lines === Seq(
+ "2020-07-01,three"
+ )
+ )
+ case "id=4" =>
+ assert(
+ lines === Seq(
+ "2020-07-01,four"
+ )
+ )
}
}
}
@@ -202,7 +235,15 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
val projection = Some(Seq(col("id"), reverse(col("value"))))
withTempPath { path =>
withUnpersist() { handle =>
- values.writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date"), writtenProjection = projection, unpersistHandle = Some(handle)).csv(path.getAbsolutePath)
+ values
+ .writePartitionedBy(
+ Seq($"id"),
+ Seq.empty,
+ Seq($"date"),
+ writtenProjection = projection,
+ unpersistHandle = Some(handle)
+ )
+ .csv(path.getAbsolutePath)
}
val partitions = path.list().filter(_.startsWith("id=")).sorted
@@ -214,7 +255,8 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession {
val lines = files.flatMap { file =>
val source = Source.fromFile(new File(dir, file))
- try source.getLines().toList finally source.close()
+ try source.getLines().toList
+ finally source.close()
}
partition match {
diff --git a/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala
index 855460b6..4fc9e48a 100644
--- a/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala
@@ -22,7 +22,6 @@ import uk.co.gresearch.spark.SparkTestSession
import java.io.File
-
class AppSuite extends AnyFunSuite with SparkTestSession {
import spark.implicits._
@@ -38,15 +37,21 @@ class AppSuite extends AnyFunSuite with SparkTestSession {
// launch app
val jsonPath = new File(path, "diff.json").getAbsolutePath
- App.main(Array(
- "--left-format", "csv",
- "--left-schema", "id int, value string",
- "--output-format", "json",
- "--id", "id",
- leftPath,
- "right_parquet",
- jsonPath
- ))
+ App.main(
+ Array(
+ "--left-format",
+ "csv",
+ "--left-schema",
+ "id int, value string",
+ "--output-format",
+ "json",
+ "--id",
+ "id",
+ leftPath,
+ "right_parquet",
+ jsonPath
+ )
+ )
// assert written diff
val actual = spark.read.json(jsonPath)
@@ -67,14 +72,18 @@ class AppSuite extends AnyFunSuite with SparkTestSession {
// launch app
val outputPath = new File(path, "diff.parquet").getAbsolutePath
- App.main(Array(
- "--format", "parquet",
- "--id", "id",
- ) ++ filter.toSeq.flatMap(f => Array("--filter", f)) ++ Array(
- leftPath,
- rightPath,
- outputPath
- ))
+ App.main(
+ Array(
+ "--format",
+ "parquet",
+ "--id",
+ "id",
+ ) ++ filter.toSeq.flatMap(f => Array("--filter", f)) ++ Array(
+ leftPath,
+ rightPath,
+ outputPath
+ )
+ )
// assert written diff
val actual = spark.read.parquet(outputPath).orderBy($"id").collect()
@@ -98,14 +107,19 @@ class AppSuite extends AnyFunSuite with SparkTestSession {
// launch app
val outputPath = new File(path, "diff.parquet").getAbsolutePath
assertThrows[RuntimeException](
- App.main(Array(
- "--format", "parquet",
- "--id", "id",
- "--filter", "A",
- leftPath,
- rightPath,
- outputPath
- ))
+ App.main(
+ Array(
+ "--format",
+ "parquet",
+ "--id",
+ "id",
+ "--filter",
+ "A",
+ leftPath,
+ rightPath,
+ outputPath
+ )
+ )
)
}
}
@@ -122,14 +136,18 @@ class AppSuite extends AnyFunSuite with SparkTestSession {
// launch app
val outputPath = new File(path, "diff.parquet").getAbsolutePath
- App.main(Array(
- "--format", "parquet",
- "--statistics",
- "--id", "id",
- leftPath,
- rightPath,
- outputPath
- ))
+ App.main(
+ Array(
+ "--format",
+ "parquet",
+ "--statistics",
+ "--id",
+ "id",
+ leftPath,
+ rightPath,
+ outputPath
+ )
+ )
// assert written diff
val actual = spark.read.parquet(outputPath).as[(String, Long)].collect().toMap
diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala
index c2401bdd..f69ac4ef 100644
--- a/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala
@@ -24,13 +24,25 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.scalatest.funsuite.AnyFunSuite
import uk.co.gresearch.spark.SparkTestSession
-import uk.co.gresearch.spark.diff.DiffComparatorSuite.{decimalEnc, optionsWithRelaxedComparators, optionsWithTightComparators}
+import uk.co.gresearch.spark.diff.DiffComparatorSuite.{
+ decimalEnc,
+ optionsWithRelaxedComparators,
+ optionsWithTightComparators
+}
import uk.co.gresearch.spark.diff.comparator._
import java.sql.{Date, Timestamp}
import java.time.Duration
-case class Numbers(id: Int, longValue: Long, floatValue: Float, doubleValue: Double, decimalValue: Decimal, someInt: Option[Int], someLong: Option[Long])
+case class Numbers(
+ id: Int,
+ longValue: Long,
+ floatValue: Float,
+ doubleValue: Double,
+ decimalValue: Decimal,
+ someInt: Option[Int],
+ someLong: Option[Long]
+)
case class Strings(id: Int, string: String)
case class Dates(id: Int, date: Date)
case class Times(id: Int, time: Timestamp)
@@ -123,10 +135,12 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
Maps(7, Map(3 -> 4L, 2 -> 2L, 1 -> 1L)),
).toDS()
- def doTest(optionsWithTightComparators: DiffOptions,
- optionsWithRelaxedComparators: DiffOptions,
- left: DataFrame = this.left.toDF(),
- right: DataFrame = this.right.toDF()): Unit = {
+ def doTest(
+ optionsWithTightComparators: DiffOptions,
+ optionsWithRelaxedComparators: DiffOptions,
+ left: DataFrame = this.left.toDF(),
+ right: DataFrame = this.right.toDF()
+ ): Unit = {
// left and right numbers have some differences
val actualWithoutComparators = left.diff(right, "id").orderBy($"id")
@@ -178,22 +192,44 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
.withComparator((left: Decimal, right: Decimal) => left.abs == right.abs),
"default any equiv" -> DiffOptions.default
.withDefaultComparator((_: Any, _: Any) => true),
-
"typed diff comparator" -> DiffOptions.default
.withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs))
- .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), LongType, FloatType, DoubleType, DecimalType(38, 18)),
+ .withComparator(
+ (left: Column, right: Column) => abs(left) <=> abs(right),
+ LongType,
+ FloatType,
+ DoubleType,
+ DecimalType(38, 18)
+ ),
"typed diff comparator for type" -> DiffOptions.default
// only works if data type is equal to input type of typed diff comparator
.withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), IntegerType)
- .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), LongType, FloatType, DoubleType, DecimalType(38, 18)),
-
+ .withComparator(
+ (left: Column, right: Column) => abs(left) <=> abs(right),
+ LongType,
+ FloatType,
+ DoubleType,
+ DecimalType(38, 18)
+ ),
"diff comparator for type" -> DiffOptions.default
.withComparator((left: Column, right: Column) => abs(left) <=> abs(right), IntegerType)
- .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), LongType, FloatType, DoubleType, DecimalType(38, 18)),
+ .withComparator(
+ (left: Column, right: Column) => abs(left) <=> abs(right),
+ LongType,
+ FloatType,
+ DoubleType,
+ DecimalType(38, 18)
+ ),
"diff comparator for name" -> DiffOptions.default
.withComparator((left: Column, right: Column) => abs(left) <=> abs(right), "someInt")
- .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), "longValue", "floatValue", "doubleValue", "someLong", "decimalValue"),
-
+ .withComparator(
+ (left: Column, right: Column) => abs(left) <=> abs(right),
+ "longValue",
+ "floatValue",
+ "doubleValue",
+ "someLong",
+ "decimalValue"
+ ),
"encoder equiv" -> DiffOptions.default
.withComparator((left: Int, right: Int) => left.abs == right.abs)
.withComparator((left: Long, right: Long) => left.abs == right.abs)
@@ -211,12 +247,14 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
.withComparator((left: Long, right: Long) => left.abs == right.abs, Encoders.scalaLong, "longValue", "someLong")
.withComparator((left: Float, right: Float) => left.abs == right.abs, Encoders.scalaFloat, "floatValue")
.withComparator((left: Double, right: Double) => left.abs == right.abs, Encoders.scalaDouble, "doubleValue")
- .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs, ExpressionEncoder[Decimal](), "decimalValue"),
-
+ .withComparator(
+ (left: Decimal, right: Decimal) => left.abs == right.abs,
+ ExpressionEncoder[Decimal](),
+ "decimalValue"
+ ),
"typed equiv for type" -> DiffOptions.default
.withComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType)
.withComparator(alwaysTrueEquiv, LongType, FloatType, DoubleType, DecimalType(38, 18)),
-
"any equiv for column name" -> DiffOptions.default
.withComparator(alwaysTrueEquiv, "someInt")
.withComparator(alwaysTrueEquiv, "longValue", "floatValue", "doubleValue", "someLong", "decimalValue")
@@ -228,7 +266,8 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
val allValuesEqual = Set("default any equiv", "any equiv for type", "any equiv for column name").contains(label)
val unchangedIds = if (allValuesEqual) Seq(2, 3) else Seq(2)
- val expected = diffWithoutComparators.withColumn("diff", when($"id".isin(unchangedIds: _*), lit("N")).otherwise($"diff"))
+ val expected =
+ diffWithoutComparators.withColumn("diff", when($"id".isin(unchangedIds: _*), lit("N")).otherwise($"diff"))
assert(expected.where($"diff" === "C").count() === 3 - unchangedIds.size)
val actual = left.diff(rightSign, options, "id").orderBy($"id").collect()
@@ -240,44 +279,45 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
test("null-aware comparator") {
val options = DiffOptions.default.withComparator(
// only if this method is called with nulls, the expected result can occur
- (x: Column, y: Column) => x.isNull || y.isNull || x === y, StringType)
+ (x: Column, y: Column) => x.isNull || y.isNull || x === y,
+ StringType
+ )
val diff = leftStrings.diff(rightStrings, options, "id").orderBy($"id").collect()
- assert(diff === Seq(
- Row("N", 1, "1", "1"),
- Row("N", 2, null, "2"),
- Row("N", 3, "3", null),
- Row("N", 4, null, null),
- ))
+ assert(
+ diff === Seq(
+ Row("N", 1, "1", "1"),
+ Row("N", 2, null, "2"),
+ Row("N", 3, "3", null),
+ Row("N", 4, null, null),
+ )
+ )
}
Seq(
"diff comparator" -> (DiffOptions.default
.withDefaultComparator((_: Column, _: Column) => lit(1)),
- Seq(
- "'(1 AND 1)' requires boolean type, not int", // until Spark 3.3
- "\"(1 AND 1)\" due to data type mismatch: " + // Spark 3.4 and beyond
- "the binary operator requires the input type \"BOOLEAN\", not \"INT\"."
- )
- ),
+ Seq(
+ "'(1 AND 1)' requires boolean type, not int", // until Spark 3.3
+ "\"(1 AND 1)\" due to data type mismatch: " + // Spark 3.4 and beyond
+ "the binary operator requires the input type \"BOOLEAN\", not \"INT\"."
+ )),
"encoder equiv" -> (DiffOptions.default
.withDefaultComparator((_: Int, _: Int) => true),
- Seq(
- "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1
- "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3
- "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond
- "the binary operator requires the input type \"INT\", not \"BIGINT\"."
- )
- ),
+ Seq(
+ "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1
+ "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3
+ "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond
+ "the binary operator requires the input type \"INT\", not \"BIGINT\"."
+ )),
"typed equiv" -> (DiffOptions.default
.withDefaultComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType)),
- Seq(
- "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1
- "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3
- "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond
- "the binary operator requires the input type \"INT\", not \"BIGINT\"."
- )
- )
+ Seq(
+ "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1
+ "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3
+ "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond
+ "the binary operator requires the input type \"INT\", not \"BIGINT\"."
+ ))
).foreach { case (label, (options, expecteds)) =>
test(s"with comparator of incompatible type - $label") {
val exception = intercept[AnalysisException] {
@@ -289,34 +329,45 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
}
test("absolute epsilon comparator (inclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.5).asAbsolute().asInclusive())
- val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asInclusive())
+ val optionsWithTightComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.5).asAbsolute().asInclusive())
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asInclusive())
doTest(optionsWithTightComparator, optionsWithRelaxedComparator)
}
test("absolute epsilon comparator (exclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asExclusive())
- val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.001).asAbsolute().asExclusive())
+ val optionsWithTightComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asExclusive())
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.001).asAbsolute().asExclusive())
doTest(optionsWithTightComparator, optionsWithRelaxedComparator)
}
test("relative epsilon comparator (inclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.1).asRelative().asInclusive())
- val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1/3.0).asRelative().asInclusive())
+ val optionsWithTightComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.1).asRelative().asInclusive())
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0).asRelative().asInclusive())
doTest(optionsWithTightComparator, optionsWithRelaxedComparator)
}
test("relative epsilon comparator (exclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1/3.0).asRelative().asExclusive())
- val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1/3.0 + .001).asRelative().asExclusive())
+ val optionsWithTightComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0).asRelative().asExclusive())
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0 + .001).asRelative().asExclusive())
doTest(optionsWithTightComparator, optionsWithRelaxedComparator)
}
test("whitespace agnostic string comparator") {
val left = Seq(Strings(1, "one"), Strings(2, "two spaces "), Strings(3, "three"), Strings(4, "four")).toDF()
- val right = Seq(Strings(1, "one"), Strings(2, " two \t\nspaces"), Strings(3, "three\nspaces"), Strings(5, "five")).toDF()
- val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = false))
- val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = true))
+ val right =
+ Seq(Strings(1, "one"), Strings(2, " two \t\nspaces"), Strings(3, "three\nspaces"), Strings(5, "five")).toDF()
+ val optionsWithTightComparator =
+ DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = false))
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = true))
doTest(optionsWithTightComparator, optionsWithRelaxedComparator, left, right)
}
@@ -331,26 +382,34 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
}
} else {
test("duration comparator with date (inclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(23)).asInclusive(), "date")
- val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asInclusive(), "date")
+ val optionsWithTightComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(23)).asInclusive(), "date")
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asInclusive(), "date")
doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftDates.toDF, rightDates.toDF)
}
test("duration comparator with date (exclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asExclusive(), "date")
- val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(25)).asExclusive(), "date")
+ val optionsWithTightComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asExclusive(), "date")
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(25)).asExclusive(), "date")
doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftDates.toDF, rightDates.toDF)
}
test("duration comparator with time (inclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(59)).asInclusive(), "time")
- val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asInclusive(), "time")
+ val optionsWithTightComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(59)).asInclusive(), "time")
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asInclusive(), "time")
doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftTimes.toDF, rightTimes.toDF)
}
test("duration comparator with time (exclusive)") {
- val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asExclusive(), "time")
- val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(61)).asExclusive(), "time")
+ val optionsWithTightComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asExclusive(), "time")
+ val optionsWithRelaxedComparator =
+ DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(61)).asExclusive(), "time")
doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftTimes.toDF, rightTimes.toDF)
}
}
@@ -365,12 +424,15 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
val options = DiffOptions.default.withComparator(DiffComparators.map[Int, Long](sensitive), "map")
val actual = leftMaps.diff(rightMaps, options, "id").orderBy($"id").collect()
- val diffs = Seq((1, "N"), (2, "C"), (3, "C"), (4, "D"), (5, "I"), (6, if (sensitive) "C" else "N"), (7, "C")).toDF("id", "diff")
- val expected = leftMaps.withColumnRenamed("map", "left_map")
+ val diffs = Seq((1, "N"), (2, "C"), (3, "C"), (4, "D"), (5, "I"), (6, if (sensitive) "C" else "N"), (7, "C"))
+ .toDF("id", "diff")
+ val expected = leftMaps
+ .withColumnRenamed("map", "left_map")
.join(rightMaps.withColumnRenamed("map", "right_map"), Seq("id"), "fullouter")
.join(diffs, "id")
.select($"diff", $"id", $"left_map", $"right_map")
- .orderBy($"id").collect()
+ .orderBy($"id")
+ .collect()
assert(actual === expected)
}
}
@@ -386,16 +448,27 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession {
}
val diffComparatorMethodTests: Seq[(String, (() => DiffComparator, DiffComparator))] =
- if(DurationDiffComparator.isSupportedBySpark) {
- Seq("duration" -> (() => DiffComparators.duration(Duration.ofSeconds(1)).asExclusive(), DurationDiffComparator(Duration.ofSeconds(1), inclusive = false)))
- } else { Seq.empty } ++ Seq(
- "default" -> (() => DiffComparators.default(), DefaultDiffComparator),
- "nullSafeEqual" -> (() => DiffComparators.nullSafeEqual(), NullSafeEqualDiffComparator),
- "equiv with encoder" -> (() => DiffComparators.equiv(IntEquiv), EquivDiffComparator(IntEquiv)),
- "equiv with type" -> (() => DiffComparators.equiv(IntEquiv, IntegerType), EquivDiffComparator(IntEquiv, IntegerType)),
- "equiv with any" -> (() => DiffComparators.equiv(AnyEquiv), EquivDiffComparator(AnyEquiv)),
- "epsilon" -> (() => DiffComparators.epsilon(1.0).asAbsolute().asExclusive(), EpsilonDiffComparator(1.0, relative = false, inclusive = false))
- )
+ if (DurationDiffComparator.isSupportedBySpark) {
+ Seq(
+ "duration" -> (() => DiffComparators.duration(Duration.ofSeconds(1)).asExclusive(), DurationDiffComparator(
+ Duration.ofSeconds(1),
+ inclusive = false
+ ))
+ )
+ } else
+ { Seq.empty } ++ Seq(
+ "default" -> (() => DiffComparators.default(), DefaultDiffComparator),
+ "nullSafeEqual" -> (() => DiffComparators.nullSafeEqual(), NullSafeEqualDiffComparator),
+ "equiv with encoder" -> (() => DiffComparators.equiv(IntEquiv), EquivDiffComparator(IntEquiv)),
+ "equiv with type" -> (() =>
+ DiffComparators.equiv(IntEquiv, IntegerType), EquivDiffComparator(IntEquiv, IntegerType)),
+ "equiv with any" -> (() => DiffComparators.equiv(AnyEquiv), EquivDiffComparator(AnyEquiv)),
+ "epsilon" -> (() => DiffComparators.epsilon(1.0).asAbsolute().asExclusive(), EpsilonDiffComparator(
+ 1.0,
+ relative = false,
+ inclusive = false
+ ))
+ )
diffComparatorMethodTests.foreach { case (label, (method, expected)) =>
test(s"DiffComparator.$label") {
@@ -414,9 +487,12 @@ object DiffComparatorSuite {
val tightIntComparator: EquivDiffComparator[Int] = EquivDiffComparator((x: Int, y: Int) => math.abs(x - y) < 1)
val tightLongComparator: EquivDiffComparator[Long] = EquivDiffComparator((x: Long, y: Long) => math.abs(x - y) < 1)
- val tightFloatComparator: EquivDiffComparator[Float] = EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) < 0.001)
- val tightDoubleComparator: EquivDiffComparator[Double] = EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) < 0.001)
- val tightDecimalComparator: EquivDiffComparator[Decimal] = EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs < Decimal(0.001))
+ val tightFloatComparator: EquivDiffComparator[Float] =
+ EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) < 0.001)
+ val tightDoubleComparator: EquivDiffComparator[Double] =
+ EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) < 0.001)
+ val tightDecimalComparator: EquivDiffComparator[Decimal] =
+ EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs < Decimal(0.001))
val optionsWithTightComparators: DiffOptions = DiffOptions.default
.withComparator(tightIntComparator, IntegerType)
@@ -427,9 +503,12 @@ object DiffComparatorSuite {
val relaxedIntComparator: EquivDiffComparator[Int] = EquivDiffComparator((x: Int, y: Int) => math.abs(x - y) <= 1)
val relaxedLongComparator: EquivDiffComparator[Long] = EquivDiffComparator((x: Long, y: Long) => math.abs(x - y) <= 1)
- val relaxedFloatComparator: EquivDiffComparator[Float] = EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) <= 0.001)
- val relaxedDoubleComparator: EquivDiffComparator[Double] = EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) <= 0.001)
- val relaxedDecimalComparator: EquivDiffComparator[Decimal] = EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs <= Decimal(0.001))
+ val relaxedFloatComparator: EquivDiffComparator[Float] =
+ EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) <= 0.001)
+ val relaxedDoubleComparator: EquivDiffComparator[Double] =
+ EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) <= 0.001)
+ val relaxedDecimalComparator: EquivDiffComparator[Decimal] =
+ EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs <= Decimal(0.001))
val optionsWithRelaxedComparators: DiffOptions = DiffOptions.default
.withComparator(relaxedIntComparator, IntegerType)
diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala
index 64629f93..191375ab 100644
--- a/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala
@@ -38,17 +38,19 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession {
test("diff options left and right prefixes") {
// test the copy method (constructor), not the fluent methods
val default = DiffOptions.default
- doTestRequirement(default.copy(leftColumnPrefix = ""),
- "Left column prefix must not be empty")
- doTestRequirement(default.copy(rightColumnPrefix = ""),
- "Right column prefix must not be empty")
+ doTestRequirement(default.copy(leftColumnPrefix = ""), "Left column prefix must not be empty")
+ doTestRequirement(default.copy(rightColumnPrefix = ""), "Right column prefix must not be empty")
val prefix = "prefix"
- doTestRequirement(default.copy(leftColumnPrefix = prefix, rightColumnPrefix = prefix),
- s"Left and right column prefix must be distinct: $prefix")
+ doTestRequirement(
+ default.copy(leftColumnPrefix = prefix, rightColumnPrefix = prefix),
+ s"Left and right column prefix must be distinct: $prefix"
+ )
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
- doTestRequirement(default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase),
- s"Left and right column prefix must be distinct: $prefix")
+ doTestRequirement(
+ default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase),
+ s"Left and right column prefix must be distinct: $prefix"
+ )
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase)
@@ -69,18 +71,30 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession {
assert(emptyNochangeDiffValueOpts.nochangeDiffValue.isEmpty)
Seq("value", "").foreach { value =>
- doTestRequirement(default.copy(insertDiffValue = value, changeDiffValue = value),
- s"Diff values must be distinct: List($value, $value, D, N)")
- doTestRequirement(default.copy(insertDiffValue = value, deleteDiffValue = value),
- s"Diff values must be distinct: List($value, C, $value, N)")
- doTestRequirement(default.copy(insertDiffValue = value, nochangeDiffValue = value),
- s"Diff values must be distinct: List($value, C, D, $value)")
- doTestRequirement(default.copy(changeDiffValue = value, deleteDiffValue = value),
- s"Diff values must be distinct: List(I, $value, $value, N)")
- doTestRequirement(default.copy(changeDiffValue = value, nochangeDiffValue = value),
- s"Diff values must be distinct: List(I, $value, D, $value)")
- doTestRequirement(default.copy(deleteDiffValue = value, nochangeDiffValue = value),
- s"Diff values must be distinct: List(I, C, $value, $value)")
+ doTestRequirement(
+ default.copy(insertDiffValue = value, changeDiffValue = value),
+ s"Diff values must be distinct: List($value, $value, D, N)"
+ )
+ doTestRequirement(
+ default.copy(insertDiffValue = value, deleteDiffValue = value),
+ s"Diff values must be distinct: List($value, C, $value, N)"
+ )
+ doTestRequirement(
+ default.copy(insertDiffValue = value, nochangeDiffValue = value),
+ s"Diff values must be distinct: List($value, C, D, $value)"
+ )
+ doTestRequirement(
+ default.copy(changeDiffValue = value, deleteDiffValue = value),
+ s"Diff values must be distinct: List(I, $value, $value, N)"
+ )
+ doTestRequirement(
+ default.copy(changeDiffValue = value, nochangeDiffValue = value),
+ s"Diff values must be distinct: List(I, $value, D, $value)"
+ )
+ doTestRequirement(
+ default.copy(deleteDiffValue = value, nochangeDiffValue = value),
+ s"Diff values must be distinct: List(I, C, $value, $value)"
+ )
}
}
@@ -177,11 +191,16 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession {
DiffOptions.default
.withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), LongType, FloatType)
}
- assert(exceptionMulti.getMessage.contains("Comparator with input type int cannot be used for data type bigint, float"))
+ assert(
+ exceptionMulti.getMessage.contains("Comparator with input type int cannot be used for data type bigint, float")
+ )
}
test("fluent methods of diff options") {
- assert(DiffMode.Default != DiffMode.LeftSide, "test assumption on default diff mode must hold, otherwise test is trivial")
+ assert(
+ DiffMode.Default != DiffMode.LeftSide,
+ "test assumption on default diff mode must hold, otherwise test is trivial"
+ )
val cmp1 = new DiffComparator {
override def equiv(left: Column, right: Column): Column = lit(true)
@@ -211,7 +230,21 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession {
val dexpectedDefCmp = cmp1
val expectedDtCmps = Map(IntegerType.asInstanceOf[DataType] -> cmp2)
val expectedColCmps = Map("col1" -> cmp3)
- val expected = DiffOptions("d", "l", "r", "i", "c", "d", "n", Some("change"), DiffMode.LeftSide, sparseMode = true, dexpectedDefCmp, expectedDtCmps, expectedColCmps)
+ val expected = DiffOptions(
+ "d",
+ "l",
+ "r",
+ "i",
+ "c",
+ "d",
+ "n",
+ Some("change"),
+ DiffMode.LeftSide,
+ sparseMode = true,
+ dexpectedDefCmp,
+ expectedDtCmps,
+ expectedColCmps
+ )
assert(options === expected)
}
diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala
index c47551c0..1fd8f2c4 100644
--- a/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala
@@ -38,60 +38,50 @@ case class Value9up(ID: Int, SEQ: Option[Int], VALUE: Option[String], INFO: Opti
case class ValueLeft(left_id: Int, value: Option[String])
case class ValueRight(right_id: Int, value: Option[String])
-case class DiffAs(diff: String,
- id: Int,
- left_value: Option[String],
- right_value: Option[String])
-case class DiffAs8(diff: String,
- id: Int,
- seq: Option[Int],
- left_value: Option[String],
- right_value: Option[String],
- left_meta: Option[String],
- right_meta: Option[String])
-case class DiffAs8SideBySide(diff: String,
- id: Int,
- seq: Option[Int],
- left_value: Option[String],
- left_meta: Option[String],
- right_value: Option[String],
- right_meta: Option[String])
-case class DiffAs8OneSide(diff: String,
- id: Int,
- seq: Option[Int],
- value: Option[String],
- meta: Option[String])
-case class DiffAs8changes(diff: String,
- changed: Array[String],
- id: Int,
- seq: Option[Int],
- left_value: Option[String],
- right_value: Option[String],
- left_meta: Option[String],
- right_meta: Option[String])
-case class DiffAs8and9(diff: String,
- id: Int,
- seq: Option[Int],
- left_value: Option[String],
- right_value: Option[String],
- left_meta: Option[String],
- right_info: Option[String])
-
-case class DiffAsCustom(action: String,
- id: Int,
- before_value: Option[String],
- after_value: Option[String])
-case class DiffAsSubset(diff: String,
- id: Int,
- left_value: Option[String])
-case class DiffAsExtra(diff: String,
- id: Int,
- left_value: Option[String],
- right_value: Option[String],
- extra: String)
-case class DiffAsOneSide(diff: String,
- id: Int,
- value: Option[String])
+case class DiffAs(diff: String, id: Int, left_value: Option[String], right_value: Option[String])
+case class DiffAs8(
+ diff: String,
+ id: Int,
+ seq: Option[Int],
+ left_value: Option[String],
+ right_value: Option[String],
+ left_meta: Option[String],
+ right_meta: Option[String]
+)
+case class DiffAs8SideBySide(
+ diff: String,
+ id: Int,
+ seq: Option[Int],
+ left_value: Option[String],
+ left_meta: Option[String],
+ right_value: Option[String],
+ right_meta: Option[String]
+)
+case class DiffAs8OneSide(diff: String, id: Int, seq: Option[Int], value: Option[String], meta: Option[String])
+case class DiffAs8changes(
+ diff: String,
+ changed: Array[String],
+ id: Int,
+ seq: Option[Int],
+ left_value: Option[String],
+ right_value: Option[String],
+ left_meta: Option[String],
+ right_meta: Option[String]
+)
+case class DiffAs8and9(
+ diff: String,
+ id: Int,
+ seq: Option[Int],
+ left_value: Option[String],
+ right_value: Option[String],
+ left_meta: Option[String],
+ right_info: Option[String]
+)
+
+case class DiffAsCustom(action: String, id: Int, before_value: Option[String], after_value: Option[String])
+case class DiffAsSubset(diff: String, id: Int, left_value: Option[String])
+case class DiffAsExtra(diff: String, id: Int, left_value: Option[String], right_value: Option[String], extra: String)
+case class DiffAsOneSide(diff: String, id: Int, value: Option[String])
object DiffSuite {
def left(spark: SparkSession): Dataset[Value] = {
@@ -170,7 +160,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Value8(3, None, None, None)
).toDS()
- lazy val right9: Dataset[Value9] = right8.withColumn("info", regexp_replace($"meta", "user", "info")).drop("meta").as[Value9]
+ lazy val right9: Dataset[Value9] =
+ right8.withColumn("info", regexp_replace($"meta", "user", "info")).drop("meta").as[Value9]
lazy val expectedDiffColumns: Seq[String] = Seq("diff", "id", "left_value", "right_value")
@@ -183,9 +174,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("D", 4, "four", null)
)
- lazy val expectedDiffAs: Seq[DiffAs] = expectedDiff.map(r =>
- DiffAs(r.getString(0), r.getInt(1), Option(r.getString(2)), Option(r.getString(3)))
- )
+ lazy val expectedDiffAs: Seq[DiffAs] =
+ expectedDiff.map(r => DiffAs(r.getString(0), r.getInt(1), Option(r.getString(2)), Option(r.getString(3))))
lazy val expectedDiff7: Seq[Row] = Seq(
Row("C", 1, "one", "One", "one label", "one label"),
@@ -200,9 +190,13 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("I", 10, null, null, null, null)
)
- lazy val expectedSideBySideDiff7: Seq[Row] = expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5)))
- lazy val expectedLeftSideDiff7: Seq[Row] = expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4)))
- lazy val expectedRightSideDiff7: Seq[Row] = expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5)))
+ lazy val expectedSideBySideDiff7: Seq[Row] = expectedDiff7.map(row =>
+ Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5))
+ )
+ lazy val expectedLeftSideDiff7: Seq[Row] =
+ expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4)))
+ lazy val expectedRightSideDiff7: Seq[Row] =
+ expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5)))
lazy val expectedSparseDiff7: Seq[Row] = Seq(
Row("C", 1, "one", "One", null, null),
@@ -217,9 +211,13 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("I", 10, null, null, null, null)
)
- lazy val expectedSideBySideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5)))
- lazy val expectedLeftSideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4)))
- lazy val expectedRightSideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5)))
+ lazy val expectedSideBySideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row =>
+ Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5))
+ )
+ lazy val expectedLeftSideSparseDiff7: Seq[Row] =
+ expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4)))
+ lazy val expectedRightSideSparseDiff7: Seq[Row] =
+ expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5)))
lazy val expectedDiff7WithChanges: Seq[Row] = Seq(
Row("C", Seq("value"), 1, "one", "One", "one label", "one label"),
@@ -256,9 +254,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("N", 3, null, null, null, null, null)
)
- lazy val expectedSideBySideDiff8: Seq[Row] = expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6)))
- lazy val expectedLeftSideDiff8: Seq[Row] = expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5)))
- lazy val expectedRightSideDiff8: Seq[Row] = expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6)))
+ lazy val expectedSideBySideDiff8: Seq[Row] =
+ expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6)))
+ lazy val expectedLeftSideDiff8: Seq[Row] =
+ expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5)))
+ lazy val expectedRightSideDiff8: Seq[Row] =
+ expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6)))
lazy val expectedSparseDiff8: Seq[Row] = Seq(
Row("N", 1, 1, null, null, "user1", "user2"),
@@ -271,50 +272,66 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("N", 3, null, null, null, null, null)
)
- lazy val expectedSideBySideSparseDiff8: Seq[Row] = expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6)))
- lazy val expectedLeftSideSparseDiff8: Seq[Row] = expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5)))
- lazy val expectedRightSideSparseDiff8: Seq[Row] = expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6)))
+ lazy val expectedSideBySideSparseDiff8: Seq[Row] =
+ expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6)))
+ lazy val expectedLeftSideSparseDiff8: Seq[Row] =
+ expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5)))
+ lazy val expectedRightSideSparseDiff8: Seq[Row] =
+ expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6)))
lazy val expectedDiffAs8: Seq[DiffAs8] = expectedDiff8.map(r =>
- DiffAs8(r.getString(0),
- r.getInt(1), Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)),
- Option(r.getString(3)), Option(r.getString(4)),
- Option(r.getString(5)), Option(r.getString(6))
+ DiffAs8(
+ r.getString(0),
+ r.getInt(1),
+ Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)),
+ Option(r.getString(3)),
+ Option(r.getString(4)),
+ Option(r.getString(5)),
+ Option(r.getString(6))
)
)
lazy val expectedDiff8WithChanges: Seq[Row] = expectedDiff8.map(r =>
- Row(r.get(0),
+ Row(
+ r.get(0),
r.get(0) match {
case "N" => Seq.empty
case "I" => null
case "C" => Seq("value")
case "D" => null
},
- r.get(1), r.get(2),
- r.getString(3), r.getString(4),
- r.getString(5), r.getString(6)
+ r.get(1),
+ r.get(2),
+ r.getString(3),
+ r.getString(4),
+ r.getString(5),
+ r.getString(6)
)
)
lazy val expectedDiffAs8and9: Seq[DiffAs8and9] = expectedDiff8and9.map(r =>
- DiffAs8and9(r.getString(0),
- r.getInt(1), Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)),
- Option(r.getString(3)), Option(r.getString(4)),
- Option(r.getString(5)), Option(r.getString(6))
+ DiffAs8and9(
+ r.getString(0),
+ r.getInt(1),
+ Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)),
+ Option(r.getString(3)),
+ Option(r.getString(4)),
+ Option(r.getString(5)),
+ Option(r.getString(6))
)
)
- lazy val expectedDiffWith8and9: Seq[(String, Value8, Value9)] = expectedDiffAs8and9.map(v => (
- v.diff,
- if (v.diff == "I") null else Value8(v.id, v.seq, v.left_value, v.left_meta),
- if (v.diff == "D") null else Value9(v.id, v.seq, v.right_value, v.right_info)
- ))
-
- lazy val expectedDiffWith8and9up: Seq[(String, Value8, Value9up)] = expectedDiffWith8and9.map(t =>
- t.copy(_3 = Option(t._3).map(v => Value9up(v.id, v.seq, v.value, v.info)).orNull)
+ lazy val expectedDiffWith8and9: Seq[(String, Value8, Value9)] = expectedDiffAs8and9.map(v =>
+ (
+ v.diff,
+ if (v.diff == "I") null else Value8(v.id, v.seq, v.left_value, v.left_meta),
+ if (v.diff == "D") null else Value9(v.id, v.seq, v.right_value, v.right_info)
+ )
)
+ lazy val expectedDiffWith8and9up: Seq[(String, Value8, Value9up)] =
+ expectedDiffWith8and9.map(t => t.copy(_3 = Option(t._3).map(v => Value9up(v.id, v.seq, v.value, v.info)).orNull))
+
test("distinct prefix for") {
assert(distinctPrefixFor(Seq.empty[String]) === "_")
assert(distinctPrefixFor(Seq("a")) === "_")
@@ -328,9 +345,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
test("diff dataframe with duplicate columns") {
val df = Seq(1).toDF("id").select($"id", $"id")
- doTestRequirement(df.diff(df, "id"),
+ doTestRequirement(
+ df.diff(df, "id"),
"The datasets have duplicate columns.\n" +
- "Left column names: id, id\nRight column names: id, id")
+ "Left column names: id, id\nRight column names: id, id"
+ )
}
test("diff with no id column") {
@@ -383,8 +402,7 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
test("diff with one id column case-sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
- doTestRequirement(left.diff(right, "ID"),
- "Some id columns do not exist: ID missing among id, value")
+ doTestRequirement(left.diff(right, "ID"), "Some id columns do not exist: ID missing among id, value")
val actual = left.diff(right, "id").orderBy("id")
val reverse = right.diff(left, "id").orderBy("id")
@@ -526,10 +544,14 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
)
val expectedColumns = Seq(
"diff",
- "id", "seq",
- "left_value1", "right_value1",
- "left_value2", "right_value2",
- "left_value3", "right_value3"
+ "id",
+ "seq",
+ "left_value1",
+ "right_value1",
+ "left_value2",
+ "right_value2",
+ "left_value3",
+ "right_value3"
)
val actual = left.diff(right, "id", "seq").orderBy("id", "seq")
@@ -548,10 +570,14 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
)
val expectedColumns = Seq(
"diff",
- "id", "seq",
- "left_value2", "right_value2",
- "left_value3", "right_value3",
- "left_value1", "right_value1"
+ "id",
+ "seq",
+ "left_value2",
+ "right_value2",
+ "left_value3",
+ "right_value3",
+ "left_value1",
+ "right_value1"
)
val actual = right.diff(left, "id", "seq").orderBy("id", "seq")
@@ -570,7 +596,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("I", "val2.2.1", 2, "val2.2.2", 2, "val2.2.3")
)
val expectedColumns = Seq(
- "diff", "value1", "id", "value2", "seq", "value3"
+ "diff",
+ "value1",
+ "id",
+ "value2",
+ "seq",
+ "value3"
)
val actual = left.diff(right).orderBy("id", "seq", "diff")
@@ -589,7 +620,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
Row("I", "val2.2.1", 2, "val2.2.2", 2, "val2.2.3")
)
val expectedColumns = Seq(
- "diff", "value1", "id", "value2", "seq", "value3"
+ "diff",
+ "value1",
+ "id",
+ "value2",
+ "seq",
+ "value3"
)
val actual = left.diff(right).orderBy("id", "seq", "diff")
@@ -617,9 +653,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val expectedColumns = Seq(
"diff",
"id",
- "left_left_value", "right_left_value",
- "left_right_value", "right_right_value",
- "left_value", "right_value"
+ "left_left_value",
+ "right_left_value",
+ "left_right_value",
+ "right_right_value",
+ "left_value",
+ "right_value"
)
val expectedDiff = Seq(
Row("C", 1, "left", "Left", "right", "Right", "value", "Value")
@@ -633,21 +672,25 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq(Value4(1, "diff")).toDS()
val right = Seq(Value4(1, "Diff")).toDS()
- doTestRequirement(left.diff(right),
- "The id columns must not contain the diff column name 'diff': id, diff")
- doTestRequirement(left.diff(right, "diff"),
- "The id columns must not contain the diff column name 'diff': diff")
- doTestRequirement(left.diff(right, "diff", "id"),
- "The id columns must not contain the diff column name 'diff': diff, id")
+ doTestRequirement(left.diff(right), "The id columns must not contain the diff column name 'diff': id, diff")
+ doTestRequirement(left.diff(right, "diff"), "The id columns must not contain the diff column name 'diff': diff")
+ doTestRequirement(
+ left.diff(right, "diff", "id"),
+ "The id columns must not contain the diff column name 'diff': diff, id"
+ )
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
- doTestRequirement(left.withColumnRenamed("diff", "Diff")
- .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id"),
- "The id columns must not contain the diff column name 'diff': Diff, id")
+ doTestRequirement(
+ left
+ .withColumnRenamed("diff", "Diff")
+ .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id"),
+ "The id columns must not contain the diff column name 'diff': Diff, id"
+ )
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
- left.withColumnRenamed("diff", "Diff")
+ left
+ .withColumnRenamed("diff", "Diff")
.diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id")
}
}
@@ -660,7 +703,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val expectedColumns = Seq(
"diff",
"id",
- "left_diff", "right_diff"
+ "left_diff",
+ "right_diff"
)
val expectedDiff = Seq(
Row("C", 1, "diff", "Diff")
@@ -676,9 +720,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
.withLeftColumnPrefix("a")
.withRightColumnPrefix("b")
- doTestRequirement(left.diff(right, options, "id"),
+ doTestRequirement(
+ left.diff(right, options, "id"),
"The column prefixes 'a' and 'b', together with these non-id columns " +
- "must not produce the diff column name 'a_value': value")
+ "must not produce the diff column name 'a_value': value"
+ )
}
test("diff with left-side mode where non-id column would produce diff column name") {
@@ -708,9 +754,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
.withLeftColumnPrefix("A")
.withRightColumnPrefix("B")
- doTestRequirement(left.diff(right, options, "id"),
+ doTestRequirement(
+ left.diff(right, options, "id"),
"The column prefixes 'A' and 'B', together with these non-id columns " +
- "must not produce the diff column name 'a_value': value")
+ "must not produce the diff column name 'a_value': value"
+ )
}
}
@@ -747,7 +795,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val actual = left.diff(right, options, "id").orderBy("id")
val expectedColumns = Seq(
- "a_value", "id", "A_value", "B_value"
+ "a_value",
+ "id",
+ "A_value",
+ "B_value"
)
assert(actual.columns === expectedColumns)
@@ -761,9 +812,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
.withLeftColumnPrefix("a")
.withRightColumnPrefix("b")
- doTestRequirement(left.diff(right, options, "id"),
+ doTestRequirement(
+ left.diff(right, options, "id"),
"The column prefixes 'a' and 'b', together with these non-id columns " +
- "must not produce the change column name 'a_value': value")
+ "must not produce the change column name 'a_value': value"
+ )
}
test("diff where case-insensitive non-id column produces change column name") {
@@ -773,9 +826,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
.withLeftColumnPrefix("A")
.withRightColumnPrefix("B")
- doTestRequirement(left.diff(right, options, "id"),
+ doTestRequirement(
+ left.diff(right, options, "id"),
"The column prefixes 'A' and 'B', together with these non-id columns " +
- "must not produce the change column name 'a_value': value")
+ "must not produce the change column name 'a_value': value"
+ )
}
}
@@ -788,7 +843,13 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val actual = left7.diff(right7, options, "id").orderBy("id")
val expectedColumns = Seq(
- "diff", "a_value", "id", "A_value", "B_value", "A_label", "B_label"
+ "diff",
+ "a_value",
+ "id",
+ "A_value",
+ "B_value",
+ "A_label",
+ "B_label"
)
assert(actual.columns === expectedColumns)
@@ -804,9 +865,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq(Value5(1, "value")).toDS()
val right = Seq(Value5(1, "Value")).toDS()
- doTestRequirement(left.diff(right, options, "first_id"),
+ doTestRequirement(
+ left.diff(right, options, "first_id"),
"The column prefixes 'first' and 'second', together with these non-id columns " +
- "must not produce any id column name 'first_id': id")
+ "must not produce any id column name 'first_id': id"
+ )
}
test("diff where case-insensitive non-id column produces id column name") {
@@ -818,9 +881,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq(Value5(1, "value")).toDS()
val right = Seq(Value5(1, "Value")).toDS()
- doTestRequirement(left.diff(right, options, "first_id"),
+ doTestRequirement(
+ left.diff(right, options, "first_id"),
"The column prefixes 'FIRST' and 'SECOND', together with these non-id columns " +
- "must not produce any id column name 'first_id': id")
+ "must not produce any id column name 'first_id': id"
+ )
}
}
@@ -835,7 +900,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val actual = left.diff(right, options, "first_id")
val expectedColumns = Seq(
- "diff", "first_id", "FIRST_id", "SECOND_id"
+ "diff",
+ "first_id",
+ "FIRST_id",
+ "SECOND_id"
)
assert(actual.columns === expectedColumns)
@@ -870,8 +938,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq((1, "info")).toDF("id", "info")
val right = Seq((1, "meta")).toDF("id", "meta")
- doTestRequirement(left.diff(right, Seq.empty, Seq("id", "info", "meta")),
- "The schema except ignored columns must not be empty")
+ doTestRequirement(
+ left.diff(right, Seq.empty, Seq("id", "info", "meta")),
+ "The schema except ignored columns must not be empty"
+ )
}
test("diff with different types") {
@@ -879,10 +949,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq((1, "str")).toDF("id", "value")
val right = Seq((1, 2)).toDF("id", "value")
- doTestRequirement(left.diff(right),
+ doTestRequirement(
+ left.diff(right),
"The datasets do not have the same schema.\n" +
"Left extra columns: value (StringType)\n" +
- "Right extra columns: value (IntegerType)")
+ "Right extra columns: value (IntegerType)"
+ )
}
test("diff with ignored columns of different types") {
@@ -891,12 +963,16 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val right = Seq((1, 2)).toDF("id", "value")
val actual = left.diff(right, Seq.empty, Seq("value"))
- assert(ignoreNullable(actual.schema) === StructType(Seq(
- StructField("diff", StringType),
- StructField("id", IntegerType),
- StructField("left_value", StringType),
- StructField("right_value", IntegerType),
- )))
+ assert(
+ ignoreNullable(actual.schema) === StructType(
+ Seq(
+ StructField("diff", StringType),
+ StructField("id", IntegerType),
+ StructField("left_value", StringType),
+ StructField("right_value", IntegerType),
+ )
+ )
+ )
assert(actual.collect() === Seq(Row("N", 1, "str", 2)))
}
@@ -922,10 +998,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq((1, "str")).toDF("id", "value")
val right = Seq((1, "str")).toDF("id", "comment")
- doTestRequirement(left.diff(right, "id"),
+ doTestRequirement(
+ left.diff(right, "id"),
"The datasets do not have the same schema.\n" +
"Left extra columns: value (StringType)\n" +
- "Right extra columns: comment (StringType)")
+ "Right extra columns: comment (StringType)"
+ )
}
test("diff with case-insensitive column names") {
@@ -950,16 +1028,20 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val right = this.right.toDF("ID", "VaLuE")
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
- doTestRequirement(left.diff(right, "id"),
+ doTestRequirement(
+ left.diff(right, "id"),
"The datasets do not have the same schema.\n" +
"Left extra columns: id (IntegerType), value (StringType)\n" +
- "Right extra columns: ID (IntegerType), VaLuE (StringType)")
+ "Right extra columns: ID (IntegerType), VaLuE (StringType)"
+ )
}
}
test("diff of non-existing id column") {
- doTestRequirement(left.diff(right, "does not exists"),
- "Some id columns do not exist: does not exists missing among id, value")
+ doTestRequirement(
+ left.diff(right, "does not exists"),
+ "Some id columns do not exist: does not exists missing among id, value"
+ )
}
test("diff with different number of columns") {
@@ -967,20 +1049,24 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = Seq((1, "str")).toDF("id", "value")
val right = Seq((1, 1, "str")).toDF("id", "seq", "value")
- doTestRequirement(left.diff(right, "id"),
+ doTestRequirement(
+ left.diff(right, "id"),
"The number of columns doesn't match.\n" +
"Left column names (2): id, value\n" +
- "Right column names (3): id, seq, value")
+ "Right column names (3): id, seq, value"
+ )
}
test("diff similar with ignored column and different number of columns") {
val left = Seq((1, "str", "meta")).toDF("id", "value", "meta")
val right = Seq((1, 1, "str")).toDF("id", "seq", "value")
- doTestRequirement(left.diff(right, Seq("id"), Seq("meta")),
+ doTestRequirement(
+ left.diff(right, Seq("id"), Seq("meta")),
"The number of columns doesn't match.\n" +
"Left column names except ignored columns (2): id, value\n" +
- "Right column names except ignored columns (3): id, seq, value")
+ "Right column names except ignored columns (3): id, seq, value"
+ )
}
test("diff as U") {
@@ -1010,7 +1096,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
(DiffOptions.default.nochangeDiffValue, "eq")
).toDF("diff", "action")
- val expected = expectedDiffAs.toDS()
+ val expected = expectedDiffAs
+ .toDS()
.join(actions, "diff")
.select($"action", $"id", $"left_value".as("before_value"), $"right_value".as("after_value"))
.as[DiffAsCustom]
@@ -1032,8 +1119,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
}
test("diff as U with extra column") {
- doTestRequirement(left.diffAs[DiffAsExtra](right, "id"),
- "Diff encoder's columns must be part of the diff result schema, these columns are unexpected: extra")
+ doTestRequirement(
+ left.diffAs[DiffAsExtra](right, "id"),
+ "Diff encoder's columns must be part of the diff result schema, these columns are unexpected: extra"
+ )
}
test("diff with change column") {
@@ -1041,15 +1130,19 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val actual = left7.diff(right7, options, "id").orderBy("id")
assert(actual.columns === Seq("diff", "changes", "id", "left_value", "right_value", "left_label", "right_label"))
- assert(actual.schema === StructType(Seq(
- StructField("diff", StringType, nullable = false),
- StructField("changes", ArrayType(StringType, containsNull = false), nullable = true),
- StructField("id", IntegerType, nullable = true),
- StructField("left_value", StringType, nullable = true),
- StructField("right_value", StringType, nullable = true),
- StructField("left_label", StringType, nullable = true),
- StructField("right_label", StringType, nullable = true)
- )))
+ assert(
+ actual.schema === StructType(
+ Seq(
+ StructField("diff", StringType, nullable = false),
+ StructField("changes", ArrayType(StringType, containsNull = false), nullable = true),
+ StructField("id", IntegerType, nullable = true),
+ StructField("left_value", StringType, nullable = true),
+ StructField("right_value", StringType, nullable = true),
+ StructField("left_label", StringType, nullable = true),
+ StructField("right_label", StringType, nullable = true)
+ )
+ )
+ )
assert(actual.collect() === expectedDiff7WithChanges)
}
@@ -1058,15 +1151,21 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val actual = left7.diff(right7, options)
assert(actual.columns === Seq("diff", "changes", "id", "value", "label"))
- assert(actual.schema === StructType(Seq(
- StructField("diff", StringType, nullable = false),
- StructField("changes", ArrayType(StringType, containsNull = false), nullable = true),
- StructField("id", IntegerType, nullable = true),
- StructField("value", StringType, nullable = true),
- StructField("label", StringType, nullable = true)
- )))
- assert(actual.select($"diff", $"changes").distinct().orderBy($"diff").collect() ===
- Seq(Row("D", null), Row("I", null), Row("N", Seq.empty[String])))
+ assert(
+ actual.schema === StructType(
+ Seq(
+ StructField("diff", StringType, nullable = false),
+ StructField("changes", ArrayType(StringType, containsNull = false), nullable = true),
+ StructField("id", IntegerType, nullable = true),
+ StructField("value", StringType, nullable = true),
+ StructField("label", StringType, nullable = true)
+ )
+ )
+ )
+ assert(
+ actual.select($"diff", $"changes").distinct().orderBy($"diff").collect() ===
+ Seq(Row("D", null), Row("I", null), Row("N", Seq.empty[String]))
+ )
}
test("diff with change column name in non-id columns") {
@@ -1079,8 +1178,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
test("diff with change column name in id columns") {
val options = DiffOptions.default.withChangeColumn("value")
- doTestRequirement(left.diff(right, options, "id", "value"),
- "The id columns must not contain the change column name 'value': id, value")
+ doTestRequirement(
+ left.diff(right, options, "id", "value"),
+ "The id columns must not contain the change column name 'value': id, value"
+ )
}
test("diff with column-by-column diff mode") {
@@ -1145,26 +1246,34 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
test("diff with left-side diff mode and diff column name in value columns") {
val options = DiffOptions.default.withDiffColumn("value").withDiffMode(DiffMode.LeftSide)
- doTestRequirement(left.diff(right, options, "id"),
- "The left non-id columns must not contain the diff column name 'value': value")
+ doTestRequirement(
+ left.diff(right, options, "id"),
+ "The left non-id columns must not contain the diff column name 'value': value"
+ )
}
test("diff with right-side diff mode and diff column name in value columns") {
val options = DiffOptions.default.withDiffColumn("value").withDiffMode(DiffMode.RightSide)
- doTestRequirement(right.diff(right, options, "id"),
- "The right non-id columns must not contain the diff column name 'value': value")
+ doTestRequirement(
+ right.diff(right, options, "id"),
+ "The right non-id columns must not contain the diff column name 'value': value"
+ )
}
test("diff with left-side diff mode and change column name in value columns") {
val options = DiffOptions.default.withChangeColumn("value").withDiffMode(DiffMode.LeftSide)
- doTestRequirement(left.diff(right, options, "id"),
- "The left non-id columns must not contain the change column name 'value': value")
+ doTestRequirement(
+ left.diff(right, options, "id"),
+ "The left non-id columns must not contain the change column name 'value': value"
+ )
}
test("diff with right-side diff mode and change column name in value columns") {
val options = DiffOptions.default.withChangeColumn("value").withDiffMode(DiffMode.RightSide)
- doTestRequirement(right.diff(right, options, "id"),
- "The right non-id columns must not contain the change column name 'value': value")
+ doTestRequirement(
+ right.diff(right, options, "id"),
+ "The right non-id columns must not contain the change column name 'value': value"
+ )
}
test("diff with dots in diff column") {
@@ -1283,10 +1392,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
schema.copy(fields =
schema.fields
.map(_.copy(nullable = true))
- .map(field => field.dataType match {
- case a: ArrayType => field.copy(dataType = a.copy(containsNull = false))
- case _ => field
- })
+ .map(field =>
+ field.dataType match {
+ case a: ArrayType => field.copy(dataType = a.copy(containsNull = false))
+ case _ => field
+ }
+ )
)
}
@@ -1297,24 +1408,63 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
}
test("diff with ignored columns") {
- assertIgnoredColumns(left8.diff(right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema)
- assertIgnoredColumns(Diff.of(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema)
- assertIgnoredColumns(Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema)
-
- assertIgnoredColumns[DiffAs8](left8.diffAs(right8, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, Encoders.product[DiffAs8].schema)
- assertIgnoredColumns[DiffAs8](Diff.ofAs(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, Encoders.product[DiffAs8].schema)
- assertIgnoredColumns[DiffAs8](Diff.default.diffAs(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, Encoders.product[DiffAs8].schema)
-
- val expected = expectedDiff8.map(row => (
- row.getString(0),
- Value8(row.getInt(1), Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3)), Option(row.getString(5))),
- Value8(row.getInt(1), Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(4)), Option(row.getString(6)))
- )).map { case (diff, left, right) => (
- diff,
- if (diff == "I") null else left,
- if (diff == "D") null else right
+ assertIgnoredColumns(
+ left8.diff(right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8,
+ Encoders.product[DiffAs8].schema
)
- }
+ assertIgnoredColumns(
+ Diff.of(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8,
+ Encoders.product[DiffAs8].schema
+ )
+ assertIgnoredColumns(
+ Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8,
+ Encoders.product[DiffAs8].schema
+ )
+
+ assertIgnoredColumns[DiffAs8](
+ left8.diffAs(right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiffAs8,
+ Encoders.product[DiffAs8].schema
+ )
+ assertIgnoredColumns[DiffAs8](
+ Diff.ofAs(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiffAs8,
+ Encoders.product[DiffAs8].schema
+ )
+ assertIgnoredColumns[DiffAs8](
+ Diff.default.diffAs(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiffAs8,
+ Encoders.product[DiffAs8].schema
+ )
+
+ val expected = expectedDiff8
+ .map(row =>
+ (
+ row.getString(0),
+ Value8(
+ row.getInt(1),
+ Option(row.get(2)).map(_.asInstanceOf[Int]),
+ Option(row.getString(3)),
+ Option(row.getString(5))
+ ),
+ Value8(
+ row.getInt(1),
+ Option(row.get(2)).map(_.asInstanceOf[Int]),
+ Option(row.getString(4)),
+ Option(row.getString(6))
+ )
+ )
+ )
+ .map { case (diff, left, right) =>
+ (
+ diff,
+ if (diff == "I") null else left,
+ if (diff == "D") null else right
+ )
+ }
assertDiffWith(left8.diffWith(right8, Seq("id", "seq"), Seq("meta")).collect(), expected)
assertDiffWith(Diff.ofWith(left8, right8, Seq("id", "seq"), Seq("meta")).collect(), expected)
@@ -1325,112 +1475,230 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val options = DiffOptions.default.withChangeColumn("changed")
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedDiff8WithChanges, Encoders.product[DiffAs8changes].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8WithChanges, Encoders.product[DiffAs8changes].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8WithChanges,
+ Encoders.product[DiffAs8changes].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8WithChanges,
+ Encoders.product[DiffAs8changes].schema
+ )
}
test("diff with ignored columns and column-by-column diff mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8,
+ Encoders.product[DiffAs8].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedDiff8,
+ Encoders.product[DiffAs8].schema
+ )
}
test("diff with ignored columns and side-by-side diff mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedSideBySideDiff8, Encoders.product[DiffAs8SideBySide].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedSideBySideDiff8, Encoders.product[DiffAs8SideBySide].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedSideBySideDiff8,
+ Encoders.product[DiffAs8SideBySide].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedSideBySideDiff8,
+ Encoders.product[DiffAs8SideBySide].schema
+ )
}
test("diff with ignored columns and left-side diff mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedLeftSideDiff8, Encoders.product[DiffAs8OneSide].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedLeftSideDiff8, Encoders.product[DiffAs8OneSide].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedLeftSideDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedLeftSideDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
}
test("diff with ignored columns and right-side diff mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.RightSide)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedRightSideDiff8, Encoders.product[DiffAs8OneSide].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedRightSideDiff8, Encoders.product[DiffAs8OneSide].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedRightSideDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedRightSideDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
}
test("diff with ignored columns, column-by-column diff and sparse mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn).withSparseMode(true)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedSparseDiff8, Encoders.product[DiffAs8].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedSparseDiff8, Encoders.product[DiffAs8].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedSparseDiff8,
+ Encoders.product[DiffAs8].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedSparseDiff8,
+ Encoders.product[DiffAs8].schema
+ )
}
test("diff with ignored columns, side-by-side diff and sparse mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide).withSparseMode(true)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedSideBySideSparseDiff8, Encoders.product[DiffAs8SideBySide].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedSideBySideSparseDiff8, Encoders.product[DiffAs8SideBySide].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedSideBySideSparseDiff8,
+ Encoders.product[DiffAs8SideBySide].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedSideBySideSparseDiff8,
+ Encoders.product[DiffAs8SideBySide].schema
+ )
}
test("diff with ignored columns, left-side diff and sparse mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide).withSparseMode(true)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedLeftSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedLeftSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedLeftSideSparseDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedLeftSideSparseDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
}
test("diff with ignored columns, right-side diff and sparse mode") {
val options = DiffOptions.default.withDiffMode(DiffMode.RightSide).withSparseMode(true)
val differ = new Differ(options)
- assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedRightSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema)
- assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedRightSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema)
+ assertIgnoredColumns(
+ left8.diff(right8, options, Seq("id", "seq"), Seq("meta")),
+ expectedRightSideSparseDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
+ assertIgnoredColumns(
+ differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")),
+ expectedRightSideSparseDiff8,
+ Encoders.product[DiffAs8OneSide].schema
+ )
}
test("diff similar with ignored columns") {
- val expectedSchema = StructType(Seq(
- StructField("diff", StringType),
- StructField("id", IntegerType),
- StructField("seq", IntegerType),
- StructField("left_value", StringType),
- StructField("right_value", StringType),
- StructField("left_meta", StringType),
- StructField("right_info", StringType),
- ))
+ val expectedSchema = StructType(
+ Seq(
+ StructField("diff", StringType),
+ StructField("id", IntegerType),
+ StructField("seq", IntegerType),
+ StructField("left_value", StringType),
+ StructField("right_value", StringType),
+ StructField("left_meta", StringType),
+ StructField("right_info", StringType),
+ )
+ )
assertIgnoredColumns(left8.diff(right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiff8and9, expectedSchema)
- assertIgnoredColumns(Diff.of(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiff8and9, expectedSchema)
- assertIgnoredColumns(Diff.default.diff(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiff8and9, expectedSchema)
-
- assertIgnoredColumns[DiffAs8and9](left8.diffAs(right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema)
- assertIgnoredColumns[DiffAs8and9](Diff.ofAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema)
- assertIgnoredColumns[DiffAs8and9](Diff.default.diffAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema)
-
- val expectedSchemaWith = StructType(Seq(
- StructField("_1", StringType),
- StructField("_2", StructType(Seq(
- StructField("id", IntegerType, nullable = true),
- StructField("seq", IntegerType, nullable = true),
- StructField("value", StringType, nullable = true),
- StructField("meta", StringType, nullable = true)
- ))),
- StructField("_3", StructType(Seq(
- StructField("id", IntegerType, nullable = true),
- StructField("seq", IntegerType, nullable = true),
- StructField("value", StringType, nullable = true),
- StructField("info", StringType, nullable = true)
- ))),
- ))
-
- assertDiffWithSchema(left8.diffWith(right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffWith8and9, expectedSchemaWith)
- assertDiffWithSchema(Diff.ofWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffWith8and9, expectedSchemaWith)
- assertDiffWithSchema(Diff.default.diffWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffWith8and9, expectedSchemaWith)
+ assertIgnoredColumns(
+ Diff.of(left8, right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiff8and9,
+ expectedSchema
+ )
+ assertIgnoredColumns(
+ Diff.default.diff(left8, right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiff8and9,
+ expectedSchema
+ )
+
+ assertIgnoredColumns[DiffAs8and9](
+ left8.diffAs(right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffAs8and9,
+ expectedSchema
+ )
+ assertIgnoredColumns[DiffAs8and9](
+ Diff.ofAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffAs8and9,
+ expectedSchema
+ )
+ assertIgnoredColumns[DiffAs8and9](
+ Diff.default.diffAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffAs8and9,
+ expectedSchema
+ )
+
+ val expectedSchemaWith = StructType(
+ Seq(
+ StructField("_1", StringType),
+ StructField(
+ "_2",
+ StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = true),
+ StructField("seq", IntegerType, nullable = true),
+ StructField("value", StringType, nullable = true),
+ StructField("meta", StringType, nullable = true)
+ )
+ )
+ ),
+ StructField(
+ "_3",
+ StructType(
+ Seq(
+ StructField("id", IntegerType, nullable = true),
+ StructField("seq", IntegerType, nullable = true),
+ StructField("value", StringType, nullable = true),
+ StructField("info", StringType, nullable = true)
+ )
+ )
+ ),
+ )
+ )
+
+ assertDiffWithSchema(
+ left8.diffWith(right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffWith8and9,
+ expectedSchemaWith
+ )
+ assertDiffWithSchema(
+ Diff.ofWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffWith8and9,
+ expectedSchemaWith
+ )
+ assertDiffWithSchema(
+ Diff.default.diffWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffWith8and9,
+ expectedSchemaWith
+ )
}
test("diff similar with ignored columns of different type") {
@@ -1443,23 +1711,45 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val right = right8.toDF("ID", "SEQ", "VALUE", "META")
def expectedSchema(id: String, seq: String): StructType =
- StructType(Seq(
- StructField("diff", StringType),
- StructField(id, IntegerType),
- StructField(seq, IntegerType),
- StructField("left_value", StringType),
- StructField("right_VALUE", StringType),
- StructField("left_meta", StringType),
- StructField("right_META", StringType),
- ))
+ StructType(
+ Seq(
+ StructField("diff", StringType),
+ StructField(id, IntegerType),
+ StructField(seq, IntegerType),
+ StructField("left_value", StringType),
+ StructField("right_VALUE", StringType),
+ StructField("left_meta", StringType),
+ StructField("right_META", StringType),
+ )
+ )
assertIgnoredColumns(left.diff(right, Seq("iD", "sEq"), Seq("MeTa")), expectedDiff8, expectedSchema("iD", "sEq"))
- assertIgnoredColumns(Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA")), expectedDiff8, expectedSchema("Id", "SeQ"))
- assertIgnoredColumns(Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META")), expectedDiff8, expectedSchema("ID", "SEQ"))
+ assertIgnoredColumns(
+ Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA")),
+ expectedDiff8,
+ expectedSchema("Id", "SeQ")
+ )
+ assertIgnoredColumns(
+ Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META")),
+ expectedDiff8,
+ expectedSchema("ID", "SEQ")
+ )
- assertIgnoredColumns[DiffAs8](left.diffAs(right, Seq("id", "seq"), Seq("MeTa")), expectedDiffAs8, expectedSchema("id", "seq"))
- assertIgnoredColumns[DiffAs8](Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA")), expectedDiffAs8, expectedSchema("id", "seq"))
- assertIgnoredColumns[DiffAs8](Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, expectedSchema("id", "seq"))
+ assertIgnoredColumns[DiffAs8](
+ left.diffAs(right, Seq("id", "seq"), Seq("MeTa")),
+ expectedDiffAs8,
+ expectedSchema("id", "seq")
+ )
+ assertIgnoredColumns[DiffAs8](
+ Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA")),
+ expectedDiffAs8,
+ expectedSchema("id", "seq")
+ )
+ assertIgnoredColumns[DiffAs8](
+ Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta")),
+ expectedDiffAs8,
+ expectedSchema("id", "seq")
+ )
// TODO: add diffWith
}
@@ -1470,26 +1760,44 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = left8.toDF("id", "seq", "value", "meta")
val right = right8.toDF("ID", "SEQ", "VALUE", "META")
- doTestRequirement(left.diff(right, Seq("Id", "SeQ"), Seq("MeTa")),
- "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)")
- doTestRequirement(Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa")),
- "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)")
- doTestRequirement(Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa")),
- "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)")
-
- doTestRequirement(left8.diff(right8, Seq("Id", "SeQ"), Seq("MeTa")),
- "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta")
- doTestRequirement(Diff.of(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")),
- "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta")
- doTestRequirement(Diff.default.diff(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")),
- "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta")
-
- doTestRequirement(left8.diff(right8, Seq("id", "seq"), Seq("MeTa")),
- "Some ignore columns do not exist: MeTa missing among id, meta, seq, value")
- doTestRequirement(Diff.of(left8, right8, Seq("id", "seq"), Seq("MeTa")),
- "Some ignore columns do not exist: MeTa missing among id, meta, seq, value")
- doTestRequirement(Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("MeTa")),
- "Some ignore columns do not exist: MeTa missing among id, meta, seq, value")
+ doTestRequirement(
+ left.diff(right, Seq("Id", "SeQ"), Seq("MeTa")),
+ "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)"
+ )
+ doTestRequirement(
+ Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa")),
+ "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)"
+ )
+ doTestRequirement(
+ Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa")),
+ "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)"
+ )
+
+ doTestRequirement(
+ left8.diff(right8, Seq("Id", "SeQ"), Seq("MeTa")),
+ "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta"
+ )
+ doTestRequirement(
+ Diff.of(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")),
+ "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta"
+ )
+ doTestRequirement(
+ Diff.default.diff(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")),
+ "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta"
+ )
+
+ doTestRequirement(
+ left8.diff(right8, Seq("id", "seq"), Seq("MeTa")),
+ "Some ignore columns do not exist: MeTa missing among id, meta, seq, value"
+ )
+ doTestRequirement(
+ Diff.of(left8, right8, Seq("id", "seq"), Seq("MeTa")),
+ "Some ignore columns do not exist: MeTa missing among id, meta, seq, value"
+ )
+ doTestRequirement(
+ Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("MeTa")),
+ "Some ignore columns do not exist: MeTa missing among id, meta, seq, value"
+ )
}
}
@@ -1499,45 +1807,97 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val right = right9.toDF("ID", "SEQ", "VALUE", "INFO").as[Value9up]
def expectedSchema(id: String, seq: String): StructType =
- StructType(Seq(
- StructField("diff", StringType),
- StructField(id, IntegerType),
- StructField(seq, IntegerType),
- StructField("left_value", StringType),
- StructField("right_VALUE", StringType),
- StructField("left_meta", StringType),
- StructField("right_INFO", StringType),
- ))
-
- assertIgnoredColumns(left.diff(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")), expectedDiff8and9, expectedSchema("iD", "sEq"))
- assertIgnoredColumns(Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")), expectedDiff8and9, expectedSchema("Id", "SeQ"))
- assertIgnoredColumns(Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")), expectedDiff8and9, expectedSchema("ID", "SEQ"))
+ StructType(
+ Seq(
+ StructField("diff", StringType),
+ StructField(id, IntegerType),
+ StructField(seq, IntegerType),
+ StructField("left_value", StringType),
+ StructField("right_VALUE", StringType),
+ StructField("left_meta", StringType),
+ StructField("right_INFO", StringType),
+ )
+ )
+
+ assertIgnoredColumns(
+ left.diff(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")),
+ expectedDiff8and9,
+ expectedSchema("iD", "sEq")
+ )
+ assertIgnoredColumns(
+ Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")),
+ expectedDiff8and9,
+ expectedSchema("Id", "SeQ")
+ )
+ assertIgnoredColumns(
+ Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")),
+ expectedDiff8and9,
+ expectedSchema("ID", "SEQ")
+ )
// TODO: remove generic type
- assertIgnoredColumns[DiffAs8and9](left.diffAs(right, Seq("id", "seq"), Seq("MeTa", "InFo")), expectedDiffAs8and9, expectedSchema("id", "seq"))
- assertIgnoredColumns[DiffAs8and9](Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA", "iNfO")), expectedDiffAs8and9, expectedSchema("id", "seq"))
- assertIgnoredColumns[DiffAs8and9](Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema("id", "seq"))
+ assertIgnoredColumns[DiffAs8and9](
+ left.diffAs(right, Seq("id", "seq"), Seq("MeTa", "InFo")),
+ expectedDiffAs8and9,
+ expectedSchema("id", "seq")
+ )
+ assertIgnoredColumns[DiffAs8and9](
+ Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA", "iNfO")),
+ expectedDiffAs8and9,
+ expectedSchema("id", "seq")
+ )
+ assertIgnoredColumns[DiffAs8and9](
+ Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta", "info")),
+ expectedDiffAs8and9,
+ expectedSchema("id", "seq")
+ )
def expectedSchemaWith(id: String, seq: String): StructType =
- StructType(Seq(
- StructField("_1", StringType, nullable = false),
- StructField("_2", StructType(Seq(
- StructField(id, IntegerType),
- StructField(seq, IntegerType),
- StructField("value", StringType),
- StructField("meta", StringType)
- )), nullable = true),
- StructField("_3", StructType(Seq(
- StructField(id, IntegerType),
- StructField(seq, IntegerType),
- StructField("VALUE", StringType),
- StructField("INFO", StringType)
- )), nullable = true),
- ))
-
- assertIgnoredColumns[(String, Value8, Value9up)](left.diffWith(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")), expectedDiffWith8and9up, expectedSchemaWith("iD", "sEq"))
- assertIgnoredColumns[(String, Value8, Value9up)](Diff.ofWith(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")), expectedDiffWith8and9up, expectedSchemaWith("Id", "SeQ"))
- assertIgnoredColumns[(String, Value8, Value9up)](Diff.default.diffWith(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")), expectedDiffWith8and9up, expectedSchemaWith("ID", "SEQ"))
+ StructType(
+ Seq(
+ StructField("_1", StringType, nullable = false),
+ StructField(
+ "_2",
+ StructType(
+ Seq(
+ StructField(id, IntegerType),
+ StructField(seq, IntegerType),
+ StructField("value", StringType),
+ StructField("meta", StringType)
+ )
+ ),
+ nullable = true
+ ),
+ StructField(
+ "_3",
+ StructType(
+ Seq(
+ StructField(id, IntegerType),
+ StructField(seq, IntegerType),
+ StructField("VALUE", StringType),
+ StructField("INFO", StringType)
+ )
+ ),
+ nullable = true
+ ),
+ )
+ )
+
+ assertIgnoredColumns[(String, Value8, Value9up)](
+ left.diffWith(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")),
+ expectedDiffWith8and9up,
+ expectedSchemaWith("iD", "sEq")
+ )
+ assertIgnoredColumns[(String, Value8, Value9up)](
+ Diff.ofWith(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")),
+ expectedDiffWith8and9up,
+ expectedSchemaWith("Id", "SeQ")
+ )
+ assertIgnoredColumns[(String, Value8, Value9up)](
+ Diff.default.diffWith(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")),
+ expectedDiffWith8and9up,
+ expectedSchemaWith("ID", "SEQ")
+ )
}
}
@@ -1546,26 +1906,44 @@ class DiffSuite extends AnyFunSuite with SparkTestSession {
val left = left8.toDF("id", "seq", "value", "meta").as[Value8]
val right = right9.toDF("ID", "SEQ", "VALUE", "INFO").as[Value9up]
- doTestRequirement(left.diff(right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
- "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)")
- doTestRequirement(Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
- "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)")
- doTestRequirement(Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
- "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)")
-
- doTestRequirement(left8.diff(right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
- "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)")
- doTestRequirement(Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
- "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)")
- doTestRequirement(Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
- "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)")
-
- doTestRequirement(left8.diff(right9, Seq("Id", "SeQ"), Seq("meta", "info")),
- "Some id columns do not exist: Id, SeQ missing among id, seq, value")
- doTestRequirement(Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")),
- "Some id columns do not exist: Id, SeQ missing among id, seq, value")
- doTestRequirement(Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")),
- "Some id columns do not exist: Id, SeQ missing among id, seq, value")
+ doTestRequirement(
+ left.diff(right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
+ "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)"
+ )
+ doTestRequirement(
+ Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
+ "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)"
+ )
+ doTestRequirement(
+ Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
+ "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)"
+ )
+
+ doTestRequirement(
+ left8.diff(right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
+ "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)"
+ )
+ doTestRequirement(
+ Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
+ "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)"
+ )
+ doTestRequirement(
+ Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")),
+ "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)"
+ )
+
+ doTestRequirement(
+ left8.diff(right9, Seq("Id", "SeQ"), Seq("meta", "info")),
+ "Some id columns do not exist: Id, SeQ missing among id, seq, value"
+ )
+ doTestRequirement(
+ Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")),
+ "Some id columns do not exist: Id, SeQ missing among id, seq, value"
+ )
+ doTestRequirement(
+ Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")),
+ "Some id columns do not exist: Id, SeQ missing among id, seq, value"
+ )
}
}
diff --git a/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala b/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala
index 5c69813a..14664ad9 100644
--- a/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala
+++ b/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala
@@ -26,10 +26,12 @@ class Examples extends AnyFunSuite with SparkTestSession {
test("issue") {
import spark.implicits._
- val originalDF = Seq((1,"gaurav","jaipur",550,70000),(2,"sunil","noida",600,80000),(3,"rishi","ahmedabad",510,65000))
- .toDF("id","name","city","credit_score","credit_limit")
- val changedDF= Seq((1,"gaurav","jaipur",550,70000),(2,"sunil","noida",650,90000),(4,"Joshua","cochin",612,85000))
- .toDF("id","name","city","credit_score","credit_limit")
+ val originalDF =
+ Seq((1, "gaurav", "jaipur", 550, 70000), (2, "sunil", "noida", 600, 80000), (3, "rishi", "ahmedabad", 510, 65000))
+ .toDF("id", "name", "city", "credit_score", "credit_limit")
+ val changedDF =
+ Seq((1, "gaurav", "jaipur", 550, 70000), (2, "sunil", "noida", 650, 90000), (4, "Joshua", "cochin", 612, 85000))
+ .toDF("id", "name", "city", "credit_score", "credit_limit")
val options = DiffOptions.default.withChangeColumn("changes")
val diff = originalDF.diff(changedDF, options, "id")
diff.show(false)
@@ -56,7 +58,7 @@ class Examples extends AnyFunSuite with SparkTestSession {
{
Seq(DiffMode.ColumnByColumn, DiffMode.SideBySide, DiffMode.LeftSide, DiffMode.RightSide).foreach { mode =>
- Seq(false, true).foreach{ sparse =>
+ Seq(false, true).foreach { sparse =>
val options = DiffOptions.default.withDiffMode(mode)
left.diff(right, options, "id").orderBy("id").show(false)
}
diff --git a/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala b/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala
index b2fb22c8..4e04c36c 100644
--- a/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala
@@ -79,16 +79,22 @@ class GroupSuite extends AnyFunSpec {
def testUnconsumedKey(unconsumedKey: K, func: () => Iterator[(K, V)]): Unit = {
// we expect all tuples (k, Some(v)), except for k == unconsumedKey, where we expect (k, None)
// here we consume all groups (it.toList), which is tested elsewhere to work
- val expected = new GroupedIterator(func()).map {
- case (key, it) if key == unconsumedKey => it.toList; (key, Iterator(None))
- case (key, it) => (key, it.map(Some(_)))
- }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList
+ val expected = new GroupedIterator(func())
+ .map {
+ case (key, it) if key == unconsumedKey => it.toList; (key, Iterator(None))
+ case (key, it) => (key, it.map(Some(_)))
+ }
+ .flatMap { case (k, it) => it.map(v => (k, v)) }
+ .toList
// here we do not consume the group with key `unconsumedKey`
- val actual = new GroupedIterator(func()).map {
- case (key, _) if key == unconsumedKey => (key, Iterator(None))
- case (key, it) => (key, it.map(Some(_)))
- }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList
+ val actual = new GroupedIterator(func())
+ .map {
+ case (key, _) if key == unconsumedKey => (key, Iterator(None))
+ case (key, it) => (key, it.map(Some(_)))
+ }
+ .flatMap { case (k, it) => it.map(v => (k, v)) }
+ .toList
assert(actual === expected)
}
@@ -135,16 +141,22 @@ class GroupSuite extends AnyFunSpec {
// we expect all tuples (k, v), except for k == unconsumedKey,
// where we expect only the first tuple with k == partiallyConsumedKey
// here we consume all groups (it.toList), which is tested elsewhere to work
- val expected = new GroupedIterator(func()).map {
- case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.toList.head))
- case (key, it) => (key, it)
- }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList
+ val expected = new GroupedIterator(func())
+ .map {
+ case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.toList.head))
+ case (key, it) => (key, it)
+ }
+ .flatMap { case (k, it) => it.map(v => (k, v)) }
+ .toList
// here we only consume the first element of the group with key `unconsumedKey`
- val actual = new GroupedIterator(func()).map {
- case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.next()))
- case (key, it) => (key, it)
- }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList
+ val actual = new GroupedIterator(func())
+ .map {
+ case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.next()))
+ case (key, it) => (key, it)
+ }
+ .flatMap { case (k, it) => it.map(v => (k, v)) }
+ .toList
assert(actual === expected)
}
@@ -178,8 +190,8 @@ class GroupSuite extends AnyFunSpec {
// this consumes all group iterators
it("and fully consumed groups") {
val expected = func().toList
- val actual = new GroupedIterator(func()).flatMap {
- case (k, it) => it.map(v => (k, v))
+ val actual = new GroupedIterator(func()).flatMap { case (k, it) =>
+ it.map(v => (k, v))
}.toList
assert(actual === expected)
}
@@ -211,7 +223,7 @@ class GroupSuite extends AnyFunSpec {
}
describe("should iterate only over current key") {
- def test[K : Ordering, V](it: Seq[(K, V)], expectedValues: Seq[V]): Unit = {
+ def test[K: Ordering, V](it: Seq[(K, V)], expectedValues: Seq[V]): Unit = {
val git = new GroupIterator[K, V](it.iterator.buffered)
assert(git.toList === expectedValues)
}
diff --git a/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala b/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala
index b8b6cbb4..67ce3a16 100644
--- a/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala
+++ b/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala
@@ -39,12 +39,14 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
val parallelisms = Seq(None, Some(1), Some(2), Some(8))
- def assertDf(actual: DataFrame,
- order: Seq[Column],
- expectedSchema: StructType,
- expectedRows: Seq[Row],
- expectedParallelism: Option[Int],
- postProcess: DataFrame => DataFrame = identity): Unit = {
+ def assertDf(
+ actual: DataFrame,
+ order: Seq[Column],
+ expectedSchema: StructType,
+ expectedRows: Seq[Row],
+ expectedParallelism: Option[Int],
+ postProcess: DataFrame => DataFrame = identity
+ ): Unit = {
assert(actual.schema === expectedSchema)
if (expectedParallelism.isDefined) {
@@ -56,7 +58,10 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
val replaced =
actual
.orderBy(order: _*)
- .withColumn("filename", regexp_replace(regexp_replace($"filename", ".*/test.parquet/", ""), ".*/nested.parquet", "nested.parquet"))
+ .withColumn(
+ "filename",
+ regexp_replace(regexp_replace($"filename", ".*/test.parquet/", ""), ".*/nested.parquet", "nested.parquet")
+ )
.when(actual.columns.contains("schema"))
.call(_.withColumn("schema", regexp_replace($"schema", "\n", "\\\\n")))
.call(postProcess)
@@ -81,20 +86,22 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
.either(_.parquetMetadata(parallelism.get, testFile))
.or(_.parquetMetadata(testFile)),
Seq($"filename"),
- StructType(Seq(
- StructField("filename", StringType, nullable = true),
- StructField("blocks", IntegerType, nullable = false),
- StructField("compressedBytes", LongType, nullable = false),
- StructField("uncompressedBytes", LongType, nullable = false),
- StructField("rows", LongType, nullable = false),
- StructField("columns", IntegerType, nullable = false),
- StructField("values", LongType, nullable = false),
- StructField("nulls", LongType, nullable = true),
- StructField("createdBy", StringType, nullable = true),
- StructField("schema", StringType, nullable = true),
- StructField("encryption", StringType, nullable = true),
- StructField("keyValues", MapType(StringType, StringType, valueContainsNull = true), nullable = true),
- )),
+ StructType(
+ Seq(
+ StructField("filename", StringType, nullable = true),
+ StructField("blocks", IntegerType, nullable = false),
+ StructField("compressedBytes", LongType, nullable = false),
+ StructField("uncompressedBytes", LongType, nullable = false),
+ StructField("rows", LongType, nullable = false),
+ StructField("columns", IntegerType, nullable = false),
+ StructField("values", LongType, nullable = false),
+ StructField("nulls", LongType, nullable = true),
+ StructField("createdBy", StringType, nullable = true),
+ StructField("schema", StringType, nullable = true),
+ StructField("encryption", StringType, nullable = true),
+ StructField("keyValues", MapType(StringType, StringType, valueContainsNull = true), nullable = true),
+ )
+ ),
Seq(
Row("file1.parquet", 1, 1268, 1652, 100, 2, 200, 0, createdBy, schema, UNENCRYPTED, keyValues),
Row("file2.parquet", 2, 2539, 3302, 200, 2, 400, 0, createdBy, schema, UNENCRYPTED, keyValues),
@@ -116,21 +123,23 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
.either(_.parquetSchema(parallelism.get, nestedFile))
.or(_.parquetSchema(nestedFile)),
Seq($"filename", $"columnPath"),
- StructType(Seq(
- StructField("filename", StringType, nullable = true),
- StructField("columnName", StringType, nullable = true),
- StructField("columnPath", ArrayType(StringType, containsNull = true), nullable = true),
- StructField("repetition", StringType, nullable = true),
- StructField("type", StringType, nullable = true),
- StructField("length", IntegerType, nullable = true),
- StructField("originalType", StringType, nullable = true),
- StructField("logicalType", StringType, nullable = true),
- StructField("isPrimitive", BooleanType, nullable = false),
- StructField("primitiveType", StringType, nullable = true),
- StructField("primitiveOrder", StringType, nullable = true),
- StructField("maxDefinitionLevel", IntegerType, nullable = false),
- StructField("maxRepetitionLevel", IntegerType, nullable = false),
- )),
+ StructType(
+ Seq(
+ StructField("filename", StringType, nullable = true),
+ StructField("columnName", StringType, nullable = true),
+ StructField("columnPath", ArrayType(StringType, containsNull = true), nullable = true),
+ StructField("repetition", StringType, nullable = true),
+ StructField("type", StringType, nullable = true),
+ StructField("length", IntegerType, nullable = true),
+ StructField("originalType", StringType, nullable = true),
+ StructField("logicalType", StringType, nullable = true),
+ StructField("isPrimitive", BooleanType, nullable = false),
+ StructField("primitiveType", StringType, nullable = true),
+ StructField("primitiveOrder", StringType, nullable = true),
+ StructField("maxDefinitionLevel", IntegerType, nullable = false),
+ StructField("maxRepetitionLevel", IntegerType, nullable = false),
+ )
+ ),
// format: off
Seq(
Row("nested.parquet", "a", Seq("a"), "REQUIRED", "INT64", 0, null, null, true, "INT64", "TYPE_DEFINED_ORDER", 0, 0),
@@ -153,17 +162,19 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
.either(_.parquetBlocks(parallelism.get, testFile))
.or(_.parquetBlocks(testFile)),
Seq($"filename", $"block"),
- StructType(Seq(
- StructField("filename", StringType, nullable = true),
- StructField("block", IntegerType, nullable = false),
- StructField("blockStart", LongType, nullable = false),
- StructField("compressedBytes", LongType, nullable = false),
- StructField("uncompressedBytes", LongType, nullable = false),
- StructField("rows", LongType, nullable = false),
- StructField("columns", IntegerType, nullable = false),
- StructField("values", LongType, nullable = false),
- StructField("nulls", LongType, nullable = true),
- )),
+ StructType(
+ Seq(
+ StructField("filename", StringType, nullable = true),
+ StructField("block", IntegerType, nullable = false),
+ StructField("blockStart", LongType, nullable = false),
+ StructField("compressedBytes", LongType, nullable = false),
+ StructField("uncompressedBytes", LongType, nullable = false),
+ StructField("rows", LongType, nullable = false),
+ StructField("columns", IntegerType, nullable = false),
+ StructField("values", LongType, nullable = false),
+ StructField("nulls", LongType, nullable = true),
+ )
+ ),
Seq(
Row("file1.parquet", 1, 4, 1268, 1652, 100, 2, 200, 0),
Row("file2.parquet", 1, 4, 1269, 1651, 100, 2, 200, 0),
@@ -182,21 +193,23 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
.either(_.parquetBlockColumns(parallelism.get, testFile))
.or(_.parquetBlockColumns(testFile)),
Seq($"filename", $"block", $"column"),
- StructType(Seq(
- StructField("filename", StringType, nullable = true),
- StructField("block", IntegerType, nullable = false),
- StructField("column", ArrayType(StringType), nullable = true),
- StructField("codec", StringType, nullable = true),
- StructField("type", StringType, nullable = true),
- StructField("encodings", ArrayType(StringType), nullable = true),
- StructField("minValue", StringType, nullable = true),
- StructField("maxValue", StringType, nullable = true),
- StructField("columnStart", LongType, nullable = false),
- StructField("compressedBytes", LongType, nullable = false),
- StructField("uncompressedBytes", LongType, nullable = false),
- StructField("values", LongType, nullable = false),
- StructField("nulls", LongType, nullable = true),
- )),
+ StructType(
+ Seq(
+ StructField("filename", StringType, nullable = true),
+ StructField("block", IntegerType, nullable = false),
+ StructField("column", ArrayType(StringType), nullable = true),
+ StructField("codec", StringType, nullable = true),
+ StructField("type", StringType, nullable = true),
+ StructField("encodings", ArrayType(StringType), nullable = true),
+ StructField("minValue", StringType, nullable = true),
+ StructField("maxValue", StringType, nullable = true),
+ StructField("columnStart", LongType, nullable = false),
+ StructField("compressedBytes", LongType, nullable = false),
+ StructField("uncompressedBytes", LongType, nullable = false),
+ StructField("values", LongType, nullable = false),
+ StructField("nulls", LongType, nullable = true),
+ )
+ ),
// format: off
Seq(
Row("file1.parquet", 1, "[id]", "SNAPPY", "required int64 id", "[BIT_PACKED, PLAIN]", "0", "99", 4, 437, 826, 100, 0),
@@ -208,9 +221,10 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
),
// format: on
parallelism,
- (df: DataFrame) => df
- .withColumn("column", $"column".cast(StringType))
- .withColumn("encodings", $"encodings".cast(StringType))
+ (df: DataFrame) =>
+ df
+ .withColumn("column", $"column".cast(StringType))
+ .withColumn("encodings", $"encodings".cast(StringType))
)
}
}
@@ -230,7 +244,14 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
.join(rows, Seq("partition"), "left")
.select($"partition", $"start", $"end", $"length", $"rows", $"actual_rows", $"filename")
- if (partitions.where($"rows" =!= $"actual_rows" || ($"rows" =!= 0 || $"actual_rows" =!= 0) && $"length" =!= partitionSize).head(1).nonEmpty) {
+ if (
+ partitions
+ .where(
+ $"rows" =!= $"actual_rows" || ($"rows" =!= 0 || $"actual_rows" =!= 0) && $"length" =!= partitionSize
+ )
+ .head(1)
+ .nonEmpty
+ ) {
partitions
.orderBy($"start")
.where($"rows" =!= 0 || $"actual_rows" =!= 0)
@@ -276,8 +297,11 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
),
).foreach { case (partitionSize, expectedRows) =>
parallelisms.foreach { parallelism =>
- test(s"read parquet partitions (${partitionSize.getOrElse("default")} bytes) (parallelism=${parallelism.map(_.toString).getOrElse("None")})") {
- withSQLConf(partitionSize.map(size => Seq("spark.sql.files.maxPartitionBytes" -> size.toString)).getOrElse(Seq.empty): _*) {
+ test(s"read parquet partitions (${partitionSize
+ .getOrElse("default")} bytes) (parallelism=${parallelism.map(_.toString).getOrElse("None")})") {
+ withSQLConf(
+ partitionSize.map(size => Seq("spark.sql.files.maxPartitionBytes" -> size.toString)).getOrElse(Seq.empty): _*
+ ) {
val expected = expectedRows.map {
case row if SparkMajorVersion > 3 || SparkMinorVersion >= 3 => row
case row => Row(unapplySeq(row).get.updated(11, null): _*)
@@ -296,21 +320,23 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion {
assert(Seq(0, 0) === partitions)
}
- val schema = StructType(Seq(
- StructField("partition", IntegerType, nullable = false),
- StructField("start", LongType, nullable = false),
- StructField("end", LongType, nullable = false),
- StructField("length", LongType, nullable = false),
- StructField("blocks", IntegerType, nullable = false),
- StructField("compressedBytes", LongType, nullable = false),
- StructField("uncompressedBytes", LongType, nullable = false),
- StructField("rows", LongType, nullable = false),
- StructField("columns", IntegerType, nullable = false),
- StructField("values", LongType, nullable = false),
- StructField("nulls", LongType, nullable = true),
- StructField("filename", StringType, nullable = true),
- StructField("fileLength", LongType, nullable = true),
- ))
+ val schema = StructType(
+ Seq(
+ StructField("partition", IntegerType, nullable = false),
+ StructField("start", LongType, nullable = false),
+ StructField("end", LongType, nullable = false),
+ StructField("length", LongType, nullable = false),
+ StructField("blocks", IntegerType, nullable = false),
+ StructField("compressedBytes", LongType, nullable = false),
+ StructField("uncompressedBytes", LongType, nullable = false),
+ StructField("rows", LongType, nullable = false),
+ StructField("columns", IntegerType, nullable = false),
+ StructField("values", LongType, nullable = false),
+ StructField("nulls", LongType, nullable = true),
+ StructField("filename", StringType, nullable = true),
+ StructField("fileLength", LongType, nullable = true),
+ )
+ )
assertDf(actual, Seq($"filename", $"start"), schema, expected, parallelism, df => df.drop("partition"))
actual.unpersist()