diff --git a/.github/actions/build/action.yml b/.github/actions/build/action.yml index ab596939..8cf9d7d8 100644 --- a/.github/actions/build/action.yml +++ b/.github/actions/build/action.yml @@ -45,9 +45,9 @@ runs: env: JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED run: | - mvn --batch-mode --update-snapshots clean compile test-compile - mvn --batch-mode package -DskipTests -Dmaven.test.skip=true - mvn --batch-mode install -DskipTests -Dmaven.test.skip=true -Dgpg.skip + mvn --batch-mode --update-snapshots -Dspotless.check.skip clean compile test-compile + mvn --batch-mode package -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true + mvn --batch-mode install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip shell: bash - name: Upload Binaries diff --git a/.github/actions/test-jvm/action.yml b/.github/actions/test-jvm/action.yml index 76e55dde..51f107f3 100644 --- a/.github/actions/test-jvm/action.yml +++ b/.github/actions/test-jvm/action.yml @@ -73,7 +73,7 @@ runs: - name: Scala and Java Tests env: JDK_JAVA_OPTIONS: --add-exports java.base/sun.nio.ch=ALL-UNNAMED --add-exports java.base/sun.util.calendar=ALL-UNNAMED - run: mvn --batch-mode test + run: mvn --batch-mode -Dspotless.check.skip test shell: bash - name: Diff App test diff --git a/.github/actions/test-python/action.yml b/.github/actions/test-python/action.yml index 7ca019a3..1f18472e 100644 --- a/.github/actions/test-python/action.yml +++ b/.github/actions/test-python/action.yml @@ -100,7 +100,7 @@ runs: shell: bash - name: Install Spark Extension - run: mvn --batch-mode install -DskipTests -Dmaven.test.skip=true -Dgpg.skip + run: mvn --batch-mode install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip shell: bash - name: Python Integration Tests diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 38a1cd56..30a48ca0 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -2,9 +2,45 @@ name: Check on: workflow_call: + jobs: + lint: + name: Scala lint + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Cache Maven packages + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-mvn-lint-${{ hashFiles('pom.xml') }} + + - name: Setup JDK ${{ inputs.java-compat-version }} + uses: actions/setup-java@v3 + with: + java-version: '11' + distribution: 'zulu' + + - name: Check + id: check + run: | + mvn --batch-mode spotless:check + shell: bash + + - name: Changes + if: failure() && steps.check.outcome == 'failure' + run: | + mvn --batch-mode spotless:apply + git diff + shell: bash + config: - name: Configure check + name: Configure compat runs-on: ubuntu-latest outputs: major-version: ${{ steps.versions.outputs.major-version }} @@ -32,8 +68,8 @@ jobs: echo "release-major-version=${release_version/.*/}" >> "$GITHUB_OUTPUT" shell: bash - check: - name: Check (Spark ${{ matrix.spark-compat-version }} Scala ${{ matrix.scala-compat-version }}) + compat: + name: Compat (Spark ${{ matrix.spark-compat-version }} Scala ${{ matrix.scala-compat-version }}) needs: config runs-on: ubuntu-latest if: needs.config.outputs.major-version == needs.config.outputs.release-major-version diff --git a/pom.xml b/pom.xml index 7fbb0aff..3066f2e7 100644 --- a/pom.xml +++ b/pom.xml @@ -235,7 +235,7 @@ com.diffplug.spotless spotless-maven-plugin - 2.41.0 + 2.30.0 @@ -244,6 +244,16 @@ + + + + spotless-check + compile + + check + + + diff --git a/src/main/scala/uk/co/gresearch/package.scala b/src/main/scala/uk/co/gresearch/package.scala index 188db09a..5053f2a5 100644 --- a/src/main/scala/uk/co/gresearch/package.scala +++ b/src/main/scala/uk/co/gresearch/package.scala @@ -20,28 +20,28 @@ package object gresearch { trait ConditionalCall[T] { def call(f: T => T): T - def either[R](f: T => R): ConditionalCallOr[T,R] + def either[R](f: T => R): ConditionalCallOr[T, R] } - trait ConditionalCallOr[T,R] { + trait ConditionalCallOr[T, R] { def or(f: T => R): R } case class TrueCall[T](t: T) extends ConditionalCall[T] { override def call(f: T => T): T = f(t) - override def either[R](f: T => R): ConditionalCallOr[T,R] = TrueCallOr[T,R](f(t)) + override def either[R](f: T => R): ConditionalCallOr[T, R] = TrueCallOr[T, R](f(t)) } case class FalseCall[T](t: T) extends ConditionalCall[T] { override def call(f: T => T): T = t - override def either[R](f: T => R): ConditionalCallOr[T,R] = FalseCallOr[T,R](t) + override def either[R](f: T => R): ConditionalCallOr[T, R] = FalseCallOr[T, R](t) } - case class TrueCallOr[T,R](r: R) extends ConditionalCallOr[T,R] { + case class TrueCallOr[T, R](r: R) extends ConditionalCallOr[T, R] { override def or(f: T => R): R = r } - case class FalseCallOr[T,R](t: T) extends ConditionalCallOr[T,R] { + case class FalseCallOr[T, R](t: T) extends ConditionalCallOr[T, R] { override def or(f: T => R): R = f(t) } @@ -71,16 +71,17 @@ package object gresearch { * * which either needs many temporary variables or duplicate code. * - * @param condition condition - * @return the function result + * @param condition + * condition + * @return + * the function result */ def on(condition: Boolean): ConditionalCall[T] = { if (condition) TrueCall[T](t) else FalseCall[T](t) } /** - * Allows to call a function on the decorated instance conditionally. - * This is an alias for the `on` function. + * Allows to call a function on the decorated instance conditionally. This is an alias for the `on` function. * * This allows fluent code like * @@ -103,8 +104,10 @@ package object gresearch { * * which either needs many temporary variables or duplicate code. * - * @param condition condition - * @return the function result + * @param condition + * condition + * @return + * the function result */ def when(condition: Boolean): ConditionalCall[T] = on(condition) @@ -131,8 +134,10 @@ package object gresearch { * * where the effective sequence of operations is not clear. * - * @param f function - * @return the function result + * @param f + * function + * @return + * the function result */ def call[R](f: T => R): R = f(t) } diff --git a/src/main/scala/uk/co/gresearch/spark/Backticks.scala b/src/main/scala/uk/co/gresearch/spark/Backticks.scala index efaaf19b..95cdd41e 100644 --- a/src/main/scala/uk/co/gresearch/spark/Backticks.scala +++ b/src/main/scala/uk/co/gresearch/spark/Backticks.scala @@ -34,12 +34,15 @@ object Backticks { * col(Backticks.column_name("a.column", "a.field")) // produces "`a.column`.`a.field`" * }}} * - * @param string a string - * @param strings more strings + * @param string + * a string + * @param strings + * more strings * @return */ @scala.annotation.varargs def column_name(string: String, strings: String*): String = (string +: strings) - .map(s => if (s.contains(".") && !s.startsWith("`") && !s.endsWith("`")) s"`$s`" else s).mkString(".") + .map(s => if (s.contains(".") && !s.startsWith("`") && !s.endsWith("`")) s"`$s`" else s) + .mkString(".") } diff --git a/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala b/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala index 60241e0d..9b831833 100644 --- a/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala +++ b/src/main/scala/uk/co/gresearch/spark/BuildVersion.scala @@ -37,7 +37,7 @@ trait BuildVersion { } lazy val VersionString: String = props.getProperty("project.version") - + lazy val BuildSparkMajorVersion: Int = props.getProperty("spark.major.version").toInt lazy val BuildSparkMinorVersion: Int = props.getProperty("spark.minor.version").toInt lazy val BuildSparkPatchVersion: Int = props.getProperty("spark.patch.version").split("-").head.toInt diff --git a/src/main/scala/uk/co/gresearch/spark/Histogram.scala b/src/main/scala/uk/co/gresearch/spark/Histogram.scala index c4ef8c49..c52cc679 100644 --- a/src/main/scala/uk/co/gresearch/spark/Histogram.scala +++ b/src/main/scala/uk/co/gresearch/spark/Histogram.scala @@ -23,20 +23,25 @@ import uk.co.gresearch.ExtendedAny import scala.collection.JavaConverters object Histogram { + /** - * Compute the histogram of a column when aggregated by aggregate columns. - * Thresholds are expected to be provided in ascending order. - * The result dataframe contains the aggregate and histogram columns only. - * For each threshold value in thresholds, there will be a column named s"≤threshold". - * There will also be a final column called s">last_threshold", that counts the remaining - * values that exceed the last threshold. + * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in + * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value + * in thresholds, there will be a column named s"≤threshold". There will also be a final column called + * s">last_threshold", that counts the remaining values that exceed the last threshold. * - * @param df dataset to compute histogram from - * @param thresholds sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn - * @param valueColumn histogram is computed for values of this column - * @param aggregateColumns histogram is computed against these columns - * @tparam T type of histogram thresholds - * @return dataframe with aggregate and histogram columns + * @param df + * dataset to compute histogram from + * @param thresholds + * sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn + * @param valueColumn + * histogram is computed for values of this column + * @param aggregateColumns + * histogram is computed against these columns + * @tparam T + * type of histogram thresholds + * @return + * dataframe with aggregate and histogram columns */ def of[D, T](df: Dataset[D], thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = { if (thresholds.isEmpty) @@ -49,10 +54,9 @@ object Histogram { df.toDF() .withColumn(s"≤${thresholds.head}", when(valueColumn <= thresholds.head, 1).otherwise(0)) - .call( - bins.foldLeft(_) { case (df, bin) => - df.withColumn(s"≤${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0)) - }) + .call(bins.foldLeft(_) { case (df, bin) => + df.withColumn(s"≤${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0)) + }) .withColumn(s">${thresholds.last}", when(valueColumn > thresholds.last, 1).otherwise(0)) .groupBy(aggregateColumns: _*) .agg( @@ -63,22 +67,31 @@ object Histogram { } /** - * Compute the histogram of a column when aggregated by aggregate columns. - * Thresholds are expected to be provided in ascending order. - * The result dataframe contains the aggregate and histogram columns only. - * For each threshold value in thresholds, there will be a column named s"≤threshold". - * There will also be a final column called s">last_threshold", that counts the remaining - * values that exceed the last threshold. + * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in + * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value + * in thresholds, there will be a column named s"≤threshold". There will also be a final column called + * s">last_threshold", that counts the remaining values that exceed the last threshold. * - * @param df dataset to compute histogram from - * @param thresholds sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn - * @param valueColumn histogram is computed for values of this column - * @param aggregateColumns histogram is computed against these columns - * @tparam T type of histogram thresholds - * @return dataframe with aggregate and histogram columns + * @param df + * dataset to compute histogram from + * @param thresholds + * sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn + * @param valueColumn + * histogram is computed for values of this column + * @param aggregateColumns + * histogram is computed against these columns + * @tparam T + * type of histogram thresholds + * @return + * dataframe with aggregate and histogram columns */ @scala.annotation.varargs - def of[D, T](df: Dataset[D], thresholds: java.util.List[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = + def of[D, T]( + df: Dataset[D], + thresholds: java.util.List[T], + valueColumn: Column, + aggregateColumns: Column* + ): DataFrame = of(df, JavaConverters.iterableAsScalaIterable(thresholds).toSeq, valueColumn, aggregateColumns: _*) } diff --git a/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala b/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala index e5015e65..5bc7d48c 100644 --- a/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala +++ b/src/main/scala/uk/co/gresearch/spark/RowNumbers.scala @@ -21,10 +21,12 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, functions} import org.apache.spark.sql.functions.{coalesce, col, lit, max, monotonically_increasing_id, spark_partition_id, sum} import org.apache.spark.storage.StorageLevel -case class RowNumbersFunc(rowNumberColumnName: String = "row_number", - storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, - unpersistHandle: UnpersistHandle = UnpersistHandle.Noop, - orderColumns: Seq[Column] = Seq.empty) { +case class RowNumbersFunc( + rowNumberColumnName: String = "row_number", + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + unpersistHandle: UnpersistHandle = UnpersistHandle.Noop, + orderColumns: Seq[Column] = Seq.empty +) { def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc = this.copy(rowNumberColumnName = rowNumberColumnName) @@ -39,7 +41,11 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number", this.copy(orderColumns = orderColumns) def of[D](df: Dataset[D]): DataFrame = { - if (storageLevel.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) { + if ( + storageLevel.equals( + StorageLevel.NONE + ) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5) + ) { throw new IllegalArgumentException(s"Storage level $storageLevel not supported with Spark 3.5.0 and above.") } @@ -53,7 +59,9 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number", val partitionOffsetColumnName = prefix + "partition_offset" // if no order is given, we preserve existing order - val dfOrdered = if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id()) else df.orderBy(orderColumns: _*) + val dfOrdered = + if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id()) + else df.orderBy(orderColumns: _*) val order = if (orderColumns.isEmpty) Seq(col(monoIdColumnName)) else orderColumns // add partition ids and local row numbers @@ -66,17 +74,22 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number", .withColumn(localRowNumberColumnName, functions.row_number().over(localRowNumberWindow)) // compute row offset for the partitions - val cumRowNumbersWindow = Window.orderBy(partitionIdColumnName) + val cumRowNumbersWindow = Window + .orderBy(partitionIdColumnName) .rowsBetween(Window.unboundedPreceding, Window.currentRow) val partitionOffsets = dfWithLocalRowNumbers .groupBy(partitionIdColumnName) .agg(max(localRowNumberColumnName).alias(maxLocalRowNumberColumnName)) .withColumn(cumRowNumbersColumnName, sum(maxLocalRowNumberColumnName).over(cumRowNumbersWindow)) - .select(col(partitionIdColumnName) + 1 as partitionIdColumnName, col(cumRowNumbersColumnName).as(partitionOffsetColumnName)) + .select( + col(partitionIdColumnName) + 1 as partitionIdColumnName, + col(cumRowNumbersColumnName).as(partitionOffsetColumnName) + ) // compute global row number by adding local row number with partition offset val partitionOffsetColumn = coalesce(col(partitionOffsetColumnName), lit(0)) - dfWithLocalRowNumbers.join(partitionOffsets, Seq(partitionIdColumnName), "left") + dfWithLocalRowNumbers + .join(partitionOffsets, Seq(partitionIdColumnName), "left") .withColumn(rowNumberColumnName, col(localRowNumberColumnName) + partitionOffsetColumn) .drop(monoIdColumnName, partitionIdColumnName, localRowNumberColumnName, partitionOffsetColumnName) } diff --git a/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala b/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala index 863b4b6b..36f597d3 100644 --- a/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala +++ b/src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala @@ -51,7 +51,7 @@ case class SilentUnpersistHandle() extends UnpersistHandle { } } -case class NoopUnpersistHandle() extends UnpersistHandle{ +case class NoopUnpersistHandle() extends UnpersistHandle { override def setDataFrame(dataframe: DataFrame): DataFrame = dataframe override def apply(): Unit = {} override def apply(blocking: Boolean): Unit = {} diff --git a/src/main/scala/uk/co/gresearch/spark/diff/App.scala b/src/main/scala/uk/co/gresearch/spark/diff/App.scala index 0df86d24..22a6be79 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/App.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/App.scala @@ -23,31 +23,28 @@ import uk.co.gresearch._ object App { // define available options - case class Options(master: Option[String] = None, - appName: Option[String] = None, - hive: Boolean = false, - - leftPath: Option[String] = None, - rightPath: Option[String] = None, - outputPath: Option[String] = None, - - leftFormat: Option[String] = None, - rightFormat: Option[String] = None, - outputFormat: Option[String] = None, - - leftSchema: Option[String] = None, - rightSchema: Option[String] = None, - - leftOptions: Map[String, String] = Map.empty, - rightOptions: Map[String, String] = Map.empty, - outputOptions: Map[String, String] = Map.empty, - - ids: Seq[String] = Seq.empty, - ignore: Seq[String] = Seq.empty, - saveMode: SaveMode = SaveMode.ErrorIfExists, - filter: Set[String] = Set.empty, - statistics: Boolean = false, - diffOptions: DiffOptions = DiffOptions.default) + case class Options( + master: Option[String] = None, + appName: Option[String] = None, + hive: Boolean = false, + leftPath: Option[String] = None, + rightPath: Option[String] = None, + outputPath: Option[String] = None, + leftFormat: Option[String] = None, + rightFormat: Option[String] = None, + outputFormat: Option[String] = None, + leftSchema: Option[String] = None, + rightSchema: Option[String] = None, + leftOptions: Map[String, String] = Map.empty, + rightOptions: Map[String, String] = Map.empty, + outputOptions: Map[String, String] = Map.empty, + ids: Seq[String] = Seq.empty, + ignore: Seq[String] = Seq.empty, + saveMode: SaveMode = SaveMode.ErrorIfExists, + filter: Set[String] = Set.empty, + statistics: Boolean = false, + diffOptions: DiffOptions = DiffOptions.default + ) // read options from args val programName = s"spark-extension_${spark.BuildScalaCompatVersionString}-${spark.VersionString}.jar" @@ -105,11 +102,13 @@ object App { note("Input and output") opt[String]('f', "format") .valueName("") - .action((x, c) => c.copy( - leftFormat = c.leftFormat.orElse(Some(x)), - rightFormat = c.rightFormat.orElse(Some(x)), - outputFormat = c.outputFormat.orElse(Some(x)) - )) + .action((x, c) => + c.copy( + leftFormat = c.leftFormat.orElse(Some(x)), + rightFormat = c.rightFormat.orElse(Some(x)), + outputFormat = c.outputFormat.orElse(Some(x)) + ) + ) .text("input and output file format (csv, json, parquet, ...)") opt[String]("left-format") .valueName("") @@ -127,10 +126,12 @@ object App { note("") opt[String]('s', "schema") .valueName("") - .action((x, c) => c.copy( - leftSchema = c.leftSchema.orElse(Some(x)), - rightSchema = c.rightSchema.orElse(Some(x)) - )) + .action((x, c) => + c.copy( + leftSchema = c.leftSchema.orElse(Some(x)), + rightSchema = c.rightSchema.orElse(Some(x)) + ) + ) .text("input schema") opt[String]("left-schema") .valueName("") @@ -182,7 +183,9 @@ object App { .optional() .valueName("") .action((x, c) => c.copy(filter = c.filter + x)) - .text(s"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)") + .text( + s"Filters for rows with these diff actions, with default diffing options use 'N', 'I', 'D', or 'C' (see 'Diffing options' section)" + ) opt[Unit]("statistics") .optional() .action((_, c) => c.copy(statistics = true)) @@ -245,45 +248,83 @@ object App { help("help").text("prints this usage text") } - def read(spark: SparkSession, format: Option[String], path: String, schema: Option[String], options: Map[String, String]): DataFrame = + def read( + spark: SparkSession, + format: Option[String], + path: String, + schema: Option[String], + options: Map[String, String] + ): DataFrame = spark.read - .when(format.isDefined).call(_.format(format.get)) + .when(format.isDefined) + .call(_.format(format.get)) .options(options) - .when(schema.isDefined).call(_.schema(schema.get)) - .when(format.isDefined).either(_.load(path)).or(_.table(path)) + .when(schema.isDefined) + .call(_.schema(schema.get)) + .when(format.isDefined) + .either(_.load(path)) + .or(_.table(path)) - def write(df: DataFrame, format: Option[String], path: String, options: Map[String, String], saveMode: SaveMode, filter: Set[String], saveStats: Boolean, diffOptions: DiffOptions): Unit = - df.when(filter.nonEmpty).call(_.where(col(diffOptions.diffColumn).isInCollection(filter))) - .when(saveStats).call(_.groupBy(diffOptions.diffColumn).count.orderBy(diffOptions.diffColumn)) + def write( + df: DataFrame, + format: Option[String], + path: String, + options: Map[String, String], + saveMode: SaveMode, + filter: Set[String], + saveStats: Boolean, + diffOptions: DiffOptions + ): Unit = + df.when(filter.nonEmpty) + .call(_.where(col(diffOptions.diffColumn).isInCollection(filter))) + .when(saveStats) + .call(_.groupBy(diffOptions.diffColumn).count.orderBy(diffOptions.diffColumn)) .write - .when(format.isDefined).call(_.format(format.get)) + .when(format.isDefined) + .call(_.format(format.get)) .options(options) .mode(saveMode) - .when(format.isDefined).either(_.save(path)).or(_.saveAsTable(path)) + .when(format.isDefined) + .either(_.save(path)) + .or(_.saveAsTable(path)) def main(args: Array[String]): Unit = { // parse options val options = parser.parse(args, Options()) match { case Some(options) => options - case None => sys.exit(1) + case None => sys.exit(1) } val unknownFilters = options.filter.filter(filter => !options.diffOptions.diffValues.contains(filter)) if (unknownFilters.nonEmpty) { - throw new RuntimeException(s"Filter ${unknownFilters.mkString("'", "', '", "'")} not allowed, " + - s"these are the configured diff values: ${options.diffOptions.diffValues.mkString("'", "', '", "'")}") + throw new RuntimeException( + s"Filter ${unknownFilters.mkString("'", "', '", "'")} not allowed, " + + s"these are the configured diff values: ${options.diffOptions.diffValues.mkString("'", "', '", "'")}" + ) } // create spark session - val spark = SparkSession.builder() + val spark = SparkSession + .builder() .appName(options.appName.get) - .when(options.hive).call(_.enableHiveSupport()) - .when(options.master.isDefined).call(_.master(options.master.get)) + .when(options.hive) + .call(_.enableHiveSupport()) + .when(options.master.isDefined) + .call(_.master(options.master.get)) .getOrCreate() // read and write val left = read(spark, options.leftFormat, options.leftPath.get, options.leftSchema, options.leftOptions) val right = read(spark, options.rightFormat, options.rightPath.get, options.rightSchema, options.rightOptions) val diff = left.diff(right, options.diffOptions, options.ids, options.ignore) - write(diff, options.outputFormat, options.outputPath.get, options.outputOptions, options.saveMode, options.filter, options.statistics, options.diffOptions) + write( + diff, + options.outputFormat, + options.outputPath.get, + options.outputOptions, + options.saveMode, + options.filter, + options.statistics, + options.diffOptions + ) } } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala b/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala index 9fa60f5b..9106de15 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/Diff.scala @@ -25,113 +25,156 @@ import scala.collection.JavaConverters /** * Differ class to diff two Datasets. See Differ.of(…) for details. - * @param options options for the diffing process + * @param options + * options for the diffing process */ class Differ(options: DiffOptions) { - private[diff] def checkSchema[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Unit = { - require(left.columns.length == left.columns.toSet.size && - right.columns.length == right.columns.toSet.size, + private[diff] def checkSchema[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Unit = { + require( + left.columns.length == left.columns.toSet.size && + right.columns.length == right.columns.toSet.size, "The datasets have duplicate columns.\n" + s"Left column names: ${left.columns.mkString(", ")}\n" + - s"Right column names: ${right.columns.mkString(", ")}") + s"Right column names: ${right.columns.mkString(", ")}" + ) val leftNonIgnored = left.columns.diffCaseSensitivity(ignoreColumns) val rightNonIgnored = right.columns.diffCaseSensitivity(ignoreColumns) val exceptIgnoredColumnsMsg = if (ignoreColumns.nonEmpty) " except ignored columns" else "" - require(leftNonIgnored.length == rightNonIgnored.length, + require( + leftNonIgnored.length == rightNonIgnored.length, "The number of columns doesn't match.\n" + s"Left column names$exceptIgnoredColumnsMsg (${leftNonIgnored.length}): ${leftNonIgnored.mkString(", ")}\n" + - s"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(", ")}") + s"Right column names$exceptIgnoredColumnsMsg (${rightNonIgnored.length}): ${rightNonIgnored.mkString(", ")}" + ) require(leftNonIgnored.length > 0, s"The schema$exceptIgnoredColumnsMsg must not be empty") // column types must match but we ignore the nullability of columns - val leftFields = left.schema.fields.filter(f => !ignoreColumns.containsCaseSensitivity(f.name)).map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType) - val rightFields = right.schema.fields.filter(f => !ignoreColumns.containsCaseSensitivity(f.name)).map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType) + val leftFields = left.schema.fields + .filter(f => !ignoreColumns.containsCaseSensitivity(f.name)) + .map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType) + val rightFields = right.schema.fields + .filter(f => !ignoreColumns.containsCaseSensitivity(f.name)) + .map(f => handleConfiguredCaseSensitivity(f.name) -> f.dataType) val leftExtraSchema = leftFields.diff(rightFields) val rightExtraSchema = rightFields.diff(leftFields) - require(leftExtraSchema.isEmpty && rightExtraSchema.isEmpty, + require( + leftExtraSchema.isEmpty && rightExtraSchema.isEmpty, "The datasets do not have the same schema.\n" + s"Left extra columns: ${leftExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}\n" + - s"Right extra columns: ${rightExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}") + s"Right extra columns: ${rightExtraSchema.map(t => s"${t._1} (${t._2})").mkString(", ")}" + ) val columns = leftNonIgnored val pkColumns = if (idColumns.isEmpty) columns.toList else idColumns val nonPkColumns = columns.diffCaseSensitivity(pkColumns) val missingIdColumns = pkColumns.diffCaseSensitivity(columns) - require(missingIdColumns.isEmpty, - s"Some id columns do not exist: ${missingIdColumns.mkString(", ")} missing among ${columns.mkString(", ")}") + require( + missingIdColumns.isEmpty, + s"Some id columns do not exist: ${missingIdColumns.mkString(", ")} missing among ${columns.mkString(", ")}" + ) val missingIgnoreColumns = ignoreColumns.diffCaseSensitivity(left.columns).diffCaseSensitivity(right.columns) - require(missingIgnoreColumns.isEmpty, + require( + missingIgnoreColumns.isEmpty, s"Some ignore columns do not exist: ${missingIgnoreColumns.mkString(", ")} " + - s"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(", ")}") + s"missing among ${(leftNonIgnored ++ rightNonIgnored).distinct.sorted.mkString(", ")}" + ) - require(!pkColumns.containsCaseSensitivity(options.diffColumn), - s"The id columns must not contain the diff column name '${options.diffColumn}': ${pkColumns.mkString(", ")}") - require(options.changeColumn.forall(!pkColumns.containsCaseSensitivity(_)), - s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(", ")}") + require( + !pkColumns.containsCaseSensitivity(options.diffColumn), + s"The id columns must not contain the diff column name '${options.diffColumn}': ${pkColumns.mkString(", ")}" + ) + require( + options.changeColumn.forall(!pkColumns.containsCaseSensitivity(_)), + s"The id columns must not contain the change column name '${options.changeColumn.get}': ${pkColumns.mkString(", ")}" + ) val diffValueColumns = getDiffColumns(pkColumns, nonPkColumns, left, right, ignoreColumns).map(_._1).diff(pkColumns) if (Seq(DiffMode.LeftSide, DiffMode.RightSide).contains(options.diffMode)) { - require(!diffValueColumns.containsCaseSensitivity(options.diffColumn), + require( + !diffValueColumns.containsCaseSensitivity(options.diffColumn), s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " + s"non-id columns must not contain the diff column name '${options.diffColumn}': " + - s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}") + s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}" + ) - require(options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)), + require( + options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)), s"The ${if (options.diffMode == DiffMode.LeftSide) "left" else "right"} " + s"non-id columns must not contain the change column name '${options.changeColumn.get}': " + - s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}") + s"${(if (options.diffMode == DiffMode.LeftSide) left else right).columns.diffCaseSensitivity(idColumns).mkString(", ")}" + ) } else { - require(!diffValueColumns.containsCaseSensitivity(options.diffColumn), + require( + !diffValueColumns.containsCaseSensitivity(options.diffColumn), s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " + s"together with these non-id columns " + s"must not produce the diff column name '${options.diffColumn}': " + - s"${nonPkColumns.mkString(", ")}") + s"${nonPkColumns.mkString(", ")}" + ) - require(options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)), + require( + options.changeColumn.forall(!diffValueColumns.containsCaseSensitivity(_)), s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " + s"together with these non-id columns " + s"must not produce the change column name '${options.changeColumn.orNull}': " + - s"${nonPkColumns.mkString(", ")}") + s"${nonPkColumns.mkString(", ")}" + ) - require(diffValueColumns.forall(!pkColumns.containsCaseSensitivity(_)), + require( + diffValueColumns.forall(!pkColumns.containsCaseSensitivity(_)), s"The column prefixes '${options.leftColumnPrefix}' and '${options.rightColumnPrefix}', " + s"together with these non-id columns " + s"must not produce any id column name '${pkColumns.mkString("', '")}': " + - s"${nonPkColumns.mkString(", ")}") + s"${nonPkColumns.mkString(", ")}" + ) } } - private def getChangeColumn(existsColumnName: String, valueColumns: Seq[String], left: Dataset[_], right: Dataset[_]): Option[Column] = { + private def getChangeColumn( + existsColumnName: String, + valueColumns: Seq[String], + left: Dataset[_], + right: Dataset[_] + ): Option[Column] = { options.changeColumn .map(changeColumn => - when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null)). - otherwise( + when(left(existsColumnName).isNull || right(existsColumnName).isNull, lit(null)) + .otherwise( Some(valueColumns.toSeq) .filter(_.nonEmpty) .map(columns => concat( - columns.map(c => - when(left(backticks(c)) <=> right(backticks(c)), array()).otherwise(array(lit(c))) - ): _* + columns + .map(c => when(left(backticks(c)) <=> right(backticks(c)), array()).otherwise(array(lit(c)))): _* ) - ).getOrElse( - array().cast(ArrayType(StringType, containsNull = false)) - ) - ). - as(changeColumn) + ) + .getOrElse( + array().cast(ArrayType(StringType, containsNull = false)) + ) + ) + .as(changeColumn) ) } - private[diff] def getDiffColumns[T, U](pkColumns: Seq[String], valueColumns: Seq[String], - left: Dataset[T], right: Dataset[U], - ignoreColumns: Seq[String]): Seq[(String, Column)] = { + private[diff] def getDiffColumns[T, U]( + pkColumns: Seq[String], + valueColumns: Seq[String], + left: Dataset[T], + right: Dataset[U], + ignoreColumns: Seq[String] + ): Seq[(String, Column)] = { val idColumns = pkColumns.map(c => c -> coalesce(left(backticks(c)), right(backticks(c))).as(c)) val leftValueColumns = left.columns.filterIsInCaseSensitivity(valueColumns) @@ -145,8 +188,24 @@ class Differ(options: DiffOptions) { val (leftValues, rightValues) = if (options.sparseMode) { ( - leftNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c))) else left(backticks(c))))).toMap, - rightNonPkColumns.map(c => (handleConfiguredCaseSensitivity(c), c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c))) else right(backticks(c))))).toMap + leftNonPkColumns + .map(c => + ( + handleConfiguredCaseSensitivity(c), + c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), left(backticks(c))) + else left(backticks(c))) + ) + ) + .toMap, + rightNonPkColumns + .map(c => + ( + handleConfiguredCaseSensitivity(c), + c -> (if (options.sparseMode) when(not(left(backticks(c)) <=> right(backticks(c))), right(backticks(c))) + else right(backticks(c))) + ) + ) + .toMap ) } else { ( @@ -189,15 +248,22 @@ class Differ(options: DiffOptions) { case DiffMode.LeftSide | DiffMode.RightSide => // in left-side / right-side mode, we do not prefix columns ( - if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues)) else valueColumns.map(alias(None, rightValues)) - ) ++ ( - if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues)) else rightIgnoredColumns.map(alias(None, rightValues)) - ) + if (options.diffMode == DiffMode.LeftSide) valueColumns.map(alias(None, leftValues)) + else valueColumns.map(alias(None, rightValues)) + ) ++ ( + if (options.diffMode == DiffMode.LeftSide) leftIgnoredColumns.map(alias(None, leftValues)) + else rightIgnoredColumns.map(alias(None, rightValues)) + ) } idColumns ++ nonIdColumns } - private def doDiff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty): DataFrame = { + private def doDiff[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: Seq[String], + ignoreColumns: Seq[String] = Seq.empty + ): DataFrame = { checkSchema(left, right, idColumns, ignoreColumns) val columns = left.columns.diffCaseSensitivity(ignoreColumns).toList @@ -209,18 +275,21 @@ class Differ(options: DiffOptions) { val existsColumnName = distinctPrefixFor(left.columns) + "exists" val leftWithExists = left.withColumn(existsColumnName, lit(1)) val rightWithExists = right.withColumn(existsColumnName, lit(1)) - val joinCondition = pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _) - val unChanged = valueVolumnsWithComparator.map { case (c, cmp) => - cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c))) - }.reduceOption(_ && _) + val joinCondition = + pkColumns.map(c => leftWithExists(backticks(c)) <=> rightWithExists(backticks(c))).reduce(_ && _) + val unChanged = valueVolumnsWithComparator + .map { case (c, cmp) => + cmp.equiv(leftWithExists(backticks(c)), rightWithExists(backticks(c))) + } + .reduceOption(_ && _) val changeCondition = not(unChanged.getOrElse(lit(true))) val diffActionColumn = - when(leftWithExists(existsColumnName).isNull, lit(options.insertDiffValue)). - when(rightWithExists(existsColumnName).isNull, lit(options.deleteDiffValue)). - when(changeCondition, lit(options.changeDiffValue)). - otherwise(lit(options.nochangeDiffValue)). - as(options.diffColumn) + when(leftWithExists(existsColumnName).isNull, lit(options.insertDiffValue)) + .when(rightWithExists(existsColumnName).isNull, lit(options.deleteDiffValue)) + .when(changeCondition, lit(options.changeDiffValue)) + .otherwise(lit(options.nochangeDiffValue)) + .as(options.diffColumn) val diffColumns = getDiffColumns(pkColumns, valueColumns, left, right, ignoreColumns).map(_._2) val changeColumn = getChangeColumn(existsColumnName, valueColumns, leftWithExists, rightWithExists) @@ -228,27 +297,26 @@ class Differ(options: DiffOptions) { .map(Seq(_)) .getOrElse(Seq.empty[Column]) - leftWithExists.join(rightWithExists, joinCondition, "fullouter") + leftWithExists + .join(rightWithExists, joinCondition, "fullouter") .select((diffActionColumn +: changeColumn) ++ diffColumns: _*) } /** - * Returns a new DataFrame that contains the differences between two Datasets of - * the same type `T`. Both Datasets must contain the same set of column names and data types. - * The order of columns in the two Datasets is not relevant as columns are compared based on the - * name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must + * contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as + * columns are compared based on the name, not the the position. * - * Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between two Datasets, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset, - * that do not exist in the right Dataset are marked as `"D"`elete. + * Optional `id` columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset, + * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of + * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete. * - * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows - * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all + * changes will exists as respective `"D"`elete and `"I"`nsert. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -281,33 +349,31 @@ class Differ(options: DiffOptions) { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are + * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. */ @scala.annotation.varargs def diff[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame = doDiff(left, right, idColumns) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types, - * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as - * columns are compared based on the name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both + * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The + * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the + * position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between two Datasets, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset, - * that do not exist in the right Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset, + * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of + * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete. * - * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows - * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all + * changes will exists as respective `"D"`elete and `"I"`nsert. * * Values in optional ignore columns are not compared but included in the output DataFrame. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -340,32 +406,30 @@ class Differ(options: DiffOptions) { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are + * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. */ def diff[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = doDiff(left, right, idColumns, ignoreColumns) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types, - * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as - * columns are compared based on the name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both + * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The + * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the + * position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between two Datasets, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset, - * that do not exist in the right Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset, + * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of + * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete. * - * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows - * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all + * changes will exists as respective `"D"`elete and `"I"`nsert. * * Values in optional ignore columns are not compared but included in the output DataFrame. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -398,12 +462,21 @@ class Differ(options: DiffOptions) { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are + * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. */ - def diff[T, U](left: Dataset[T], right: Dataset[U], idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): DataFrame = { - diff(left, right, JavaConverters.iterableAsScalaIterable(idColumns).toSeq, JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq) + def diff[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: java.util.List[String], + ignoreColumns: java.util.List[String] + ): DataFrame = { + diff( + left, + right, + JavaConverters.iterableAsScalaIterable(idColumns).toSeq, + JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq + ) } /** @@ -415,21 +488,22 @@ class Differ(options: DiffOptions) { */ // no @scala.annotation.varargs here as implicit arguments are explicit in Java // this signature is redundant to the other diffAs method in Java - def diffAs[T, U, V](left: Dataset[T], right: Dataset[T], idColumns: String*) - (implicit diffEncoder: Encoder[V]): Dataset[V] = { + def diffAs[T, U, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit + diffEncoder: Encoder[V] + ): Dataset[V] = { diffAs(left, right, diffEncoder, idColumns: _*) } /** - * Returns a new Dataset that contains the differences between two Datasets of - * similar types `T` and `U`. + * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`. * * See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`. * * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. */ - def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]) - (implicit diffEncoder: Encoder[V]): Dataset[V] = { + def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit + diffEncoder: Encoder[V] + ): Dataset[V] = { diffAs(left, right, diffEncoder, idColumns, ignoreColumns) } @@ -441,50 +515,67 @@ class Differ(options: DiffOptions) { * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ @scala.annotation.varargs - def diffAs[T, V](left: Dataset[T], right: Dataset[T], - diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = { + def diffAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = { diffAs(left, right, diffEncoder, idColumns, Seq.empty) } /** - * Returns a new Dataset that contains the differences between two Datasets of - * similar types `T` and `U`. + * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. * * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ - def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], - diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] = { - val nonIdColumns = if (idColumns.isEmpty) Seq.empty else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq + def diffAs[T, U, V]( + left: Dataset[T], + right: Dataset[U], + diffEncoder: Encoder[V], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[V] = { + val nonIdColumns = + if (idColumns.isEmpty) Seq.empty + else left.columns.diffCaseSensitivity(idColumns).diffCaseSensitivity(ignoreColumns).toSeq val encColumns = diffEncoder.schema.fields.map(_.name) - val diffColumns = Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1) + val diffColumns = + Seq(options.diffColumn) ++ getDiffColumns(idColumns, nonIdColumns, left, right, ignoreColumns).map(_._1) val extraColumns = encColumns.diffCaseSensitivity(diffColumns) - require(extraColumns.isEmpty, + require( + extraColumns.isEmpty, s"Diff encoder's columns must be part of the diff result schema, " + - s"these columns are unexpected: ${extraColumns.mkString(", ")}") + s"these columns are unexpected: ${extraColumns.mkString(", ")}" + ) diff(left, right, idColumns, ignoreColumns).as[V](diffEncoder) } /** - * Returns a new Dataset that contains the differences between two Datasets of - * similar types `T` and `U`. + * Returns a new Dataset that contains the differences between two Datasets of similar types `T` and `U`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. * * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ - def diffAs[T, U, V](left: Dataset[T], right: Dataset[U], diffEncoder: Encoder[V], - idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[V] = { - diffAs(left, right, diffEncoder, - JavaConverters.iterableAsScalaIterable(idColumns).toSeq, JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq) + def diffAs[T, U, V]( + left: Dataset[T], + right: Dataset[U], + diffEncoder: Encoder[V], + idColumns: java.util.List[String], + ignoreColumns: java.util.List[String] + ): Dataset[V] = { + diffAs( + left, + right, + diffEncoder, + JavaConverters.iterableAsScalaIterable(idColumns).toSeq, + JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq + ) } /** - * Returns a new Dataset that contains the differences between two Dataset of - * the same type `T` as tuples of type `(String, T, T)`. + * Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type + * `(String, T, T)`. * * See `diff(Dataset[T], Dataset[T], String*)`. */ @@ -495,27 +586,39 @@ class Differ(options: DiffOptions) { } /** - * Returns a new Dataset that contains the differences between two Dataset of - * similar types `T` and `U` as tuples of type `(String, T, U)`. + * Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of + * type `(String, T, U)`. * * See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`. */ - def diffWith[T, U](left: Dataset[T], right: Dataset[U], - idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] = { + def diffWith[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[(String, T, U)] = { val df = diff(left, right, idColumns, ignoreColumns) diffWith(df, idColumns: _*)(left.encoder, right.encoder) } /** - * Returns a new Dataset that contains the differences between two Dataset of - * similar types `T` and `U` as tuples of type `(String, T, U)`. + * Returns a new Dataset that contains the differences between two Dataset of similar types `T` and `U` as tuples of + * type `(String, T, U)`. * * See `diff(Dataset[T], Dataset[U], Seq[String], Seq[String])`. */ - def diffWith[T, U](left: Dataset[T], right: Dataset[U], - idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[(String, T, U)] = { - diffWith(left, right, - JavaConverters.iterableAsScalaIterable(idColumns).toSeq, JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq) + def diffWith[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: java.util.List[String], + ignoreColumns: java.util.List[String] + ): Dataset[(String, T, U)] = { + diffWith( + left, + right, + JavaConverters.iterableAsScalaIterable(idColumns).toSeq, + JavaConverters.iterableAsScalaIterable(ignoreColumns).toSeq + ) } private def columnsOfSide(df: DataFrame, idColumns: Seq[String], sidePrefix: String): Seq[Column] = { @@ -525,7 +628,7 @@ class Differ(options: DiffOptions) { .map(c => if (idColumns.contains(c)) col(c) else col(c).as(c.replace(prefix, ""))) } - private def diffWith[T : Encoder, U : Encoder](diff: DataFrame, idColumns: String*): Dataset[(String, T, U)] = { + private def diffWith[T: Encoder, U: Encoder](diff: DataFrame, idColumns: String*): Dataset[(String, T, U)] = { val leftColumns = columnsOfSide(diff, idColumns, options.leftColumnPrefix) val rightColumns = columnsOfSide(diff, idColumns, options.rightColumnPrefix) @@ -540,7 +643,9 @@ class Differ(options: DiffOptions) { val plan = diff.select(diffColumn, leftStruct, rightStruct).queryExecution.logical val encoder: Encoder[(String, T, U)] = Encoders.tuple( - Encoders.STRING, implicitly[Encoder[T]], implicitly[Encoder[U]] + Encoders.STRING, + implicitly[Encoder[T]], + implicitly[Encoder[U]] ) new Dataset(diff.sparkSession, plan, encoder) @@ -555,22 +660,20 @@ object Diff { val default = new Differ(DiffOptions.default) /** - * Returns a new DataFrame that contains the differences between two Datasets - * of the same type `T`. Both Datasets must contain the same set of column names and data types. - * The order of columns in the two Datasets is not relevant as columns are compared based on the - * name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of the same type `T`. Both Datasets must + * contain the same set of column names and data types. The order of columns in the two Datasets is not relevant as + * columns are compared based on the name, not the the position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between two Datasets, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset, - * that do not exist in the right Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset, + * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of + * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete. * - * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows - * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all + * changes will exists as respective `"D"`elete and `"I"`nsert. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -603,33 +706,31 @@ object Diff { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are + * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. */ @scala.annotation.varargs def of[T](left: Dataset[T], right: Dataset[T], idColumns: String*): DataFrame = default.diff(left, right, idColumns: _*) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types, - * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as - * columns are compared based on the name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both + * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The + * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the + * position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between two Datasets, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset, - * that do not exist in the right Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset, + * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of + * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete. * - * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows - * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all + * changes will exists as respective `"D"`elete and `"I"`nsert. * * Values in optional ignore columns are not compared but included in the output DataFrame. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -662,32 +763,35 @@ object Diff { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are + * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. */ - def of[T, U](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty): DataFrame = + def of[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: Seq[String], + ignoreColumns: Seq[String] = Seq.empty + ): DataFrame = default.diff(left, right, idColumns, ignoreColumns) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types, - * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as - * columns are compared based on the name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both + * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The + * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the + * position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between two Datasets, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the right Dataset, that do not exist in the left Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of the left Dataset, - * that do not exist in the right Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between two Datasets, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of the right Dataset, + * that do not exist in the left Dataset (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of + * the left Dataset, that do not exist in the right Dataset are marked as `"D"`elete. * - * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows - * will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given, all columns are considered id columns. Then, no `"C"`hange rows will appear, as all + * changes will exists as respective `"D"`elete and `"I"`nsert. * * Values in optional ignore columns are not compared but included in the output DataFrame. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -720,16 +824,19 @@ object Diff { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset are + * id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. */ - def of[T, U](left: Dataset[T], right: Dataset[U], idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): DataFrame = + def of[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: java.util.List[String], + ignoreColumns: java.util.List[String] + ): DataFrame = default.diff(left, right, idColumns, ignoreColumns) /** - * Returns a new Dataset that contains the differences between two Datasets of - * the same type `T`. + * Returns a new Dataset that contains the differences between two Datasets of the same type `T`. * * See `of(Dataset[T], Dataset[T], String*)`. * @@ -737,62 +844,72 @@ object Diff { */ // no @scala.annotation.varargs here as implicit arguments are explicit in Java // this signature is redundant to the other ofAs method in Java - def ofAs[T, V](left: Dataset[T], right: Dataset[T], idColumns: String*) - (implicit diffEncoder: Encoder[V]): Dataset[V] = + def ofAs[T, V](left: Dataset[T], right: Dataset[T], idColumns: String*)(implicit + diffEncoder: Encoder[V] + ): Dataset[V] = default.diffAs(left, right, idColumns: _*) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. * * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. */ - def ofAs[T, U, V](left: Dataset[T], right: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String] = Seq.empty) - (implicit diffEncoder: Encoder[V]): Dataset[V] = + def ofAs[T, U, V]( + left: Dataset[T], + right: Dataset[U], + idColumns: Seq[String], + ignoreColumns: Seq[String] = Seq.empty + )(implicit diffEncoder: Encoder[V]): Dataset[V] = default.diffAs(left, right, idColumns, ignoreColumns) /** - * Returns a new Dataset that contains the differences between two Datasets of - * the same type `T`. + * Returns a new Dataset that contains the differences between two Datasets of the same type `T`. * * See `of(Dataset[T], Dataset[T], String*)`. * * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ @scala.annotation.varargs - def ofAs[T, V](left: Dataset[T], right: Dataset[T], - diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = + def ofAs[T, V](left: Dataset[T], right: Dataset[T], diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = default.diffAs(left, right, diffEncoder, idColumns: _*) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. * * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ - def ofAs[T, U, V](left: Dataset[T], right: Dataset[U], - diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] = + def ofAs[T, U, V]( + left: Dataset[T], + right: Dataset[U], + diffEncoder: Encoder[V], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[V] = default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. * * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ - def ofAs[T, U, V](left: Dataset[T], right: Dataset[U], diffEncoder: Encoder[V], - idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[V] = + def ofAs[T, U, V]( + left: Dataset[T], + right: Dataset[U], + diffEncoder: Encoder[V], + idColumns: java.util.List[String], + ignoreColumns: java.util.List[String] + ): Dataset[V] = default.diffAs(left, right, diffEncoder, idColumns, ignoreColumns) /** - * Returns a new Dataset that contains the differences between two Dataset of - * the same type `T` as tuples of type `(String, T, T)`. + * Returns a new Dataset that contains the differences between two Dataset of the same type `T` as tuples of type + * `(String, T, T)`. * * See `of(Dataset[T], Dataset[T], String*)`. */ @@ -801,22 +918,30 @@ object Diff { default.diffWith(left, right, idColumns: _*) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U` as tuples of type `(String, T, U)`. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples + * of type `(String, T, U)`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. */ - def ofWith[T, U](left: Dataset[T], right: Dataset[U], - idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] = + def ofWith[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[(String, T, U)] = default.diffWith(left, right, idColumns, ignoreColumns) /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U` as tuples of type `(String, T, U)`. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U` as tuples + * of type `(String, T, U)`. * * See `of(Dataset[T], Dataset[U], Seq[String], Seq[String])`. */ - def ofWith[T, U](left: Dataset[T], right: Dataset[U], - idColumns: java.util.List[String], ignoreColumns: java.util.List[String]): Dataset[(String, T, U)] = + def ofWith[T, U]( + left: Dataset[T], + right: Dataset[U], + idColumns: java.util.List[String], + ignoreColumns: java.util.List[String] + ): Dataset[(String, T, U)] = default.diffWith(left, right, idColumns, ignoreColumns) } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala b/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala index 1c9e2f1d..170d2240 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/DiffComparators.scala @@ -23,6 +23,7 @@ import uk.co.gresearch.spark.diff.comparator._ import java.time.Duration object DiffComparators { + /** * The default comparator used in [[DiffOptions.default.defaultComparator]]. */ @@ -34,17 +35,17 @@ object DiffComparators { def nullSafeEqual(): DiffComparator = NullSafeEqualDiffComparator /** - * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. - * The implicit [[Encoder]] of type [[T]] determines the input data type of the comparator. - * Only columns of that type can be compared. + * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. The implicit [[Encoder]] of + * type [[T]] determines the input data type of the comparator. Only columns of that type can be compared. */ - def equiv[T : Encoder](equiv: math.Equiv[T]): EquivDiffComparator[T] = EquivDiffComparator(equiv) + def equiv[T: Encoder](equiv: math.Equiv[T]): EquivDiffComparator[T] = EquivDiffComparator(equiv) /** - * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. - * Only columns of the given data type `inputType` can be compared. + * Return a comparator that uses the given [[math.Equiv]] to compare values of type [[T]]. Only columns of the given + * data type `inputType` can be compared. */ - def equiv[T](equiv: math.Equiv[T], inputType: DataType): EquivDiffComparator[T] = EquivDiffComparator(equiv, inputType) + def equiv[T](equiv: math.Equiv[T], inputType: DataType): EquivDiffComparator[T] = + EquivDiffComparator(equiv, inputType) /** * Return a comparator that uses the given [[math.Equiv]] to compare values of any type. @@ -52,17 +53,15 @@ object DiffComparators { def equiv(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivDiffComparator(equiv) /** - * This comparator considers values equal when they are less than `epsilon` apart. - * It can be configured to use `epsilon` as an absolute (`.asAbsolute()`) threshold, - * or as relative (`.asRelative()`) to the larger value. Further, the threshold itself can be - * considered equal (`.asInclusive()`) or not equal (`.asExclusive()`): + * This comparator considers values equal when they are less than `epsilon` apart. It can be configured to use + * `epsilon` as an absolute (`.asAbsolute()`) threshold, or as relative (`.asRelative()`) to the larger value. + * Further, the threshold itself can be considered equal (`.asInclusive()`) or not equal (`.asExclusive()`): * - *
    - *
  • `DiffComparator.epsilon(epsilon).asAbsolute().asInclusive()`: `abs(left - right) ≤ epsilon`
  • - *
  • `DiffComparator.epsilon(epsilon).asAbsolute().asExclusive()`: `abs(left - right) < epsilon`
  • - *
  • `DiffComparator.epsilon(epsilon).asRelative().asInclusive()`: `abs(left - right) ≤ epsilon * max(abs(left), abs(right))`
  • - *
  • `DiffComparator.epsilon(epsilon).asRelative().asExclusive()`: `abs(left - right) < epsilon * max(abs(left), abs(right))`
  • - *
+ *
  • `DiffComparator.epsilon(epsilon).asAbsolute().asInclusive()`: `abs(left - right) ≤ epsilon`
  • + *
  • `DiffComparator.epsilon(epsilon).asAbsolute().asExclusive()`: `abs(left - right) < epsilon`
  • + *
  • `DiffComparator.epsilon(epsilon).asRelative().asInclusive()`: `abs(left - right) ≤ epsilon * max(abs(left), + * abs(right))`
  • `DiffComparator.epsilon(epsilon).asRelative().asExclusive()`: `abs(left - right) < epsilon * + * max(abs(left), abs(right))`
* * Requires compared column types to implement `-`, `*`, `<`, `==`, and `abs`. */ @@ -71,8 +70,9 @@ object DiffComparators { /** * A comparator for string values. * - * With `whitespaceAgnostic` set `true`, differences in white spaces are ignored. This ignores leading and trailing whitespaces as well. - * With `whitespaceAgnostic` set `false`, this is equal to the default string comparison (see [[default()]]). + * With `whitespaceAgnostic` set `true`, differences in white spaces are ignored. This ignores leading and trailing + * whitespaces as well. With `whitespaceAgnostic` set `false`, this is equal to the default string comparison (see + * [[default()]]). */ def string(whitespaceAgnostic: Boolean = true): StringDiffComparator = if (whitespaceAgnostic) { @@ -85,11 +85,9 @@ object DiffComparators { * This comparator considers two `DateType` or `TimestampType` values equal when they are at most `duration` apart. * Duration is an instance of `java.time.Duration`. * - * The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal (`.asExclusive()`): - *
    - *
  • `DiffComparator.duration(duration).asInclusive()`: `left - right ≤ duration`
  • - *
  • `DiffComparator.duration(duration).asExclusive()`: `left - right < duration`
  • - * + * The comparator can be configured to consider `duration` as equal (`.asInclusive()`) or not equal + * (`.asExclusive()`):
    • `DiffComparator.duration(duration).asInclusive()`: `left - right ≤ duration`
    • + *
    • `DiffComparator.duration(duration).asExclusive()`: `left - right < duration`
    • */ def duration(duration: Duration): DurationDiffComparator = DurationDiffComparator(duration) @@ -103,7 +101,9 @@ object DiffComparators { /** * This comparator compares two `Map[K,V]` values. They are equal when they match in all their keys and values. * - * @param keyOrderSensitive comparator compares key order if true + * @param keyOrderSensitive + * comparator compares key order if true */ - def map[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): DiffComparator = MapDiffComparator[K, V](keyOrderSensitive) + def map[K: Encoder, V: Encoder](keyOrderSensitive: Boolean): DiffComparator = + MapDiffComparator[K, V](keyOrderSensitive) } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala b/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala index 7f434038..0bd300f7 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/DiffOptions.scala @@ -20,7 +20,12 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.types.{DataType, StructField} import uk.co.gresearch.spark.diff import uk.co.gresearch.spark.diff.DiffMode.{Default, DiffMode} -import uk.co.gresearch.spark.diff.comparator.{DefaultDiffComparator, DiffComparator, EquivDiffComparator, TypedDiffComparator} +import uk.co.gresearch.spark.diff.comparator.{ + DefaultDiffComparator, + DiffComparator, + EquivDiffComparator, + TypedDiffComparator +} import scala.annotation.varargs import scala.collection.Map @@ -34,23 +39,20 @@ object DiffMode extends Enumeration { /** * The diff mode determines the output columns of the diffing transformation. * - * - ColumnByColumn: The diff contains value columns from the left and right dataset, - * arranged column by column: - * diff,( changes,) id-1, id-2, …, left-value-1, right-value-1, left-value-2, right-value-2, … + * - ColumnByColumn: The diff contains value columns from the left and right dataset, arranged column by column: + * diff,( changes,) id-1, id-2, …, left-value-1, right-value-1, left-value-2, right-value-2, … * - * - SideBySide: The diff contains value columns from the left and right dataset, - * arranged side by side: - * diff,( changes,) id-1, id-2, …, left-value-1, left-value-2, …, right-value-1, right-value-2, … - * - LeftSide / RightSide: The diff contains value columns from the left / right dataset only. + * - SideBySide: The diff contains value columns from the left and right dataset, arranged side by side: diff,( + * changes,) id-1, id-2, …, left-value-1, left-value-2, …, right-value-1, right-value-2, … + * - LeftSide / RightSide: The diff contains value columns from the left / right dataset only. */ val ColumnByColumn, SideBySide, LeftSide, RightSide = Value /** - * The diff mode determines the output columns of the diffing transformation. - * The default diff mode is ColumnByColumn. + * The diff mode determines the output columns of the diffing transformation. The default diff mode is ColumnByColumn. * - * Default is not a enum value here (hence the def) so that we do not have to include it in every - * match clause. We will see the respective enum value that Default points to instead. + * Default is not a enum value here (hence the def) so that we do not have to include it in every match clause. We + * will see the respective enum value that Default points to instead. */ def Default: diff.DiffMode.Value = ColumnByColumn @@ -72,45 +74,62 @@ object DiffMode extends Enumeration { /** * Configuration class for diffing Datasets. * - * @param diffColumn name of the diff column - * @param leftColumnPrefix prefix of columns from the left Dataset - * @param rightColumnPrefix prefix of columns from the right Dataset - * @param insertDiffValue value in diff column for inserted rows - * @param changeDiffValue value in diff column for changed rows - * @param deleteDiffValue value in diff column for deleted rows - * @param nochangeDiffValue value in diff column for un-changed rows - * @param changeColumn name of change column - * @param diffMode diff output format - * @param sparseMode un-changed values are null on both sides - * @param defaultComparator default custom comparator - * @param dataTypeComparators custom comparator for some data type - * @param columnNameComparators custom comparator for some column name + * @param diffColumn + * name of the diff column + * @param leftColumnPrefix + * prefix of columns from the left Dataset + * @param rightColumnPrefix + * prefix of columns from the right Dataset + * @param insertDiffValue + * value in diff column for inserted rows + * @param changeDiffValue + * value in diff column for changed rows + * @param deleteDiffValue + * value in diff column for deleted rows + * @param nochangeDiffValue + * value in diff column for un-changed rows + * @param changeColumn + * name of change column + * @param diffMode + * diff output format + * @param sparseMode + * un-changed values are null on both sides + * @param defaultComparator + * default custom comparator + * @param dataTypeComparators + * custom comparator for some data type + * @param columnNameComparators + * custom comparator for some column name */ -case class DiffOptions(diffColumn: String, - leftColumnPrefix: String, - rightColumnPrefix: String, - insertDiffValue: String, - changeDiffValue: String, - deleteDiffValue: String, - nochangeDiffValue: String, - changeColumn: Option[String] = None, - diffMode: DiffMode = Default, - sparseMode: Boolean = false, - defaultComparator: DiffComparator = DefaultDiffComparator, - dataTypeComparators: Map[DataType, DiffComparator] = Map.empty, - columnNameComparators: Map[String, DiffComparator] = Map.empty) { +case class DiffOptions( + diffColumn: String, + leftColumnPrefix: String, + rightColumnPrefix: String, + insertDiffValue: String, + changeDiffValue: String, + deleteDiffValue: String, + nochangeDiffValue: String, + changeColumn: Option[String] = None, + diffMode: DiffMode = Default, + sparseMode: Boolean = false, + defaultComparator: DiffComparator = DefaultDiffComparator, + dataTypeComparators: Map[DataType, DiffComparator] = Map.empty, + columnNameComparators: Map[String, DiffComparator] = Map.empty +) { // Constructor for Java to construct default options def this() = this("diff", "left", "right", "I", "C", "D", "N") - def this(diffColumn: String, - leftColumnPrefix: String, - rightColumnPrefix: String, - insertDiffValue: String, - changeDiffValue: String, - deleteDiffValue: String, - nochangeDiffValue: String, - changeColumn: Option[String], - diffMode: DiffMode, - sparseMode: Boolean) = { + def this( + diffColumn: String, + leftColumnPrefix: String, + rightColumnPrefix: String, + insertDiffValue: String, + changeDiffValue: String, + deleteDiffValue: String, + nochangeDiffValue: String, + changeColumn: Option[String], + diffMode: DiffMode, + sparseMode: Boolean + ) = { this( diffColumn, leftColumnPrefix, @@ -124,188 +143,215 @@ case class DiffOptions(diffColumn: String, sparseMode, DefaultDiffComparator, Map.empty, - Map.empty) + Map.empty + ) } require(leftColumnPrefix.nonEmpty, "Left column prefix must not be empty") require(rightColumnPrefix.nonEmpty, "Right column prefix must not be empty") - require(handleConfiguredCaseSensitivity(leftColumnPrefix) != handleConfiguredCaseSensitivity(rightColumnPrefix), - s"Left and right column prefix must be distinct: $leftColumnPrefix") + require( + handleConfiguredCaseSensitivity(leftColumnPrefix) != handleConfiguredCaseSensitivity(rightColumnPrefix), + s"Left and right column prefix must be distinct: $leftColumnPrefix" + ) val diffValues = Seq(insertDiffValue, changeDiffValue, deleteDiffValue, nochangeDiffValue) - require(diffValues.distinct.length == diffValues.length, - s"Diff values must be distinct: $diffValues") + require(diffValues.distinct.length == diffValues.length, s"Diff values must be distinct: $diffValues") - require(!changeColumn.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(diffColumn)), - s"Change column name must be different to diff column: $diffColumn") + require( + !changeColumn.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(diffColumn)), + s"Change column name must be different to diff column: $diffColumn" + ) /** - * Fluent method to change the diff column name. - * Returns a new immutable DiffOptions instance with the new diff column name. - * @param diffColumn new diff column name - * @return new immutable DiffOptions instance + * Fluent method to change the diff column name. Returns a new immutable DiffOptions instance with the new diff column + * name. + * @param diffColumn + * new diff column name + * @return + * new immutable DiffOptions instance */ def withDiffColumn(diffColumn: String): DiffOptions = { this.copy(diffColumn = diffColumn) } /** - * Fluent method to change the prefix of columns from the left Dataset. - * Returns a new immutable DiffOptions instance with the new column prefix. - * @param leftColumnPrefix new column prefix - * @return new immutable DiffOptions instance + * Fluent method to change the prefix of columns from the left Dataset. Returns a new immutable DiffOptions instance + * with the new column prefix. + * @param leftColumnPrefix + * new column prefix + * @return + * new immutable DiffOptions instance */ def withLeftColumnPrefix(leftColumnPrefix: String): DiffOptions = { this.copy(leftColumnPrefix = leftColumnPrefix) } /** - * Fluent method to change the prefix of columns from the right Dataset. - * Returns a new immutable DiffOptions instance with the new column prefix. - * @param rightColumnPrefix new column prefix - * @return new immutable DiffOptions instance + * Fluent method to change the prefix of columns from the right Dataset. Returns a new immutable DiffOptions instance + * with the new column prefix. + * @param rightColumnPrefix + * new column prefix + * @return + * new immutable DiffOptions instance */ def withRightColumnPrefix(rightColumnPrefix: String): DiffOptions = { this.copy(rightColumnPrefix = rightColumnPrefix) } /** - * Fluent method to change the value of inserted rows in the diff column. - * Returns a new immutable DiffOptions instance with the new diff value. - * @param insertDiffValue new diff value - * @return new immutable DiffOptions instance + * Fluent method to change the value of inserted rows in the diff column. Returns a new immutable DiffOptions instance + * with the new diff value. + * @param insertDiffValue + * new diff value + * @return + * new immutable DiffOptions instance */ def withInsertDiffValue(insertDiffValue: String): DiffOptions = { this.copy(insertDiffValue = insertDiffValue) } /** - * Fluent method to change the value of changed rows in the diff column. - * Returns a new immutable DiffOptions instance with the new diff value. - * @param changeDiffValue new diff value - * @return new immutable DiffOptions instance + * Fluent method to change the value of changed rows in the diff column. Returns a new immutable DiffOptions instance + * with the new diff value. + * @param changeDiffValue + * new diff value + * @return + * new immutable DiffOptions instance */ def withChangeDiffValue(changeDiffValue: String): DiffOptions = { this.copy(changeDiffValue = changeDiffValue) } /** - * Fluent method to change the value of deleted rows in the diff column. - * Returns a new immutable DiffOptions instance with the new diff value. - * @param deleteDiffValue new diff value - * @return new immutable DiffOptions instance + * Fluent method to change the value of deleted rows in the diff column. Returns a new immutable DiffOptions instance + * with the new diff value. + * @param deleteDiffValue + * new diff value + * @return + * new immutable DiffOptions instance */ def withDeleteDiffValue(deleteDiffValue: String): DiffOptions = { this.copy(deleteDiffValue = deleteDiffValue) } /** - * Fluent method to change the value of un-changed rows in the diff column. - * Returns a new immutable DiffOptions instance with the new diff value. - * @param nochangeDiffValue new diff value - * @return new immutable DiffOptions instance + * Fluent method to change the value of un-changed rows in the diff column. Returns a new immutable DiffOptions + * instance with the new diff value. + * @param nochangeDiffValue + * new diff value + * @return + * new immutable DiffOptions instance */ def withNochangeDiffValue(nochangeDiffValue: String): DiffOptions = { this.copy(nochangeDiffValue = nochangeDiffValue) } /** - * Fluent method to change the change column name. - * Returns a new immutable DiffOptions instance with the new change column name. - * @param changeColumn new change column name - * @return new immutable DiffOptions instance + * Fluent method to change the change column name. Returns a new immutable DiffOptions instance with the new change + * column name. + * @param changeColumn + * new change column name + * @return + * new immutable DiffOptions instance */ def withChangeColumn(changeColumn: String): DiffOptions = { this.copy(changeColumn = Some(changeColumn)) } /** - * Fluent method to remove change column. - * Returns a new immutable DiffOptions instance without a change column. - * @return new immutable DiffOptions instance + * Fluent method to remove change column. Returns a new immutable DiffOptions instance without a change column. + * @return + * new immutable DiffOptions instance */ def withoutChangeColumn(): DiffOptions = { this.copy(changeColumn = None) } /** - * Fluent method to change the diff mode. - * Returns a new immutable DiffOptions instance with the new diff mode. - * @return new immutable DiffOptions instance + * Fluent method to change the diff mode. Returns a new immutable DiffOptions instance with the new diff mode. + * @return + * new immutable DiffOptions instance */ def withDiffMode(diffMode: DiffMode): DiffOptions = { this.copy(diffMode = diffMode) } /** - * Fluent method to change the sparse mode. - * Returns a new immutable DiffOptions instance with the new sparse mode. - * @return new immutable DiffOptions instance + * Fluent method to change the sparse mode. Returns a new immutable DiffOptions instance with the new sparse mode. + * @return + * new immutable DiffOptions instance */ def withSparseMode(sparseMode: Boolean): DiffOptions = { this.copy(sparseMode = sparseMode) } /** - * Fluent method to add a default comparator. - * Returns a new immutable DiffOptions instance with the new default comparator. - * @return new immutable DiffOptions instance + * Fluent method to add a default comparator. Returns a new immutable DiffOptions instance with the new default + * comparator. + * @return + * new immutable DiffOptions instance */ def withDefaultComparator(diffComparator: DiffComparator): DiffOptions = { this.copy(defaultComparator = diffComparator) } /** - * Fluent method to add a typed equivalent operator as a default comparator. - * The encoder defines the input type of the comparator. - * Returns a new immutable DiffOptions instance with the new default comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators - * can only be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add a typed equivalent operator as a default comparator. The encoder defines the input type of the + * comparator. Returns a new immutable DiffOptions instance with the new default comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ - def withDefaultComparator[T : Encoder](equiv: math.Equiv[T]): DiffOptions = { + def withDefaultComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions = { withDefaultComparator(EquivDiffComparator(equiv)) } /** - * Fluent method to add a typed equivalent operator as a default comparator. - * Returns a new immutable DiffOptions instance with the new default comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators - * can only be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add a typed equivalent operator as a default comparator. Returns a new immutable DiffOptions + * instance with the new default comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ def withDefaultComparator[T](equiv: math.Equiv[T], inputDataType: DataType): DiffOptions = { withDefaultComparator(EquivDiffComparator(equiv, inputDataType)) } /** - * Fluent method to add an equivalent operator as a default comparator. - * Returns a new immutable DiffOptions instance with the new default comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators - * can only be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add an equivalent operator as a default comparator. Returns a new immutable DiffOptions instance + * with the new default comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ def withDefaultComparator(equiv: math.Equiv[Any]): DiffOptions = { withDefaultComparator(EquivDiffComparator(equiv)) } /** - * Fluent method to add a comparator for its input data type. - * Returns a new immutable DiffOptions instance with the new comparator. - * @return new immutable DiffOptions instance + * Fluent method to add a comparator for its input data type. Returns a new immutable DiffOptions instance with the + * new comparator. + * @return + * new immutable DiffOptions instance */ def withComparator(diffComparator: TypedDiffComparator): DiffOptions = { if (dataTypeComparators.contains(diffComparator.inputType)) { - throw new IllegalArgumentException( - s"A comparator for data type ${diffComparator.inputType} exists already.") + throw new IllegalArgumentException(s"A comparator for data type ${diffComparator.inputType} exists already.") } this.copy(dataTypeComparators = dataTypeComparators ++ Map(diffComparator.inputType -> diffComparator)) } /** - * Fluent method to add a comparator for one or more data types. - * Returns a new immutable DiffOptions instance with the new comparator. - * @return new immutable DiffOptions instance + * Fluent method to add a comparator for one or more data types. Returns a new immutable DiffOptions instance with the + * new comparator. + * @return + * new immutable DiffOptions instance */ @varargs def withComparator(diffComparator: DiffComparator, dataType: DataType, dataTypes: DataType*): DiffOptions = { @@ -313,8 +359,10 @@ case class DiffOptions(diffColumn: String, diffComparator match { case typed: TypedDiffComparator if allDataTypes.exists(_ != typed.inputType) => - throw new IllegalArgumentException(s"Comparator with input type ${typed.inputType.simpleString} " + - s"cannot be used for data type ${allDataTypes.filter(_ != typed.inputType).map(_.simpleString).sorted.mkString(", ")}") + throw new IllegalArgumentException( + s"Comparator with input type ${typed.inputType.simpleString} " + + s"cannot be used for data type ${allDataTypes.filter(_ != typed.inputType).map(_.simpleString).sorted.mkString(", ")}" + ) case _ => } @@ -322,15 +370,17 @@ case class DiffOptions(diffColumn: String, if (existingDataTypes.nonEmpty) { throw new IllegalArgumentException( s"A comparator for data type${if (existingDataTypes.length > 1) "s" else ""} " + - s"${existingDataTypes.map(_.simpleString).sorted.mkString(", ")} exists already.") + s"${existingDataTypes.map(_.simpleString).sorted.mkString(", ")} exists already." + ) } this.copy(dataTypeComparators = dataTypeComparators ++ allDataTypes.map(dt => dt -> diffComparator)) } /** - * Fluent method to add a comparator for one or more column names. - * Returns a new immutable DiffOptions instance with the new comparator. - * @return new immutable DiffOptions instance + * Fluent method to add a comparator for one or more column names. Returns a new immutable DiffOptions instance with + * the new comparator. + * @return + * new immutable DiffOptions instance */ @varargs def withComparator(diffComparator: DiffComparator, columnName: String, columnNames: String*): DiffOptions = { @@ -339,79 +389,96 @@ case class DiffOptions(diffColumn: String, if (existingColumnNames.nonEmpty) { throw new IllegalArgumentException( s"A comparator for column name${if (existingColumnNames.length > 1) "s" else ""} " + - s"${existingColumnNames.sorted.mkString(", ")} exists already.") + s"${existingColumnNames.sorted.mkString(", ")} exists already." + ) } this.copy(columnNameComparators = columnNameComparators ++ allColumnNames.map(name => name -> diffComparator)) } /** - * Fluent method to add a typed equivalent operator as a comparator for its input data type. - * The encoder defines the input type of the comparator. - * Returns a new immutable DiffOptions instance with the new comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators can only - * be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add a typed equivalent operator as a comparator for its input data type. The encoder defines the + * input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ - def withComparator[T : Encoder](equiv: math.Equiv[T]): DiffOptions = + def withComparator[T: Encoder](equiv: math.Equiv[T]): DiffOptions = withComparator(EquivDiffComparator(equiv)) /** - * Fluent method to add a typed equivalent operator as a comparator for one or more column names. - * The encoder defines the input type of the comparator. - * Returns a new immutable DiffOptions instance with the new comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators can only - * be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add a typed equivalent operator as a comparator for one or more column names. The encoder defines + * the input type of the comparator. Returns a new immutable DiffOptions instance with the new comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ - def withComparator[T : Encoder](equiv: math.Equiv[T], columnName: String, columnNames: String*): DiffOptions = + def withComparator[T: Encoder](equiv: math.Equiv[T], columnName: String, columnNames: String*): DiffOptions = withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*) /** - * Fluent method to add an equivalent operator as a comparator for one or more column names. - * Returns a new immutable DiffOptions instance with the new comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators - * can only be implemented via the `DiffComparator` interface. - * @note Java-specific method - * @return new immutable DiffOptions instance + * Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable + * DiffOptions instance with the new comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @note + * Java-specific method + * @return + * new immutable DiffOptions instance */ @varargs - def withComparator[T](equiv: math.Equiv[T], encoder: Encoder[T], columnName: String, columnNames: String*): DiffOptions = + def withComparator[T]( + equiv: math.Equiv[T], + encoder: Encoder[T], + columnName: String, + columnNames: String* + ): DiffOptions = withComparator(EquivDiffComparator(equiv)(encoder), columnName, columnNames: _*) /** - * Fluent method to add an equivalent operator as a comparator for one or more data types. - * Returns a new immutable DiffOptions instance with the new comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators - * can only be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add an equivalent operator as a comparator for one or more data types. Returns a new immutable + * DiffOptions instance with the new comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ // There is probably no use case of calling this with multiple datatype while T not being Any // But this is the only way to define withComparator[T](equiv: math.Equiv[T], dataType: DataType) // without being ambiguous with withComparator(equiv: math.Equiv[Any], dataType: DataType, dataTypes: DataType*) @varargs def withComparator[T](equiv: math.Equiv[T], dataType: DataType, dataTypes: DataType*): DiffOptions = - (dataType +: dataTypes).foldLeft(this)( (options, dataType) => + (dataType +: dataTypes).foldLeft(this)((options, dataType) => options.withComparator(EquivDiffComparator(equiv, dataType)) ) /** - * Fluent method to add an equivalent operator as a comparator for one or more column names. - * Returns a new immutable DiffOptions instance with the new comparator. - * @note The `math.Equiv` will not be given any null values. Null-aware comparators - * can only be implemented via the `DiffComparator` interface. - * @return new immutable DiffOptions instance + * Fluent method to add an equivalent operator as a comparator for one or more column names. Returns a new immutable + * DiffOptions instance with the new comparator. + * @note + * The `math.Equiv` will not be given any null values. Null-aware comparators can only be implemented via the + * `DiffComparator` interface. + * @return + * new immutable DiffOptions instance */ @varargs def withComparator(equiv: math.Equiv[Any], columnName: String, columnNames: String*): DiffOptions = withComparator(EquivDiffComparator(equiv), columnName, columnNames: _*) private[diff] def comparatorFor(column: StructField): DiffComparator = - columnNameComparators.get(column.name) + columnNameComparators + .get(column.name) .orElse(dataTypeComparators.get(column.dataType)) .getOrElse(defaultComparator) } object DiffOptions { + /** * Default diffing options. */ diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala index fed6b2e3..4ebfc79e 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/DurationDiffComparator.scala @@ -25,25 +25,30 @@ import uk.co.gresearch.spark.diff.comparator.DurationDiffComparator.isNotSupport import java.time.Duration /** - * Compares two timestamps and considers them equal when they are less than - * (or equal to when inclusive = true) a given duration apart. + * Compares two timestamps and considers them equal when they are less than (or equal to when inclusive = true) a given + * duration apart. * - * @param duration equality threshold - * @param inclusive duration is considered equal when true + * @param duration + * equality threshold + * @param inclusive + * duration is considered equal when true */ case class DurationDiffComparator(duration: Duration, inclusive: Boolean = true) extends DiffComparator { if (isNotSupportedBySpark) { - throw new UnsupportedOperationException(s"java.time.Duration is not supported by Spark ${spark.SparkCompatVersionString}") + throw new UnsupportedOperationException( + s"java.time.Duration is not supported by Spark ${spark.SparkCompatVersionString}" + ) } override def equiv(left: Column, right: Column): Column = { - val inDuration = if (inclusive) - (diff: Column) => diff <= duration - else - (diff: Column) => diff < duration + val inDuration = + if (inclusive) + (diff: Column) => diff <= duration + else + (diff: Column) => diff < duration left.isNull && right.isNull || - left.isNotNull && right.isNotNull && inDuration(abs(left - right)) + left.isNotNull && right.isNotNull && inDuration(abs(left - right)) } def asInclusive(): DurationDiffComparator = if (inclusive) this else copy(inclusive = true) @@ -52,5 +57,5 @@ case class DurationDiffComparator(duration: Duration, inclusive: Boolean = true) object DurationDiffComparator extends SparkVersion { val isSupportedBySpark: Boolean = SparkMajorVersion == 3 && SparkMinorVersion >= 3 || SparkMajorVersion > 3 - val isNotSupportedBySpark: Boolean = ! isSupportedBySpark + val isNotSupportedBySpark: Boolean = !isSupportedBySpark } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala index a9d7658a..b95196e1 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EpsilonDiffComparator.scala @@ -20,17 +20,19 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.functions.{abs, greatest} case class EpsilonDiffComparator(epsilon: Double, relative: Boolean = true, inclusive: Boolean = true) - extends DiffComparator { + extends DiffComparator { override def equiv(left: Column, right: Column): Column = { - val threshold = if (relative) - greatest(abs(left), abs(right)) * epsilon - else - epsilon + val threshold = + if (relative) + greatest(abs(left), abs(right)) * epsilon + else + epsilon - val inEpsilon = if (inclusive) - (diff: Column) => diff <= threshold - else - (diff: Column) => diff < threshold + val inEpsilon = + if (inclusive) + (diff: Column) => diff <= threshold + else + (diff: Column) => diff < threshold left.isNull && right.isNull || left.isNotNull && right.isNotNull && inEpsilon(abs(left - right)) } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala index 4d26d393..6a9dc99b 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/EquivDiffComparator.scala @@ -38,21 +38,23 @@ private trait ExpressionEquivDiffComparator[T] extends EquivDiffComparator[T] { trait TypedEquivDiffComparator[T] extends EquivDiffComparator[T] with TypedDiffComparator private[comparator] trait TypedEquivDiffComparatorWithInput[T] - extends ExpressionEquivDiffComparator[T] with TypedEquivDiffComparator[T] { + extends ExpressionEquivDiffComparator[T] + with TypedEquivDiffComparator[T] { def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType) } private[comparator] case class InputTypedEquivDiffComparator[T](equiv: math.Equiv[T], inputType: DataType) - extends TypedEquivDiffComparatorWithInput[T] - + extends TypedEquivDiffComparatorWithInput[T] object EquivDiffComparator { - def apply[T : Encoder](equiv: math.Equiv[T]): TypedEquivDiffComparator[T] = EncoderEquivDiffComparator(equiv) - def apply[T](equiv: math.Equiv[T], inputType: DataType): TypedEquivDiffComparator[T] = InputTypedEquivDiffComparator(equiv, inputType) + def apply[T: Encoder](equiv: math.Equiv[T]): TypedEquivDiffComparator[T] = EncoderEquivDiffComparator(equiv) + def apply[T](equiv: math.Equiv[T], inputType: DataType): TypedEquivDiffComparator[T] = + InputTypedEquivDiffComparator(equiv, inputType) def apply(equiv: math.Equiv[Any]): EquivDiffComparator[Any] = EquivAnyDiffComparator(equiv) - private case class EncoderEquivDiffComparator[T : Encoder](equiv: math.Equiv[T]) - extends ExpressionEquivDiffComparator[T] with TypedEquivDiffComparator[T] { + private case class EncoderEquivDiffComparator[T: Encoder](equiv: math.Equiv[T]) + extends ExpressionEquivDiffComparator[T] + with TypedEquivDiffComparator[T] { override def inputType: DataType = encoderFor[T].schema.fields(0).dataType def equiv(left: Expression, right: Expression): Equiv[T] = Equiv(left, right, equiv, inputType) } @@ -85,9 +87,12 @@ private trait EquivExpression[T] extends BinaryExpression with BinaryLikeWithNew val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equivRef = ctx.addReferenceObj("equiv", equiv, math.Equiv.getClass.getName.stripSuffix("$")) - ev.copy(code = eval1.code + eval2.code + code""" + ev.copy( + code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && !${eval2.isNull} && $equivRef.equiv(${eval1.value}, ${eval2.value}));""", isNull = FalseLiteral) + (!${eval1.isNull} && !${eval2.isNull} && $equivRef.equiv(${eval1.value}, ${eval2.value}));""", + isNull = FalseLiteral + ) } } @@ -100,13 +105,12 @@ private trait EquivOperator[T] extends BinaryOperator with EquivExpression[T] { } private case class Equiv[T](left: Expression, right: Expression, equiv: math.Equiv[T], equivInputType: DataType) - extends EquivOperator[T] { + extends EquivOperator[T] { override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Equiv[T] = - copy(left=newLeft, right=newRight) + copy(left = newLeft, right = newRight) } -private case class EquivAny(left: Expression, right: Expression, equiv: math.Equiv[Any]) - extends EquivExpression[Any] { +private case class EquivAny(left: Expression, right: Expression, equiv: math.Equiv[Any]) extends EquivExpression[Any] { override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): EquivAny = copy(left = newLeft, right = newRight) diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala index d2ca79fc..71c010d8 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/MapDiffComparator.scala @@ -27,7 +27,8 @@ case class MapDiffComparator[K, V](private val comparator: EquivDiffComparator[U override def equiv(left: Column, right: Column): Column = comparator.equiv(left, right) } -private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean) extends math.Equiv[UnsafeMapData] { +private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: DataType, keyOrderSensitive: Boolean) + extends math.Equiv[UnsafeMapData] { override def equiv(left: UnsafeMapData, right: UnsafeMapData): Boolean = { val leftKeys: Array[K] = left.keyArray().toArray(keyType) @@ -42,15 +43,20 @@ private case class MapDiffEquiv[K: ClassTag, V](keyType: DataType, valueType: Da // can only be evaluated when right has same keys as left lazy val valuesAreEqual = leftKeysIndices .map { case (key, index) => index -> rightKeysIndices(key) } - .map { case (leftIndex, rightIndex) => (leftIndex, rightIndex, leftValues.isNullAt(leftIndex), rightValues.isNullAt(rightIndex)) } + .map { case (leftIndex, rightIndex) => + (leftIndex, rightIndex, leftValues.isNullAt(leftIndex), rightValues.isNullAt(rightIndex)) + } .map { case (leftIndex, rightIndex, leftIsNull, rightIsNull) => leftIsNull && rightIsNull || - !leftIsNull && !rightIsNull && leftValues.get(leftIndex, valueType).equals(rightValues.get(rightIndex, valueType)) + !leftIsNull && !rightIsNull && leftValues + .get(leftIndex, valueType) + .equals(rightValues.get(rightIndex, valueType)) } left.numElements() == right.numElements() && - (keyOrderSensitive && leftKeys.sameElements(rightKeys) || !keyOrderSensitive && leftKeys.toSet.diff(rightKeys.toSet).isEmpty) && - valuesAreEqual.forall(identity) + (keyOrderSensitive && leftKeys + .sameElements(rightKeys) || !keyOrderSensitive && leftKeys.toSet.diff(rightKeys.toSet).isEmpty) && + valuesAreEqual.forall(identity) } } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala b/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala index ab243ec7..8b107ce6 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/comparator/WhitespaceDiffComparator.scala @@ -20,7 +20,10 @@ import org.apache.spark.unsafe.types.UTF8String case object WhitespaceDiffComparator extends TypedEquivDiffComparatorWithInput[UTF8String] with StringDiffComparator { override val equiv: scala.Equiv[UTF8String] = (x: UTF8String, y: UTF8String) => - x.trimAll().toString.replaceAll("\\s+"," ").equals( - y.trimAll().toString.replaceAll("\\s+"," ") - ) + x.trimAll() + .toString + .replaceAll("\\s+", " ") + .equals( + y.trimAll().toString.replaceAll("\\s+", " ") + ) } diff --git a/src/main/scala/uk/co/gresearch/spark/diff/package.scala b/src/main/scala/uk/co/gresearch/spark/diff/package.scala index 53f47e72..e13802db 100644 --- a/src/main/scala/uk/co/gresearch/spark/diff/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/diff/package.scala @@ -26,22 +26,21 @@ package object diff { implicit class DatasetDiff[T](ds: Dataset[T]) { /** - * Returns a new DataFrame that contains the differences between this and the other Dataset of - * the same type `T`. Both Datasets must contain the same set of column names and data types. - * The order of columns in the two Datasets is not important as one column is compared to the - * column with the same name of the other Dataset, not the column with the same position. + * Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`. + * Both Datasets must contain the same set of column names and data types. The order of columns in the two Datasets + * is not important as one column is compared to the column with the same name of the other Dataset, not the column + * with the same position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between this and the other Dataset, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the other Dataset, that do not exist in this Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of this Dataset, that - * do not exist in the other Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between this and the other Dataset, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of + * the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as + * `"I"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `"D"`elete. * - * If no id columns are given (empty sequence), all columns are considered id columns. Then, - * no `"C"`hange rows will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given (empty sequence), all columns are considered id columns. Then, no `"C"`hange rows will + * appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -74,12 +73,11 @@ package object diff { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset + * are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. * - * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a - * column name containing a dot. This is not interpreted as a column "a" with a field "field" (struct). + * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a column name + * containing a dot. This is not interpreted as a column "a" with a field "field" (struct). */ // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java def diff(other: Dataset[T], idColumns: String*): DataFrame = { @@ -87,24 +85,23 @@ package object diff { } /** - * Returns a new DataFrame that contains the differences between two Datasets of - * similar types `T` and `U`. Both Datasets must contain the same set of column names and data types, - * except for the columns in `ignoreColumns`. The order of columns in the two Datasets is not relevant as - * columns are compared based on the name, not the the position. + * Returns a new DataFrame that contains the differences between two Datasets of similar types `T` and `U`. Both + * Datasets must contain the same set of column names and data types, except for the columns in `ignoreColumns`. The + * order of columns in the two Datasets is not relevant as columns are compared based on the name, not the the + * position. * - * Optional id columns are used to uniquely identify rows to compare. If values in any non-id - * column are differing between this and the other Dataset, then that row is marked as `"C"`hange - * and `"N"`o-change otherwise. Rows of the other Dataset, that do not exist in this Dataset - * (w.r.t. the values in the id columns) are marked as `"I"`nsert. And rows of this Dataset, that - * do not exist in the other Dataset are marked as `"D"`elete. + * Optional id columns are used to uniquely identify rows to compare. If values in any non-id column are differing + * between this and the other Dataset, then that row is marked as `"C"`hange and `"N"`o-change otherwise. Rows of + * the other Dataset, that do not exist in this Dataset (w.r.t. the values in the id columns) are marked as + * `"I"`nsert. And rows of this Dataset, that do not exist in the other Dataset are marked as `"D"`elete. * - * If no id columns are given (empty sequence), all columns are considered id columns. Then, - * no `"C"`hange rows will appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. + * If no id columns are given (empty sequence), all columns are considered id columns. Then, no `"C"`hange rows will + * appear, as all changes will exists as respective `"D"`elete and `"I"`nsert. * * Values in optional ignore columns are not compared but included in the output DataFrame. * - * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, - * `"I"` or `"D"` strings. The id columns follow, then the non-id columns (all remaining columns). + * The returned DataFrame has the `diff` column as the first column. This holds the `"N"`, `"C"`, `"I"` or `"D"` + * strings. The id columns follow, then the non-id columns (all remaining columns). * * {{{ * val df1 = Seq((1, "one"), (2, "two"), (3, "three")).toDF("id", "value") @@ -137,20 +134,18 @@ package object diff { * * }}} * - * The id columns are in order as given to the method. If no id columns are given then all - * columns of this Dataset are id columns and appear in the same order. The remaining non-id - * columns are in the order of this Dataset. + * The id columns are in order as given to the method. If no id columns are given then all columns of this Dataset + * are id columns and appear in the same order. The remaining non-id columns are in the order of this Dataset. * - * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a - * column name containing a dot. This is not interpreted as a column "a" with a field "field" (struct). + * The id column names are take literally, i.e. "a.field" is interpreted as "`a.field`, which is a column name + * containing a dot. This is not interpreted as a column "a" with a field "field" (struct). */ def diff[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = { Diff.of(this.ds, other, idColumns, ignoreColumns) } /** - * Returns a new DataFrame that contains the differences - * between this and the other Dataset of the same type `T`. + * Returns a new DataFrame that contains the differences between this and the other Dataset of the same type `T`. * * See `diff(Dataset[T], String*)`. * @@ -162,76 +157,80 @@ package object diff { } /** - * Returns a new DataFrame that contains the differences - * between this and the other Dataset of similar types `T` and `U`. + * Returns a new DataFrame that contains the differences between this and the other Dataset of similar types `T` and + * `U`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. * * The schema of the returned DataFrame can be configured by the given `DiffOptions`. */ - def diff[U](other: Dataset[U], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String]): DataFrame = { + def diff[U]( + other: Dataset[U], + options: DiffOptions, + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): DataFrame = { new Differ(options).diff(this.ds, other, idColumns, ignoreColumns) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of the same type `T`. + * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`. * * See `diff(Dataset[T], String*)`. * * This requires an additional implicit `Encoder[U]` for the return type `Dataset[U]`. */ // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java - def diffAs[V](other: Dataset[T], idColumns: String*) - (implicit diffEncoder: Encoder[V]): Dataset[V] = { + def diffAs[V](other: Dataset[T], idColumns: String*)(implicit diffEncoder: Encoder[V]): Dataset[V] = { Diff.ofAs(this.ds, other, idColumns: _*) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of similar types `T` and `U`. + * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and + * `U`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. * * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. */ - def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]) - (implicit diffEncoder: Encoder[V]): Dataset[V] = { + def diffAs[U, V](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String])(implicit + diffEncoder: Encoder[V] + ): Dataset[V] = { Diff.ofAs(this.ds, other, idColumns, ignoreColumns) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of the same type `T`. + * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`. * * See `diff(Dataset[T], String*)`. * - * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. - * The schema of the returned Dataset can be configured by the given `DiffOptions`. + * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned + * Dataset can be configured by the given `DiffOptions`. */ // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java - def diffAs[V](other: Dataset[T], options: DiffOptions, idColumns: String*) - (implicit diffEncoder: Encoder[V]): Dataset[V] = { + def diffAs[V](other: Dataset[T], options: DiffOptions, idColumns: String*)(implicit + diffEncoder: Encoder[V] + ): Dataset[V] = { new Differ(options).diffAs(this.ds, other, idColumns: _*) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of similar types `T` and `U`. + * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and + * `U`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. * - * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. - * The schema of the returned Dataset can be configured by the given `DiffOptions`. + * This requires an additional implicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned + * Dataset can be configured by the given `DiffOptions`. */ - def diffAs[U, V](other: Dataset[T], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String]) - (implicit diffEncoder: Encoder[V]): Dataset[V] = { + def diffAs[U, V](other: Dataset[T], options: DiffOptions, idColumns: Seq[String], ignoreColumns: Seq[String])( + implicit diffEncoder: Encoder[V] + ): Dataset[V] = { new Differ(options).diffAs(this.ds, other, idColumns, ignoreColumns) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of the same type `T`. + * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`. * * See `diff(Dataset[T], String*)`. * @@ -243,117 +242,114 @@ package object diff { } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of similar types `T` and `U`. + * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and + * `U`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. * * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. */ - def diffAs[U, V](other: Dataset[U], diffEncoder: Encoder[V], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[V] = { + def diffAs[U, V]( + other: Dataset[U], + diffEncoder: Encoder[V], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[V] = { Diff.ofAs(this.ds, other, diffEncoder, idColumns, ignoreColumns) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of the same type `T`. + * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T`. * * See `diff(Dataset[T], String*)`. * - * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. - * The schema of the returned Dataset can be configured by the given `DiffOptions`. + * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned + * Dataset can be configured by the given `DiffOptions`. */ // no @scala.annotation.varargs here as this implicit class is not nicely accessible from Java - def diffAs[V](other: Dataset[T], - options: DiffOptions, - diffEncoder: Encoder[V], - idColumns: String*): Dataset[V] = { + def diffAs[V](other: Dataset[T], options: DiffOptions, diffEncoder: Encoder[V], idColumns: String*): Dataset[V] = { new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns: _*) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of similar types `T` and `U`. + * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and + * `U`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. * - * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. - * The schema of the returned Dataset can be configured by the given `DiffOptions`. + * This requires an additional explicit `Encoder[V]` for the return type `Dataset[V]`. The schema of the returned + * Dataset can be configured by the given `DiffOptions`. */ - def diffAs[U, V](other: Dataset[U], - options: DiffOptions, - diffEncoder: Encoder[V], - idColumns: Seq[String], - ignoreColumns: Seq[String]): Dataset[V] = { + def diffAs[U, V]( + other: Dataset[U], + options: DiffOptions, + diffEncoder: Encoder[V], + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[V] = { new Differ(options).diffAs(this.ds, other, diffEncoder, idColumns, ignoreColumns) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of the same type `T` - * as tuples of type `(String, T, T)`. + * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as + * tuples of type `(String, T, T)`. * * See `diff(Dataset[T], Seq[String])`. */ - def diffWith(other: Dataset[T], - idColumns: String*): Dataset[(String, T, T)] = + def diffWith(other: Dataset[T], idColumns: String*): Dataset[(String, T, T)] = Diff.default.diffWith(this.ds, other, idColumns: _*) /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of similar types `T` and `U` - * as tuples of type `(String, T, U)`. + * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and + * `U` as tuples of type `(String, T, U)`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. */ - def diffWith[U](other: Dataset[U], - idColumns: Seq[String], - ignoreColumns: Seq[String]): Dataset[(String, T, U)] = + def diffWith[U](other: Dataset[U], idColumns: Seq[String], ignoreColumns: Seq[String]): Dataset[(String, T, U)] = Diff.default.diffWith(this.ds, other, idColumns, ignoreColumns) /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of the same type `T` - * as tuples of type `(String, T, T)`. + * Returns a new Dataset that contains the differences between this and the other Dataset of the same type `T` as + * tuples of type `(String, T, T)`. * * See `diff(Dataset[T], String*)`. * * The schema of the returned Dataset can be configured by the given `DiffOptions`. */ - def diffWith(other: Dataset[T], - options: DiffOptions, - idColumns: String*): Dataset[(String, T, T)] = { + def diffWith(other: Dataset[T], options: DiffOptions, idColumns: String*): Dataset[(String, T, T)] = { new Differ(options).diffWith(this.ds, other, idColumns: _*) } /** - * Returns a new Dataset that contains the differences - * between this and the other Dataset of similar types `T` and `U`. - * as tuples of type `(String, T, T)`. + * Returns a new Dataset that contains the differences between this and the other Dataset of similar types `T` and + * `U`. as tuples of type `(String, T, T)`. * * See `diff(Dataset[U], Seq[String], Seq[String])`. * * The schema of the returned Dataset can be configured by the given `DiffOptions`. */ - def diffWith[U](other: Dataset[U], - options: DiffOptions, - idColumns: Seq[String], - ignoreColumns: Seq[String]): Dataset[(String, T, U)] = { + def diffWith[U]( + other: Dataset[U], + options: DiffOptions, + idColumns: Seq[String], + ignoreColumns: Seq[String] + ): Dataset[(String, T, U)] = { new Differ(options).diffWith(this.ds, other, idColumns, ignoreColumns) } } /** - * Produces a column name that considers configured case-sensitivity of column names. - * When case sensitivity is deactivated, it lower-cases the given column name and no-ops otherwise. + * Produces a column name that considers configured case-sensitivity of column names. When case sensitivity is + * deactivated, it lower-cases the given column name and no-ops otherwise. * - * @param columnName column name - * @return case sensitive or insensitive column name + * @param columnName + * column name + * @return + * case sensitive or insensitive column name */ private[diff] def handleConfiguredCaseSensitivity(columnName: String): String = if (SQLConf.get.caseSensitiveAnalysis) columnName else columnName.toLowerCase(Locale.ROOT) - implicit class CaseInsensitiveSeq(seq: Seq[String]) { def containsCaseSensitivity(string: String): Boolean = seq.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string)) @@ -372,7 +368,8 @@ package object diff { implicit class CaseInsensitiveArray(array: Array[String]) { def containsCaseSensitivity(string: String): Boolean = array.map(handleConfiguredCaseSensitivity).contains(handleConfiguredCaseSensitivity(string)) - def filterIsInCaseSensitivity(other: Iterable[String]): Array[String] = array.toSeq.filterIsInCaseSensitivity(other).toArray + def filterIsInCaseSensitivity(other: Iterable[String]): Array[String] = + array.toSeq.filterIsInCaseSensitivity(other).toArray def diffCaseSensitivity(other: Iterable[String]): Array[String] = array.toSeq.diffCaseSensitivity(other).toArray } diff --git a/src/main/scala/uk/co/gresearch/spark/group/package.scala b/src/main/scala/uk/co/gresearch/spark/group/package.scala index 5607ee1d..41063b1a 100644 --- a/src/main/scala/uk/co/gresearch/spark/group/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/group/package.scala @@ -23,54 +23,50 @@ import uk.co.gresearch.ExtendedAny package object group { /** - * This is a Dataset of key-value tuples, that provide a flatMap function over the individual groups, - * while providing a sorted iterator over group values. + * This is a Dataset of key-value tuples, that provide a flatMap function over the individual groups, while providing + * a sorted iterator over group values. * - * The key-value Dataset given the constructor has to be partitioned by the key - * and sorted within partitions by the key and value. + * The key-value Dataset given the constructor has to be partitioned by the key and sorted within partitions by the + * key and value. * - * @param ds the properly partitioned and sorted dataset - * @tparam K type of the keys with ordering and encoder - * @tparam V type of the values with encoder + * @param ds + * the properly partitioned and sorted dataset + * @tparam K + * type of the keys with ordering and encoder + * @tparam V + * type of the values with encoder */ - case class SortedGroupByDataset[K: Ordering : Encoder, V: Encoder] private (ds: Dataset[(K, V)]) { + case class SortedGroupByDataset[K: Ordering: Encoder, V: Encoder] private (ds: Dataset[(K, V)]) { + /** - * (Scala-specific) - * Applies the given function to each group of data. For each unique group, the function will - * be passed the group key and a sorted iterator that contains all of the elements in the group. - * The function can return an iterator containing elements of an arbitrary type which will be - * returned as a new [[Dataset]]. + * (Scala-specific) Applies the given function to each group of data. For each unique group, the function will be + * passed the group key and a sorted iterator that contains all of the elements in the group. The function can + * return an iterator containing elements of an arbitrary type which will be returned as a new [[Dataset]]. * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. + * This function does not support partial aggregation, and as a result requires shuffling all the data in the + * [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce + * function or an `org.apache.spark.sql.expressions#Aggregator`. * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. + * Internally, the implementation will spill to disk if any given group is too large to fit into memory. However, + * users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`) + * unless they are sure that this is possible given the memory constraints of their cluster. */ def flatMapSortedGroups[W: Encoder](func: (K, Iterator[V]) => TraversableOnce[W]): Dataset[W] = ds.mapPartitions(new GroupedIterator(_).flatMap(v => func(v._1, v._2))) /** - * (Scala-specific) - * Applies the given function to each group of data. For each unique group, the function s will - * be passed the group key to create a state instance, while the function func will be passed - * that state instance and group values in sequence according to the sort order in the groups. - * The function func can return an iterator containing elements of an arbitrary type which will - * be returned as a new [[Dataset]]. + * (Scala-specific) Applies the given function to each group of data. For each unique group, the function s will be + * passed the group key to create a state instance, while the function func will be passed that state instance and + * group values in sequence according to the sort order in the groups. The function func can return an iterator + * containing elements of an arbitrary type which will be returned as a new [[Dataset]]. * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. + * This function does not support partial aggregation, and as a result requires shuffling all the data in the + * [[Dataset]]. If an application intends to perform an aggregation over each key, it is best to use the reduce + * function or an `org.apache.spark.sql.expressions#Aggregator`. * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. + * Internally, the implementation will spill to disk if any given group is too large to fit into memory. However, + * users must take care to avoid materializing the whole iterator for a group (for example, by calling `toList`) + * unless they are sure that this is possible given the memory constraints of their cluster. */ def flatMapSortedGroups[S, W: Encoder](s: K => S)(func: (S, V) => TraversableOnce[W]): Dataset[W] = { ds.mapPartitions(new GroupedIterator(_).flatMap { case (k, it) => @@ -81,10 +77,12 @@ package object group { } object SortedGroupByDataset { - def apply[K: Ordering : Encoder, V](ds: Dataset[V], - groupColumns: Seq[Column], - orderColumns: Seq[Column], - partitions: Option[Int]): SortedGroupByDataset[K, V] = { + def apply[K: Ordering: Encoder, V]( + ds: Dataset[V], + groupColumns: Seq[Column], + orderColumns: Seq[Column], + partitions: Option[Int] + ): SortedGroupByDataset[K, V] = { // make ds encoder implicitly available implicit val valueEncoder: Encoder[V] = ds.encoder @@ -115,11 +113,13 @@ package object group { SortedGroupByDataset(grouped) } - def apply[K: Ordering : Encoder, V, O: Encoder](ds: Dataset[V], - key: V => K, - order: V => O, - partitions: Option[Int], - reverse: Boolean): SortedGroupByDataset[K, V] = { + def apply[K: Ordering: Encoder, V, O: Encoder]( + ds: Dataset[V], + key: V => K, + order: V => O, + partitions: Option[Int], + reverse: Boolean + ): SortedGroupByDataset[K, V] = { // prepare encoder needed for this exercise val keyEncoder: Encoder[K] = implicitly[Encoder[K]] implicit val valueEncoder: Encoder[V] = ds.encoder diff --git a/src/main/scala/uk/co/gresearch/spark/package.scala b/src/main/scala/uk/co/gresearch/spark/package.scala index aa644886..a1396f9d 100644 --- a/src/main/scala/uk/co/gresearch/spark/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/package.scala @@ -32,8 +32,10 @@ package object spark extends Logging with SparkVersion with BuildVersion { /** * Provides a prefix that makes any string distinct w.r.t. the given strings. - * @param existing strings - * @return distinct prefix + * @param existing + * strings + * @return + * distinct prefix */ private[spark] def distinctPrefixFor(existing: Seq[String]): String = { "_" * (existing.map(_.takeWhile(_ == '_').length).reduceOption(_ max _).getOrElse(0) + 1) @@ -41,8 +43,10 @@ package object spark extends Logging with SparkVersion with BuildVersion { /** * Create a temporary directory in a location (driver temp dir) that will be deleted on Spark application shutdown. - * @param prefix prefix string of temporary directory name - * @return absolute path of temporary directory + * @param prefix + * prefix string of temporary directory name + * @return + * absolute path of temporary directory */ def createTemporaryDir(prefix: String): String = { // SparkFiles.getRootDirectory() will be deleted on spark application shutdown @@ -51,12 +55,15 @@ package object spark extends Logging with SparkVersion with BuildVersion { // https://issues.apache.org/jira/browse/SPARK-40588 private[spark] def writePartitionedByRequiresCaching[T](ds: Dataset[T]): Boolean = { - ds.sparkSession.conf.get( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, - SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString - ).equalsIgnoreCase("true") && Some(ds.sparkSession.version).exists(ver => - Set("3.0.", "3.1.", "3.2.0", "3.2.1" ,"3.2.2", "3.3.0", "3.3.1").exists(pat => - if (pat.endsWith(".")) { ver.startsWith(pat) } else { ver.equals(pat) || ver.startsWith(pat + "-") } + ds.sparkSession.conf + .get( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.defaultValue.getOrElse(true).toString + ) + .equalsIgnoreCase("true") && Some(ds.sparkSession.version).exists(ver => + Set("3.0.", "3.1.", "3.2.0", "3.2.1", "3.2.2", "3.3.0", "3.3.1").exists(pat => + if (pat.endsWith(".")) { ver.startsWith(pat) } + else { ver.equals(pat) || ver.startsWith(pat + "-") } ) ) } @@ -80,8 +87,10 @@ package object spark extends Logging with SparkVersion with BuildVersion { * col(backticks("a.column", "a.field")) // produces "`a.column`.`a.field`" * }}} * - * @param string a string - * @param strings more strings + * @param string + * a string + * @param strings + * more strings */ @scala.annotation.varargs def backticks(string: String, strings: String*): String = @@ -97,307 +106,308 @@ package object spark extends Logging with SparkVersion with BuildVersion { private val unixEpochDotNetTicks: Long = 621355968000000000L /** - * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be - * convertible to a number (e.g. string, int, long). The Spark timestamp type does not support - * nanoseconds, so the the last digit of the timestamp (1/10 of a microsecond) is lost. + * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number + * (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the + * timestamp (1/10 of a microsecond) is lost. * * Example: * {{{ * df.select($"ticks", dotNetTicksToTimestamp($"ticks").as("timestamp")).show(false) * }}} * - * +------------------+--------------------------+ - * |ticks |timestamp | - * +------------------+--------------------------+ - * |638155413748959318|2023-03-27 21:16:14.895931| - * +------------------+--------------------------+ + * | ticks | timestamp | + * |:-------------------|:---------------------------| + * | 638155413748959318 | 2023-03-27 21:16:14.895931 | * - * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to - * preserve the full precision of the tick timestamp. + * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full + * precision of the tick timestamp. * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param tickColumn column with a tick value - * @return result timestamp column + * @param tickColumn + * column with a tick value + * @return + * result timestamp column */ def dotNetTicksToTimestamp(tickColumn: Column): Column = dotNetTicksToUnixEpoch(tickColumn).cast(TimestampType) /** - * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be - * convertible to a number (e.g. string, int, long). The Spark timestamp type does not support - * nanoseconds, so the the last digit of the timestamp (1/10 of a microsecond) is lost. + * Convert a .Net `DateTime.Ticks` timestamp to a Spark timestamp. The input column must be convertible to a number + * (e.g. string, int, long). The Spark timestamp type does not support nanoseconds, so the the last digit of the + * timestamp (1/10 of a microsecond) is lost. * * {{{ * df.select($"ticks", dotNetTicksToTimestamp("ticks").as("timestamp")).show(false) * }}} * - * +------------------+--------------------------+ - * |ticks |timestamp | - * +------------------+--------------------------+ - * |638155413748959318|2023-03-27 21:16:14.895931| - * +------------------+--------------------------+ + * | ticks | timestamp | + * |:-------------------|:---------------------------| + * | 638155413748959318 | 2023-03-27 21:16:14.895931 | * - * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to - * preserve the full precision of the tick timestamp. + * Note: the example timestamp lacks the 8/10 of a microsecond. Use `dotNetTicksToUnixEpoch` to preserve the full + * precision of the tick timestamp. * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param tickColumnName name of a column with a tick value - * @return result timestamp column + * @param tickColumnName + * name of a column with a tick value + * @return + * result timestamp column */ def dotNetTicksToTimestamp(tickColumnName: String): Column = dotNetTicksToTimestamp(col(tickColumnName)) /** - * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be - * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp - * is preserved (1/10 of a microsecond). + * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch decimal. The input column must be convertible to a number + * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond). * * Example: * {{{ * df.select($"ticks", dotNetTicksToUnixEpoch($"ticks").as("timestamp")).show(false) * }}} * - * +------------------+--------------------+ - * |ticks |timestamp | - * +------------------+--------------------+ - * |638155413748959318|1679944574.895931800| - * +------------------+--------------------+ + * | ticks | timestamp | + * |:-------------------|:---------------------| + * | 638155413748959318 | 1679944574.895931800 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param tickColumn column with a tick value - * @return result unix epoch seconds column as decimal + * @param tickColumn + * column with a tick value + * @return + * result unix epoch seconds column as decimal */ def dotNetTicksToUnixEpoch(tickColumn: Column): Column = (tickColumn.cast(DecimalType(19, 0)) - unixEpochDotNetTicks) / dotNetTicksPerSecond /** - * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be - * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp - * is preserved (1/10 of a microsecond). + * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number + * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond). * * Example: * {{{ * df.select($"ticks", dotNetTicksToUnixEpoch("ticks").as("timestamp")).show(false) * }}} * - * +------------------+--------------------+ - * |ticks |timestamp | - * +------------------+--------------------+ - * |638155413748959318|1679944574.895931800| - * +------------------+--------------------+ + * | ticks | timestamp | + * |:-------------------|:---------------------| + * | 638155413748959318 | 1679944574.895931800 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param tickColumnName name of column with a tick value - * @return result unix epoch seconds column as decimal + * @param tickColumnName + * name of column with a tick value + * @return + * result unix epoch seconds column as decimal */ def dotNetTicksToUnixEpoch(tickColumnName: String): Column = dotNetTicksToUnixEpoch(col(tickColumnName)) /** - * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be - * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp - * is preserved (1/10 of a microsecond). + * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch seconds. The input column must be convertible to a number + * (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond). * * Example: * {{{ * df.select($"ticks", dotNetTicksToUnixEpochNanos($"ticks").as("timestamp")).show(false) * }}} * - * +------------------+-------------------+ - * |ticks |timestamp | - * +------------------+-------------------+ - * |638155413748959318|1679944574895931800| - * +------------------+-------------------+ + * | ticks | timestamp | + * |:-------------------|:--------------------| + * | 638155413748959318 | 1679944574895931800 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param tickColumn column with a tick value - * @return result unix epoch nanoseconds column as long + * @param tickColumn + * column with a tick value + * @return + * result unix epoch nanoseconds column as long */ def dotNetTicksToUnixEpochNanos(tickColumn: Column): Column = { - when(tickColumn <= 713589688368547758L, (tickColumn.cast(LongType) - unixEpochDotNetTicks) * nanoSecondsPerDotNetTick) + when( + tickColumn <= 713589688368547758L, + (tickColumn.cast(LongType) - unixEpochDotNetTicks) * nanoSecondsPerDotNetTick + ) } /** - * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be - * convertible to a number (e.g. string, int, long). The full precision of the tick timestamp - * is preserved (1/10 of a microsecond). + * Convert a .Net `DateTime.Ticks` timestamp to a Unix epoch nanoseconds. The input column must be convertible to a + * number (e.g. string, int, long). The full precision of the tick timestamp is preserved (1/10 of a microsecond). * * Example: * {{{ * df.select($"ticks", dotNetTicksToUnixEpochNanos("ticks").as("timestamp")).show(false) * }}} * - * +------------------+-------------------+ - * |ticks |timestamp | - * +------------------+-------------------+ - * |638155413748959318|1679944574895931800| - * +------------------+-------------------+ + * | ticks | timestamp | + * |:-------------------|:--------------------| + * | 638155413748959318 | 1679944574895931800 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param tickColumnName name of column with a tick value - * @return result unix epoch nanoseconds column as long + * @param tickColumnName + * name of column with a tick value + * @return + * result unix epoch nanoseconds column as long */ def dotNetTicksToUnixEpochNanos(tickColumnName: String): Column = dotNetTicksToUnixEpochNanos(col(tickColumnName)) /** - * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. - * The input column must be of TimestampType. + * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType. * * Example: * {{{ * df.select($"timestamp", timestampToDotNetTicks($"timestamp").as("ticks")).show(false) * }}} * - * +--------------------------+------------------+ - * |timestamp |ticks | - * +--------------------------+------------------+ - * |2023-03-27 21:16:14.895931|638155413748959310| - * +--------------------------+------------------+ + * | timestamp | ticks | + * |:---------------------------|:-------------------| + * | 2023-03-27 21:16:14.895931 | 638155413748959310 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param timestampColumn column with a timestamp value - * @return result tick value column + * @param timestampColumn + * column with a timestamp value + * @return + * result tick value column */ def timestampToDotNetTicks(timestampColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(new Column(UnixMicros.unixMicros(timestampColumn.expr)) * 10) /** - * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. - * The input column must be of TimestampType. + * Convert a Spark timestamp to a .Net `DateTime.Ticks` timestamp. The input column must be of TimestampType. * * Example: * {{{ * df.select($"timestamp", timestampToDotNetTicks("timestamp").as("ticks")).show(false) * }}} * - * +--------------------------+------------------+ - * |timestamp |ticks | - * +--------------------------+------------------+ - * |2023-03-27 21:16:14.895931|638155413748959310| - * +--------------------------+------------------+ + * | timestamp | ticks | + * |:---------------------------|:-------------------| + * | 2023-03-27 21:16:14.895931 | 638155413748959310 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param timestampColumnName name of column with a timestamp value - * @return result tick value column + * @param timestampColumnName + * name of column with a timestamp value + * @return + * result tick value column */ def timestampToDotNetTicks(timestampColumnName: String): Column = timestampToDotNetTicks(col(timestampColumnName)) /** - * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. - * The input column must represent a numerical unix epoch timestamp, e.g. long, double, string or decimal. - * The input must not be of TimestampType, as that may be interpreted incorrectly. - * Use `timestampToDotNetTicks` for TimestampType columns instead. + * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical + * unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be + * interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead. * * Example: * {{{ * df.select($"unix", unixEpochToDotNetTicks($"unix").as("ticks")).show(false) * }}} * - * +-----------------------------+------------------+ - * |unix |ticks | - * +-----------------------------+------------------+ - * |1679944574.895931234000000000|638155413748959312| - * +-----------------------------+------------------+ + * | unix | ticks | + * |:------------------------------|:-------------------| + * | 1679944574.895931234000000000 | 638155413748959312 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param unixTimeColumn column with a unix epoch timestamp value - * @return result tick value column + * @param unixTimeColumn + * column with a unix epoch timestamp value + * @return + * result tick value column */ - def unixEpochToDotNetTicks(unixTimeColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(unixTimeColumn.cast(DecimalType(19, 7)) * 10000000) + def unixEpochToDotNetTicks(unixTimeColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks( + unixTimeColumn.cast(DecimalType(19, 7)) * 10000000 + ) /** - * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. - * The input column must represent a numerical unix epoch timestamp, e.g. long, double, string or decimal. - * The input must not be of TimestampType, as that may be interpreted incorrectly. - * Use `timestampToDotNetTicks` for TimestampType columns instead. + * Convert a Unix epoch timestamp to a .Net `DateTime.Ticks` timestamp. The input column must represent a numerical + * unix epoch timestamp, e.g. long, double, string or decimal. The input must not be of TimestampType, as that may be + * interpreted incorrectly. Use `timestampToDotNetTicks` for TimestampType columns instead. * * Example: * {{{ * df.select($"unix", unixEpochToDotNetTicks("unix").as("ticks")).show(false) * }}} * - * +-----------------------------+------------------+ - * |unix |ticks | - * +-----------------------------+------------------+ - * |1679944574.895931234000000000|638155413748959312| - * +-----------------------------+------------------+ + * | unix | ticks | + * |:------------------------------|:-------------------| + * | 1679944574.895931234000000000 | 638155413748959312 | * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param unixTimeColumnName name of column with a unix epoch timestamp value - * @return result tick value column + * @param unixTimeColumnName + * name of column with a unix epoch timestamp value + * @return + * result tick value column */ def unixEpochToDotNetTicks(unixTimeColumnName: String): Column = unixEpochToDotNetTicks(col(unixTimeColumnName)) /** - * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. - * The .Net ticks timestamp does not support the two lowest nanosecond digits, - * so only a 1/10 of a microsecond is the smallest resolution. - * The input column must represent a numerical unix epoch nanoseconds timestamp, - * e.g. long, double, string or decimal. + * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not + * support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input + * column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal. * * Example: * {{{ * df.select($"unix_nanos", unixEpochNanosToDotNetTicks($"unix_nanos").as("ticks")).show(false) * }}} * - * +-------------------+------------------+ - * |unix_nanos |ticks | - * +-------------------+------------------+ - * |1679944574895931234|638155413748959312| - * +-------------------+------------------+ + * | unix_nanos | ticks | + * |:--------------------|:-------------------| + * | 1679944574895931234 | 638155413748959312 | * * Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks. * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param unixNanosColumn column with a unix epoch timestamp value - * @return result tick value column + * @param unixNanosColumn + * column with a unix epoch timestamp value + * @return + * result tick value column */ - def unixEpochNanosToDotNetTicks(unixNanosColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks(unixNanosColumn.cast(DecimalType(21, 0)) / nanoSecondsPerDotNetTick) + def unixEpochNanosToDotNetTicks(unixNanosColumn: Column): Column = unixEpochTenthMicrosToDotNetTicks( + unixNanosColumn.cast(DecimalType(21, 0)) / nanoSecondsPerDotNetTick + ) /** - * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. - * The .Net ticks timestamp does not support the two lowest nanosecond digits, - * so only a 1/10 of a microsecond is the smallest resolution. - * The input column must represent a numerical unix epoch nanoseconds timestamp, - * e.g. long, double, string or decimal. + * Convert a Unix epoch nanosecond timestamp to a .Net `DateTime.Ticks` timestamp. The .Net ticks timestamp does not + * support the two lowest nanosecond digits, so only a 1/10 of a microsecond is the smallest resolution. The input + * column must represent a numerical unix epoch nanoseconds timestamp, e.g. long, double, string or decimal. * * Example: * {{{ * df.select($"unix_nanos", unixEpochNanosToDotNetTicks($"unix_nanos").as("ticks")).show(false) * }}} * - * +-------------------+------------------+ - * |unix_nanos |ticks | - * +-------------------+------------------+ - * |1679944574895931234|638155413748959312| - * +-------------------+------------------+ + * | unix_nanos | ticks | + * |:--------------------|:-------------------| + * | 1679944574895931234 | 638155413748959312 | * * Note: the example timestamp lacks the two lower nanosecond digits as this precision is not supported by .Net ticks. * * https://learn.microsoft.com/de-de/dotnet/api/system.datetime.ticks * - * @param unixNanosColumnName name of column with a unix epoch timestamp value - * @return result tick value column + * @param unixNanosColumnName + * name of column with a unix epoch timestamp value + * @return + * result tick value column */ - def unixEpochNanosToDotNetTicks(unixNanosColumnName: String): Column = unixEpochNanosToDotNetTicks(col(unixNanosColumnName)) + def unixEpochNanosToDotNetTicks(unixNanosColumnName: String): Column = unixEpochNanosToDotNetTicks( + col(unixNanosColumnName) + ) - private def unixEpochTenthMicrosToDotNetTicks(unixNanosColumn: Column): Column = unixNanosColumn.cast(LongType) + unixEpochDotNetTicks + private def unixEpochTenthMicrosToDotNetTicks(unixNanosColumn: Column): Column = + unixNanosColumn.cast(LongType) + unixEpochDotNetTicks /** * Set the job description and return the earlier description. Only set the description if it is not set. * - * @param description job description - * @param ifNotSet job description is only set if no description is set yet - * @param context spark context + * @param description + * job description + * @param ifNotSet + * job description is only set if no description is set yet + * @param context + * spark context * @return */ def setJobDescription(description: String, ifNotSet: Boolean = false)(implicit context: SparkContext): String = { @@ -409,8 +419,8 @@ package object spark extends Logging with SparkVersion with BuildVersion { } /** - * Adds a job description to all Spark jobs started within the given function. - * The current Job description is restored after exit of the function. + * Adds a job description to all Spark jobs started within the given function. The current Job description is restored + * after exit of the function. * * Usage example: * @@ -427,16 +437,22 @@ package object spark extends Logging with SparkVersion with BuildVersion { * * With `ifNotSet == true`, the description is only set if no job description is set yet. * - * Any modification to the job description during execution of the function is reverted, - * even if `ifNotSet == true`. - * - * @param description job description - * @param ifNotSet job description is only set if no description is set yet - * @param func code to execute while job description is set - * @param session spark session - * @tparam T return type of func + * Any modification to the job description during execution of the function is reverted, even if `ifNotSet == true`. + * + * @param description + * job description + * @param ifNotSet + * job description is only set if no description is set yet + * @param func + * code to execute while job description is set + * @param session + * spark session + * @tparam T + * return type of func */ - def withJobDescription[T](description: String, ifNotSet: Boolean = false)(func: => T)(implicit session: SparkSession): T = { + def withJobDescription[T](description: String, ifNotSet: Boolean = false)( + func: => T + )(implicit session: SparkSession): T = { val earlierDescription = setJobDescription(description, ifNotSet)(session.sparkContext) try { func @@ -448,9 +464,12 @@ package object spark extends Logging with SparkVersion with BuildVersion { /** * Append the job description and return the earlier description. * - * @param extraDescription job description - * @param separator separator to join exiting and extra description with - * @param context spark context + * @param extraDescription + * job description + * @param separator + * separator to join exiting and extra description with + * @param context + * spark context * @return */ def appendJobDescription(extraDescription: String, separator: String, context: SparkContext): String = { @@ -461,9 +480,9 @@ package object spark extends Logging with SparkVersion with BuildVersion { } /** - * Appends a job description to all Spark jobs started within the given function. - * The current Job description is extended by the separator and the extra description - * on entering the function, and restored after exit of the function. + * Appends a job description to all Spark jobs started within the given function. The current Job description is + * extended by the separator and the extra description on entering the function, and restored after exit of the + * function. * * Usage example: * @@ -482,13 +501,20 @@ package object spark extends Logging with SparkVersion with BuildVersion { * * Any modification to the job description during execution of the function is reverted. * - * @param extraDescription job description to be appended - * @param separator separator used when appending description - * @param func code to execute while job description is set - * @param session spark session - * @tparam T return type of func + * @param extraDescription + * job description to be appended + * @param separator + * separator used when appending description + * @param func + * code to execute while job description is set + * @param session + * spark session + * @tparam T + * return type of func */ - def appendJobDescription[T](extraDescription: String, separator: String = " - ")(func: => T)(implicit session: SparkSession): T = { + def appendJobDescription[T](extraDescription: String, separator: String = " - ")( + func: => T + )(implicit session: SparkSession): T = { val earlierDescription = appendJobDescription(extraDescription, separator, session.sparkContext) try { func @@ -500,37 +526,61 @@ package object spark extends Logging with SparkVersion with BuildVersion { /** * Class to extend a Spark Dataset. * - * @param ds dataset - * @tparam V inner type of dataset + * @param ds + * dataset + * @tparam V + * inner type of dataset */ - @deprecated("Constructor with encoder is deprecated, the encoder argument is ignored, ds.encoder is used instead.", since = "2.9.0") + @deprecated( + "Constructor with encoder is deprecated, the encoder argument is ignored, ds.encoder is used instead.", + since = "2.9.0" + ) class ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]) { private val eds = ExtendedDatasetV2[V](ds) def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = eds.histogram(thresholds, valueColumn, aggregateColumns: _*) - def writePartitionedBy(partitionColumns: Seq[Column], - moreFileColumns: Seq[Column] = Seq.empty, - moreFileOrder: Seq[Column] = Seq.empty, - partitions: Option[Int] = None, - writtenProjection: Option[Seq[Column]] = None, - unpersistHandle: Option[UnpersistHandle] = None): DataFrameWriter[Row] = - eds.writePartitionedBy(partitionColumns, moreFileColumns, moreFileOrder, partitions, writtenProjection, unpersistHandle) + def writePartitionedBy( + partitionColumns: Seq[Column], + moreFileColumns: Seq[Column] = Seq.empty, + moreFileOrder: Seq[Column] = Seq.empty, + partitions: Option[Int] = None, + writtenProjection: Option[Seq[Column]] = None, + unpersistHandle: Option[UnpersistHandle] = None + ): DataFrameWriter[Row] = + eds.writePartitionedBy( + partitionColumns, + moreFileColumns, + moreFileOrder, + partitions, + writtenProjection, + unpersistHandle + ) - def groupBySorted[K: Ordering : Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = + def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = eds.groupBySorted(cols: _*)(order: _*) - def groupBySorted[K: Ordering : Encoder](partitions: Int)(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = + def groupBySorted[K: Ordering: Encoder](partitions: Int)(cols: Column*)( + order: Column* + ): SortedGroupByDataset[K, V] = eds.groupBySorted(partitions)(cols: _*)(order: _*) - def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O): SortedGroupByDataset[K, V] = + def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)( + order: V => O + ): SortedGroupByDataset[K, V] = eds.groupByKeySorted(key, Some(partitions))(order) - def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O, reverse: Boolean): SortedGroupByDataset[K, V] = + def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)( + order: V => O, + reverse: Boolean + ): SortedGroupByDataset[K, V] = eds.groupByKeySorted(key, Some(partitions))(order, reverse) - def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = + def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)( + order: V => O, + reverse: Boolean = false + ): SortedGroupByDataset[K, V] = eds.groupByKeySorted(key, partitions)(order, reverse) def withRowNumbers(order: Column*): DataFrame = @@ -545,72 +595,74 @@ package object spark extends Logging with SparkVersion with BuildVersion { def withRowNumbers(unpersistHandle: UnpersistHandle, order: Column*): DataFrame = eds.withRowNumbers(unpersistHandle, order: _*) - def withRowNumbers(rowNumberColumnName: String, - storageLevel: StorageLevel, - order: Column*): DataFrame = + def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame = eds.withRowNumbers(rowNumberColumnName, storageLevel, order: _*) - def withRowNumbers(rowNumberColumnName: String, - unpersistHandle: UnpersistHandle, - order: Column*): DataFrame = + def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame = eds.withRowNumbers(rowNumberColumnName, unpersistHandle, order: _*) - def withRowNumbers(storageLevel: StorageLevel, - unpersistHandle: UnpersistHandle, - order: Column*): DataFrame = + def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame = eds.withRowNumbers(storageLevel, unpersistHandle, order: _*) - def withRowNumbers(rowNumberColumnName: String, - storageLevel: StorageLevel, - unpersistHandle: UnpersistHandle, - order: Column*): DataFrame = + def withRowNumbers( + rowNumberColumnName: String, + storageLevel: StorageLevel, + unpersistHandle: UnpersistHandle, + order: Column* + ): DataFrame = eds.withRowNumbers(rowNumberColumnName, storageLevel, unpersistHandle, order: _*) } /** * Class to extend a Spark Dataset. * - * @param ds dataset - * @tparam V inner type of dataset + * @param ds + * dataset + * @tparam V + * inner type of dataset */ def ExtendedDataset[V](ds: Dataset[V], encoder: Encoder[V]): ExtendedDataset[V] = new ExtendedDataset(ds, encoder) /** * Implicit class to extend a Spark Dataset. * - * @param ds dataset - * @tparam V inner type of dataset + * @param ds + * dataset + * @tparam V + * inner type of dataset */ implicit class ExtendedDatasetV2[V](ds: Dataset[V]) { private implicit val encoder: Encoder[V] = ds.encoder /** - * Compute the histogram of a column when aggregated by aggregate columns. - * Thresholds are expected to be provided in ascending order. - * The result dataframe contains the aggregate and histogram columns only. - * For each threshold value in thresholds, there will be a column named s"≤threshold". - * There will also be a final column called s">last_threshold", that counts the remaining - * values that exceed the last threshold. - * - * @param thresholds sequence of thresholds, must implement <= and > operators w.r.t. valueColumn - * @param valueColumn histogram is computed for values of this column - * @param aggregateColumns histogram is computed against these columns - * @tparam T type of histogram thresholds - * @return dataframe with aggregate and histogram columns + * Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in + * ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value + * in thresholds, there will be a column named s"≤threshold". There will also be a final column called + * s">last_threshold", that counts the remaining values that exceed the last threshold. + * + * @param thresholds + * sequence of thresholds, must implement <= and > operators w.r.t. valueColumn + * @param valueColumn + * histogram is computed for values of this column + * @param aggregateColumns + * histogram is computed against these columns + * @tparam T + * type of histogram thresholds + * @return + * dataframe with aggregate and histogram columns */ def histogram[T: Ordering](thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = Histogram.of(ds, thresholds, valueColumn, aggregateColumns: _*) /** - * Writes the Dataset / DataFrame via DataFrameWriter.partitionBy. In addition to partitionBy, - * this method sorts the data to improve partition file size. Small partitions will contain few - * files, large partitions contain more files. Partition ids are contained in a single partition - * file per `partitionBy` partition only. Rows within the partition files are also sorted, - * if partitionOrder is defined. + * Writes the Dataset / DataFrame via DataFrameWriter.partitionBy. In addition to partitionBy, this method sorts the + * data to improve partition file size. Small partitions will contain few files, large partitions contain more + * files. Partition ids are contained in a single partition file per `partitionBy` partition only. Rows within the + * partition files are also sorted, if partitionOrder is defined. * - * Note: With Spark 3.0, 3.1, 3.2 before 3.2.3, 3.3 before 3.3.2, and AQE enabled, an intermediate DataFrame is being - * cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588. - * That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method. + * Note: With Spark 3.0, 3.1, 3.2 before 3.2.3, 3.3 before 3.3.2, and AQE enabled, an intermediate DataFrame is + * being cached in order to guarantee sorted output files. See https://issues.apache.org/jira/browse/SPARK-40588. + * That cached DataFrame can be unpersisted via an optional [[UnpersistHandle]] provided to this method. * * Calling: * {{{ @@ -638,20 +690,29 @@ package object spark extends Logging with SparkVersion with BuildVersion { * cached.unpersist * }}} * - * @param partitionColumns columns used for partitioning - * @param moreFileColumns columns where individual values are written to a single file - * @param moreFileOrder additional columns to sort partition files - * @param partitions optional number of partition files - * @param writtenProjection additional transformation to be applied before calling write - * @param unpersistHandle handle to unpersist internally created DataFrame after writing - * @return configured DataFrameWriter + * @param partitionColumns + * columns used for partitioning + * @param moreFileColumns + * columns where individual values are written to a single file + * @param moreFileOrder + * additional columns to sort partition files + * @param partitions + * optional number of partition files + * @param writtenProjection + * additional transformation to be applied before calling write + * @param unpersistHandle + * handle to unpersist internally created DataFrame after writing + * @return + * configured DataFrameWriter */ - def writePartitionedBy(partitionColumns: Seq[Column], - moreFileColumns: Seq[Column] = Seq.empty, - moreFileOrder: Seq[Column] = Seq.empty, - partitions: Option[Int] = None, - writtenProjection: Option[Seq[Column]] = None, - unpersistHandle: Option[UnpersistHandle] = None): DataFrameWriter[Row] = { + def writePartitionedBy( + partitionColumns: Seq[Column], + moreFileColumns: Seq[Column] = Seq.empty, + moreFileOrder: Seq[Column] = Seq.empty, + partitions: Option[Int] = None, + writtenProjection: Option[Seq[Column]] = None, + unpersistHandle: Option[UnpersistHandle] = None + ): DataFrameWriter[Row] = { if (partitionColumns.isEmpty) throw new IllegalArgumentException(s"partition columns must not be empty") @@ -661,15 +722,19 @@ package object spark extends Logging with SparkVersion with BuildVersion { val requiresCaching = writePartitionedByRequiresCaching(ds) (requiresCaching, unpersistHandle.isDefined) match { case (true, false) => - warning("Partitioned-writing with AQE enabled and Spark 3.0, 3.1, 3.2 below 3.2.3, " + - "and 3.3 below 3.3.2 requires caching an intermediate DataFrame, " + - "which calling code has to unpersist once writing is done. " + - "Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop. " + - "See https://issues.apache.org/jira/browse/SPARK-40588") + warning( + "Partitioned-writing with AQE enabled and Spark 3.0, 3.1, 3.2 below 3.2.3, " + + "and 3.3 below 3.3.2 requires caching an intermediate DataFrame, " + + "which calling code has to unpersist once writing is done. " + + "Please provide an UnpersistHandle to DataFrame.writePartitionedBy, or UnpersistHandle.Noop. " + + "See https://issues.apache.org/jira/browse/SPARK-40588" + ) case (false, true) if !unpersistHandle.get.isInstanceOf[NoopUnpersistHandle] => - info("UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as " + - "partitioned-writing with AQE disabled or Spark 3.2.3, 3.3.2 or 3.4 and above " + - "does not require caching intermediate DataFrame.") + info( + "UnpersistHandle provided to DataFrame.writePartitionedBy is not needed as " + + "partitioned-writing with AQE disabled or Spark 3.2.3, 3.3.2 or 3.4 and above " + + "does not require caching intermediate DataFrame." + ) unpersistHandle.get.setDataFrame(ds.sparkSession.emptyDataFrame) case _ => } @@ -680,24 +745,28 @@ package object spark extends Logging with SparkVersion with BuildVersion { val sortColumns = partitionColumnNames ++ moreFileColumns ++ moreFileOrder ds.toDF .call(ds => partitionColumnsMap.foldLeft(ds) { case (ds, (name, col)) => ds.withColumn(name, col) }) - .when(partitions.isEmpty).call(_.repartitionByRange(rangeColumns: _*)) - .when(partitions.isDefined).call(_.repartitionByRange(partitions.get, rangeColumns: _*)) + .when(partitions.isEmpty) + .call(_.repartitionByRange(rangeColumns: _*)) + .when(partitions.isDefined) + .call(_.repartitionByRange(partitions.get, rangeColumns: _*)) .sortWithinPartitions(sortColumns: _*) - .when(writtenProjection.isDefined).call(_.select(writtenProjection.get: _*)) - .when(requiresCaching && unpersistHandle.isDefined).call(unpersistHandle.get.setDataFrame(_)) + .when(writtenProjection.isDefined) + .call(_.select(writtenProjection.get: _*)) + .when(requiresCaching && unpersistHandle.isDefined) + .call(unpersistHandle.get.setDataFrame(_)) .write .partitionBy(partitionColumnsMap.keys.toSeq: _*) } /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns. + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns. * - * @see `org.apache.spark.sql.Dataset.groupByKey(T => K)` + * @see + * `org.apache.spark.sql.Dataset.groupByKey(T => K)` * - * @note Calling this method should be preferred to `groupByKey(T => K)` because the - * Catalyst query planner cannot exploit existing partitioning and ordering of - * this Dataset with that function. + * @note + * Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot + * exploit existing partitioning and ordering of this Dataset with that function. * * {{{ * ds.groupByKey[Int]($"age").flatMapGroups(...) @@ -708,14 +777,14 @@ package object spark extends Logging with SparkVersion with BuildVersion { ds.groupBy(column +: columns: _*).as[K, V] /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns. + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns. * - * @see `org.apache.spark.sql.Dataset.groupByKey(T => K)` + * @see + * `org.apache.spark.sql.Dataset.groupByKey(T => K)` * - * @note Calling this method should be preferred to `groupByKey(T => K)` because the - * Catalyst query planner cannot exploit existing partitioning and ordering of - * this Dataset with that function. + * @note + * Calling this method should be preferred to `groupByKey(T => K)` because the Catalyst query planner cannot + * exploit existing partitioning and ordering of this Dataset with that function. * * {{{ * ds.groupByKey[Int]($"age").flatMapGroups(...) @@ -726,9 +795,8 @@ package object spark extends Logging with SparkVersion with BuildVersion { ds.groupBy(column, columns: _*).as[K, V] /** - * Groups the Dataset and sorts the groups using the specified columns, so we can run - * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available - * functions. + * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted + * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions. * * {{{ * // Enumerate elements in the sorted group @@ -736,17 +804,18 @@ package object spark extends Logging with SparkVersion with BuildVersion { * .flatMapSortedGroups((key, it) => it.zipWithIndex) * }}} * - * @param cols grouping columns - * @param order sort columns + * @param cols + * grouping columns + * @param order + * sort columns */ - def groupBySorted[K: Ordering : Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = { + def groupBySorted[K: Ordering: Encoder](cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = { SortedGroupByDataset(ds, cols, order, None) } /** - * Groups the Dataset and sorts the groups using the specified columns, so we can run - * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available - * functions. + * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted + * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions. * * {{{ * // Enumerate elements in the sorted group @@ -754,18 +823,22 @@ package object spark extends Logging with SparkVersion with BuildVersion { * .flatMapSortedGroups((key, it) => it.zipWithIndex) * }}} * - * @param partitions number of partitions - * @param cols grouping columns - * @param order sort columns + * @param partitions + * number of partitions + * @param cols + * grouping columns + * @param order + * sort columns */ - def groupBySorted[K: Ordering : Encoder](partitions: Int)(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = { + def groupBySorted[K: Ordering: Encoder]( + partitions: Int + )(cols: Column*)(order: Column*): SortedGroupByDataset[K, V] = { SortedGroupByDataset(ds, cols, order, Some(partitions)) } /** - * Groups the Dataset and sorts the groups using the specified columns, so we can run - * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available - * functions. + * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted + * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions. * * {{{ * // Enumerate elements in the sorted group @@ -773,17 +846,21 @@ package object spark extends Logging with SparkVersion with BuildVersion { * .flatMapSortedGroups((key, it) => it.zipWithIndex) * }}} * - * @param partitions number of partitions - * @param key grouping key - * @param order sort key + * @param partitions + * number of partitions + * @param key + * grouping key + * @param order + * sort key */ - def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O): SortedGroupByDataset[K, V] = + def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)( + order: V => O + ): SortedGroupByDataset[K, V] = groupByKeySorted(key, Some(partitions))(order) /** - * Groups the Dataset and sorts the groups using the specified columns, so we can run - * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available - * functions. + * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted + * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions. * * {{{ * // Enumerate elements in the sorted group @@ -791,18 +868,24 @@ package object spark extends Logging with SparkVersion with BuildVersion { * .flatMapSortedGroups((key, it) => it.zipWithIndex) * }}} * - * @param partitions number of partitions - * @param key grouping key - * @param order sort key - * @param reverse sort reverse order + * @param partitions + * number of partitions + * @param key + * grouping key + * @param order + * sort key + * @param reverse + * sort reverse order */ - def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Int)(order: V => O, reverse: Boolean): SortedGroupByDataset[K, V] = + def groupByKeySorted[K: Ordering: Encoder, O: Encoder](key: V => K, partitions: Int)( + order: V => O, + reverse: Boolean + ): SortedGroupByDataset[K, V] = groupByKeySorted(key, Some(partitions))(order, reverse) /** - * Groups the Dataset and sorts the groups using the specified columns, so we can run - * further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available - * functions. + * Groups the Dataset and sorts the groups using the specified columns, so we can run further process the sorted + * groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available functions. * * {{{ * // Enumerate elements in the sorted group @@ -810,12 +893,19 @@ package object spark extends Logging with SparkVersion with BuildVersion { * .flatMapSortedGroups((key, it) => it.zipWithIndex) * }}} * - * @param partitions optional number of partitions - * @param key grouping key - * @param order sort key - * @param reverse sort reverse order + * @param partitions + * optional number of partitions + * @param key + * grouping key + * @param order + * sort key + * @param reverse + * sort reverse order */ - def groupByKeySorted[K: Ordering : Encoder, O: Encoder](key: V => K, partitions: Option[Int] = None)(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = { + def groupByKeySorted[K: Ordering: Encoder, O: Encoder]( + key: V => K, + partitions: Option[Int] = None + )(order: V => O, reverse: Boolean = false): SortedGroupByDataset[K, V] = { SortedGroupByDataset(ds, key, order, partitions, reverse) } @@ -856,34 +946,36 @@ package object spark extends Logging with SparkVersion with BuildVersion { * * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details. */ - def withRowNumbers(rowNumberColumnName: String, - storageLevel: StorageLevel, - order: Column*): DataFrame = - RowNumbers.withRowNumberColumnName(rowNumberColumnName).withStorageLevel(storageLevel).withOrderColumns(order).of(ds) + def withRowNumbers(rowNumberColumnName: String, storageLevel: StorageLevel, order: Column*): DataFrame = + RowNumbers + .withRowNumberColumnName(rowNumberColumnName) + .withStorageLevel(storageLevel) + .withOrderColumns(order) + .of(ds) /** * Adds a global continuous row number starting at 1. * * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details. */ - def withRowNumbers(rowNumberColumnName: String, - unpersistHandle: UnpersistHandle, - order: Column*): DataFrame = - RowNumbers.withRowNumberColumnName(rowNumberColumnName).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds) + def withRowNumbers(rowNumberColumnName: String, unpersistHandle: UnpersistHandle, order: Column*): DataFrame = + RowNumbers + .withRowNumberColumnName(rowNumberColumnName) + .withUnpersistHandle(unpersistHandle) + .withOrderColumns(order) + .of(ds) /** * Adds a global continuous row number starting at 1. * * See [[withRowNumbers(String,StorageLevel,UnpersistHandle,Column...)]] for details. */ - def withRowNumbers(storageLevel: StorageLevel, - unpersistHandle: UnpersistHandle, - order: Column*): DataFrame = + def withRowNumbers(storageLevel: StorageLevel, unpersistHandle: UnpersistHandle, order: Column*): DataFrame = RowNumbers.withStorageLevel(storageLevel).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds) /** - * Adds a global continuous row number starting at 1, after sorting rows by the given columns. - * When no columns are given, the existing order is used. + * Adds a global continuous row number starting at 1, after sorting rows by the given columns. When no columns are + * given, the existing order is used. * * Hence, the following examples are equivalent: * {{{ @@ -894,8 +986,8 @@ package object spark extends Logging with SparkVersion with BuildVersion { * The column name of the column with the row numbers can be set via the `rowNumberColumnName` argument. * * To avoid some known issues optimizing the query plan, this function has to internally call - * `Dataset.persist(StorageLevel)` on an intermediate DataFrame. The storage level of that cached - * DataFrame can be set via `storageLevel`, where the default is `StorageLevel.MEMORY_AND_DISK`. + * `Dataset.persist(StorageLevel)` on an intermediate DataFrame. The storage level of that cached DataFrame can be + * set via `storageLevel`, where the default is `StorageLevel.MEMORY_AND_DISK`. * * That cached intermediate DataFrame can be un-persisted / un-cached as follows: * {{{ @@ -906,23 +998,36 @@ package object spark extends Logging with SparkVersion with BuildVersion { * unpersist() * }}} * - * @param rowNumberColumnName name of the row number column - * @param storageLevel storage level of the cached intermediate DataFrame - * @param unpersistHandle handle to un-persist intermediate DataFrame - * @param order columns to order dataframe before assigning row numbers - * @return dataframe with row numbers + * @param rowNumberColumnName + * name of the row number column + * @param storageLevel + * storage level of the cached intermediate DataFrame + * @param unpersistHandle + * handle to un-persist intermediate DataFrame + * @param order + * columns to order dataframe before assigning row numbers + * @return + * dataframe with row numbers */ - def withRowNumbers(rowNumberColumnName: String, - storageLevel: StorageLevel, - unpersistHandle: UnpersistHandle, - order: Column*): DataFrame = - RowNumbers.withRowNumberColumnName(rowNumberColumnName).withStorageLevel(storageLevel).withUnpersistHandle(unpersistHandle).withOrderColumns(order).of(ds) + def withRowNumbers( + rowNumberColumnName: String, + storageLevel: StorageLevel, + unpersistHandle: UnpersistHandle, + order: Column* + ): DataFrame = + RowNumbers + .withRowNumberColumnName(rowNumberColumnName) + .withStorageLevel(storageLevel) + .withUnpersistHandle(unpersistHandle) + .withOrderColumns(order) + .of(ds) } /** * Class to extend a Spark Dataframe. * - * @param df dataframe + * @param df + * dataframe */ @deprecated("Implicit class ExtendedDataframe is deprecated, please recompile your source code.", since = "2.9.0") class ExtendedDataframe(df: DataFrame) extends ExtendedDataset[Row](df, df.encoder) @@ -930,14 +1035,16 @@ package object spark extends Logging with SparkVersion with BuildVersion { /** * Class to extend a Spark Dataframe. * - * @param df dataframe + * @param df + * dataframe */ def ExtendedDataframe(df: DataFrame): ExtendedDataframe = new ExtendedDataframe(df) /** * Implicit class to extend a Spark Dataframe, which is a Dataset[Row]. * - * @param df dataframe + * @param df + * dataframe */ implicit class ExtendedDataframeV2(df: DataFrame) extends ExtendedDatasetV2[Row](df) diff --git a/src/main/scala/uk/co/gresearch/spark/parquet/package.scala b/src/main/scala/uk/co/gresearch/spark/parquet/package.scala index 8291d7d1..d1f91380 100644 --- a/src/main/scala/uk/co/gresearch/spark/parquet/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/parquet/package.scala @@ -34,31 +34,35 @@ package object parquet { /** * Implicit class to extend a Spark DataFrameReader. * - * @param reader data frame reader + * @param reader + * data frame reader */ implicit class ExtendedDataFrameReader(reader: DataFrameReader) { + /** * Read the metadata of Parquet files into a Dataframe. * - * The returned DataFrame has as many partitions as there are Parquet files, - * at most `spark.sparkContext.defaultParallelism` partitions. + * The returned DataFrame has as many partitions as there are Parquet files, at most + * `spark.sparkContext.defaultParallelism` partitions. * * This provides the following per-file information: - * - filename (string): The file name - * - blocks (int): Number of blocks / RowGroups in the Parquet file - * - compressedBytes (long): Number of compressed bytes of all blocks - * - uncompressedBytes (long): Number of uncompressed bytes of all blocks - * - rows (long): Number of rows in the file - * - columns (int): Number of columns in the file - * - values (long): Number of values in the file - * - nulls (long): Number of null values in the file - * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file - * - schema (string): The schema - * - encryption (string): The encryption - * - keyValues (string-to-string map): Key-value data of the file + * - filename (string): The file name + * - blocks (int): Number of blocks / RowGroups in the Parquet file + * - compressedBytes (long): Number of compressed bytes of all blocks + * - uncompressedBytes (long): Number of uncompressed bytes of all blocks + * - rows (long): Number of rows in the file + * - columns (int): Number of columns in the file + * - values (long): Number of values in the file + * - nulls (long): Number of null values in the file + * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file + * - schema (string): The schema + * - encryption (string): The encryption + * - keyValues (string-to-string map): Key-value data of the file * - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet metadata + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet metadata */ @scala.annotation.varargs def parquetMetadata(paths: String*): DataFrame = parquetMetadata(None, paths) @@ -69,22 +73,25 @@ package object parquet { * The returned DataFrame has as many partitions as specified via `parallelism`. * * This provides the following per-file information: - * - filename (string): The file name - * - blocks (int): Number of blocks / RowGroups in the Parquet file - * - compressedBytes (long): Number of compressed bytes of all blocks - * - uncompressedBytes (long): Number of uncompressed bytes of all blocks - * - rows (long): Number of rows in the file - * - columns (int): Number of columns in the file - * - values (long): Number of values in the file - * - nulls (long): Number of null values in the file - * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file - * - schema (string): The schema - * - encryption (string): The encryption - * - keyValues (string-to-string map): Key-value data of the file + * - filename (string): The file name + * - blocks (int): Number of blocks / RowGroups in the Parquet file + * - compressedBytes (long): Number of compressed bytes of all blocks + * - uncompressedBytes (long): Number of uncompressed bytes of all blocks + * - rows (long): Number of rows in the file + * - columns (int): Number of columns in the file + * - values (long): Number of values in the file + * - nulls (long): Number of null values in the file + * - createdBy (string): The createdBy string of the Parquet file, e.g. library used to write the file + * - schema (string): The schema + * - encryption (string): The encryption + * - keyValues (string-to-string map): Key-value data of the file * - * @param parallelism number of partitions of returned DataFrame - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet metadata + * @param parallelism + * number of partitions of returned DataFrame + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet metadata */ @scala.annotation.varargs def parquetMetadata(parallelism: Int, paths: String*): DataFrame = parquetMetadata(Some(parallelism), paths) @@ -94,64 +101,70 @@ package object parquet { import files.sparkSession.implicits._ - files.flatMap { case (_, file) => - readFooters(file).map { footer => - ( - footer.getFile.toString, - footer.getParquetMetadata.getBlocks.size(), - footer.getParquetMetadata.getBlocks.asScala.map(_.getCompressedSize).sum, - footer.getParquetMetadata.getBlocks.asScala.map(_.getTotalByteSize).sum, - footer.getParquetMetadata.getBlocks.asScala.map(_.getRowCount).sum, - footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.size(), - footer.getParquetMetadata.getBlocks.asScala.map(_.getColumns.map(_.getValueCount).sum).sum, - // when all columns have statistics, count the null values - Option(footer.getParquetMetadata.getBlocks.asScala.flatMap(_.getColumns.map(c => Option(c.getStatistics)))) - .filter(_.forall(_.isDefined)) - .map(_.map(_.get.getNumNulls).sum), - footer.getParquetMetadata.getFileMetaData.getCreatedBy, - footer.getParquetMetadata.getFileMetaData.getSchema.toString, - FileMetaDataUtil.getEncryptionType(footer.getParquetMetadata.getFileMetaData), - footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData.asScala, - ) + files + .flatMap { case (_, file) => + readFooters(file).map { footer => + ( + footer.getFile.toString, + footer.getParquetMetadata.getBlocks.size(), + footer.getParquetMetadata.getBlocks.asScala.map(_.getCompressedSize).sum, + footer.getParquetMetadata.getBlocks.asScala.map(_.getTotalByteSize).sum, + footer.getParquetMetadata.getBlocks.asScala.map(_.getRowCount).sum, + footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.size(), + footer.getParquetMetadata.getBlocks.asScala.map(_.getColumns.map(_.getValueCount).sum).sum, + // when all columns have statistics, count the null values + Option( + footer.getParquetMetadata.getBlocks.asScala.flatMap(_.getColumns.map(c => Option(c.getStatistics))) + ) + .filter(_.forall(_.isDefined)) + .map(_.map(_.get.getNumNulls).sum), + footer.getParquetMetadata.getFileMetaData.getCreatedBy, + footer.getParquetMetadata.getFileMetaData.getSchema.toString, + FileMetaDataUtil.getEncryptionType(footer.getParquetMetadata.getFileMetaData), + footer.getParquetMetadata.getFileMetaData.getKeyValueMetaData.asScala, + ) + } } - }.toDF( - "filename", - "blocks", - "compressedBytes", - "uncompressedBytes", - "rows", - "columns", - "values", - "nulls", - "createdBy", - "schema", - "encryption", - "keyValues" - ) + .toDF( + "filename", + "blocks", + "compressedBytes", + "uncompressedBytes", + "rows", + "columns", + "values", + "nulls", + "createdBy", + "schema", + "encryption", + "keyValues" + ) } /** * Read the schema of Parquet files into a Dataframe. * - * The returned DataFrame has as many partitions as there are Parquet files, - * at most `spark.sparkContext.defaultParallelism` partitions. + * The returned DataFrame has as many partitions as there are Parquet files, at most + * `spark.sparkContext.defaultParallelism` partitions. * * This provides the following per-file information: - * - filename (string): The Parquet file name - * - columnName (string): The column name - * - columnPath (string array): The column path - * - repetition (string): The repetition - * - type (string): The data type - * - length (int): The length of the type - * - originalType (string): The original type - * - isPrimitive (boolean: True if type is primitive - * - primitiveType (string: The primitive type - * - primitiveOrder (string: The order of the primitive type - * - maxDefinitionLevel (int): The max definition level - * - maxRepetitionLevel (int): The max repetition level + * - filename (string): The Parquet file name + * - columnName (string): The column name + * - columnPath (string array): The column path + * - repetition (string): The repetition + * - type (string): The data type + * - length (int): The length of the type + * - originalType (string): The original type + * - isPrimitive (boolean: True if type is primitive + * - primitiveType (string: The primitive type + * - primitiveOrder (string: The order of the primitive type + * - maxDefinitionLevel (int): The max definition level + * - maxRepetitionLevel (int): The max repetition level * - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet metadata + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet metadata */ @scala.annotation.varargs def parquetSchema(paths: String*): DataFrame = parquetSchema(None, paths) @@ -162,22 +175,25 @@ package object parquet { * The returned DataFrame has as many partitions as specified via `parallelism`. * * This provides the following per-file information: - * - filename (string): The Parquet file name - * - columnName (string): The column name - * - columnPath (string array): The column path - * - repetition (string): The repetition - * - type (string): The data type - * - length (int): The length of the type - * - originalType (string): The original type - * - isPrimitive (boolean: True if type is primitive - * - primitiveType (string: The primitive type - * - primitiveOrder (string: The order of the primitive type - * - maxDefinitionLevel (int): The max definition level - * - maxRepetitionLevel (int): The max repetition level + * - filename (string): The Parquet file name + * - columnName (string): The column name + * - columnPath (string array): The column path + * - repetition (string): The repetition + * - type (string): The data type + * - length (int): The length of the type + * - originalType (string): The original type + * - isPrimitive (boolean: True if type is primitive + * - primitiveType (string: The primitive type + * - primitiveOrder (string: The order of the primitive type + * - maxDefinitionLevel (int): The max definition level + * - maxRepetitionLevel (int): The max repetition level * - * @param parallelism number of partitions of returned DataFrame - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet metadata + * @param parallelism + * number of partitions of returned DataFrame + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet metadata */ @scala.annotation.varargs def parquetSchema(parallelism: Int, paths: String*): DataFrame = parquetSchema(Some(parallelism), paths) @@ -187,62 +203,66 @@ package object parquet { import files.sparkSession.implicits._ - files.flatMap { case (_, file) => - readFooters(file).flatMap { footer => - footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.map { column => - ( - footer.getFile.toString, - Option(column.getPrimitiveType).map(_.getName), - column.getPath, - Option(column.getPrimitiveType).flatMap(v => Option(v.getRepetition)).map(_.name), - Option(column.getPrimitiveType).flatMap(v => Option(v.getPrimitiveTypeName)).map(_.name), - Option(column.getPrimitiveType).map(_.getTypeLength), - Option(column.getPrimitiveType).flatMap(v => Option(v.getOriginalType)).map(_.name), - Option(column.getPrimitiveType).flatMap(PrimitiveTypeUtil.getLogicalTypeAnnotation), - column.getPrimitiveType.isPrimitive, - Option(column.getPrimitiveType).map(_.getPrimitiveTypeName.name), - Option(column.getPrimitiveType).flatMap(v => Option(v.columnOrder)).map(_.getColumnOrderName.name), - column.getMaxDefinitionLevel, - column.getMaxRepetitionLevel, - ) + files + .flatMap { case (_, file) => + readFooters(file).flatMap { footer => + footer.getParquetMetadata.getFileMetaData.getSchema.getColumns.map { column => + ( + footer.getFile.toString, + Option(column.getPrimitiveType).map(_.getName), + column.getPath, + Option(column.getPrimitiveType).flatMap(v => Option(v.getRepetition)).map(_.name), + Option(column.getPrimitiveType).flatMap(v => Option(v.getPrimitiveTypeName)).map(_.name), + Option(column.getPrimitiveType).map(_.getTypeLength), + Option(column.getPrimitiveType).flatMap(v => Option(v.getOriginalType)).map(_.name), + Option(column.getPrimitiveType).flatMap(PrimitiveTypeUtil.getLogicalTypeAnnotation), + column.getPrimitiveType.isPrimitive, + Option(column.getPrimitiveType).map(_.getPrimitiveTypeName.name), + Option(column.getPrimitiveType).flatMap(v => Option(v.columnOrder)).map(_.getColumnOrderName.name), + column.getMaxDefinitionLevel, + column.getMaxRepetitionLevel, + ) + } } } - }.toDF( - "filename", - "columnName", - "columnPath", - "repetition", - "type", - "length", - "originalType", - "logicalType", - "isPrimitive", - "primitiveType", - "primitiveOrder", - "maxDefinitionLevel", - "maxRepetitionLevel", - ) + .toDF( + "filename", + "columnName", + "columnPath", + "repetition", + "type", + "length", + "originalType", + "logicalType", + "isPrimitive", + "primitiveType", + "primitiveOrder", + "maxDefinitionLevel", + "maxRepetitionLevel", + ) } /** * Read the metadata of Parquet blocks into a Dataframe. * - * The returned DataFrame has as many partitions as there are Parquet files, - * at most `spark.sparkContext.defaultParallelism` partitions. + * The returned DataFrame has as many partitions as there are Parquet files, at most + * `spark.sparkContext.defaultParallelism` partitions. * * This provides the following per-block information: - * - filename (string): The file name - * - block (int): Block / RowGroup number starting at 1 - * - blockStart (long): Start position of the block in the Parquet file - * - compressedBytes (long): Number of compressed bytes in block - * - uncompressedBytes (long): Number of uncompressed bytes in block - * - rows (long): Number of rows in block - * - columns (int): Number of columns in block - * - values (long): Number of values in block - * - nulls (long): Number of null values in block + * - filename (string): The file name + * - block (int): Block / RowGroup number starting at 1 + * - blockStart (long): Start position of the block in the Parquet file + * - compressedBytes (long): Number of compressed bytes in block + * - uncompressedBytes (long): Number of uncompressed bytes in block + * - rows (long): Number of rows in block + * - columns (int): Number of columns in block + * - values (long): Number of values in block + * - nulls (long): Number of null values in block * - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet block metadata + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet block metadata */ @scala.annotation.varargs def parquetBlocks(paths: String*): DataFrame = parquetBlocks(None, paths) @@ -253,19 +273,22 @@ package object parquet { * The returned DataFrame has as many partitions as specified via `parallelism`. * * This provides the following per-block information: - * - filename (string): The file name - * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1) - * - blockStart (long): Start position of the block in the Parquet file - * - compressedBytes (long): Number of compressed bytes in block - * - uncompressedBytes (long): Number of uncompressed bytes in block - * - rows (long): Number of rows in block - * - columns (int): Number of columns in block - * - values (long): Number of values in block - * - nulls (long): Number of null values in block + * - filename (string): The file name + * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1) + * - blockStart (long): Start position of the block in the Parquet file + * - compressedBytes (long): Number of compressed bytes in block + * - uncompressedBytes (long): Number of uncompressed bytes in block + * - rows (long): Number of rows in block + * - columns (int): Number of columns in block + * - values (long): Number of values in block + * - nulls (long): Number of null values in block * - * @param parallelism number of partitions of returned DataFrame - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet block metadata + * @param parallelism + * number of partitions of returned DataFrame + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet block metadata */ @scala.annotation.varargs def parquetBlocks(parallelism: Int, paths: String*): DataFrame = parquetBlocks(Some(parallelism), paths) @@ -275,61 +298,65 @@ package object parquet { import files.sparkSession.implicits._ - files.flatMap { case (_, file) => - readFooters(file).flatMap { footer => - footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.map { case (block, idx) => - ( - footer.getFile.toString, - BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1, - block.getStartingPos, - block.getCompressedSize, - block.getTotalByteSize, - block.getRowCount, - block.getColumns.asScala.size, - block.getColumns.asScala.map(_.getValueCount).sum, - // when all columns have statistics, count the null values - Option(block.getColumns.asScala.map(c => Option(c.getStatistics))) - .filter(_.forall(_.isDefined)) - .map(_.map(_.get.getNumNulls).sum), - ) + files + .flatMap { case (_, file) => + readFooters(file).flatMap { footer => + footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.map { case (block, idx) => + ( + footer.getFile.toString, + BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1, + block.getStartingPos, + block.getCompressedSize, + block.getTotalByteSize, + block.getRowCount, + block.getColumns.asScala.size, + block.getColumns.asScala.map(_.getValueCount).sum, + // when all columns have statistics, count the null values + Option(block.getColumns.asScala.map(c => Option(c.getStatistics))) + .filter(_.forall(_.isDefined)) + .map(_.map(_.get.getNumNulls).sum), + ) + } } } - }.toDF( - "filename", - "block", - "blockStart", - "compressedBytes", - "uncompressedBytes", - "rows", - "columns", - "values", - "nulls" - ) + .toDF( + "filename", + "block", + "blockStart", + "compressedBytes", + "uncompressedBytes", + "rows", + "columns", + "values", + "nulls" + ) } /** * Read the metadata of Parquet block columns into a Dataframe. * - * The returned DataFrame has as many partitions as there are Parquet files, - * at most `spark.sparkContext.defaultParallelism` partitions. + * The returned DataFrame has as many partitions as there are Parquet files, at most + * `spark.sparkContext.defaultParallelism` partitions. * * This provides the following per-block-column information: - * - filename (string): The file name - * - block (int): Block / RowGroup number starting at 1 - * - column (string): Block / RowGroup column name - * - codec (string): The coded used to compress the block column values - * - type (string): The data type of the block column - * - encodings (string): Encodings of the block column - * - minValue (string): Minimum value of this column in this block - * - maxValue (string): Maximum value of this column in this block - * - columnStart (long): Start position of the block column in the Parquet file - * - compressedBytes (long): Number of compressed bytes of this block column - * - uncompressedBytes (long): Number of uncompressed bytes of this block column - * - values (long): Number of values in this block column - * - nulls (long): Number of null values in block + * - filename (string): The file name + * - block (int): Block / RowGroup number starting at 1 + * - column (string): Block / RowGroup column name + * - codec (string): The coded used to compress the block column values + * - type (string): The data type of the block column + * - encodings (string): Encodings of the block column + * - minValue (string): Minimum value of this column in this block + * - maxValue (string): Maximum value of this column in this block + * - columnStart (long): Start position of the block column in the Parquet file + * - compressedBytes (long): Number of compressed bytes of this block column + * - uncompressedBytes (long): Number of uncompressed bytes of this block column + * - values (long): Number of values in this block column + * - nulls (long): Number of null values in block * - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet block metadata + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet block metadata */ @scala.annotation.varargs def parquetBlockColumns(paths: String*): DataFrame = parquetBlockColumns(None, paths) @@ -340,23 +367,26 @@ package object parquet { * The returned DataFrame has as many partitions as specified via `parallelism`. * * This provides the following per-block-column information: - * - filename (string): The file name - * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1) - * - column (string): Block / RowGroup column name - * - codec (string): The coded used to compress the block column values - * - type (string): The data type of the block column - * - encodings (string): Encodings of the block column - * - minValue (string): Minimum value of this column in this block - * - maxValue (string): Maximum value of this column in this block - * - columnStart (long): Start position of the block column in the Parquet file - * - compressedBytes (long): Number of compressed bytes of this block column - * - uncompressedBytes (long): Number of uncompressed bytes of this block column - * - values (long): Number of values in this block column - * - nulls (long): Number of null values in block + * - filename (string): The file name + * - block (int): Block / RowGroup number starting at 1 (block ordinal + 1) + * - column (string): Block / RowGroup column name + * - codec (string): The coded used to compress the block column values + * - type (string): The data type of the block column + * - encodings (string): Encodings of the block column + * - minValue (string): Minimum value of this column in this block + * - maxValue (string): Maximum value of this column in this block + * - columnStart (long): Start position of the block column in the Parquet file + * - compressedBytes (long): Number of compressed bytes of this block column + * - uncompressedBytes (long): Number of uncompressed bytes of this block column + * - values (long): Number of values in this block column + * - nulls (long): Number of null values in block * - * @param parallelism number of partitions of returned DataFrame - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Parquet block metadata + * @param parallelism + * number of partitions of returned DataFrame + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Parquet block metadata */ @scala.annotation.varargs def parquetBlockColumns(parallelism: Int, paths: String*): DataFrame = parquetBlockColumns(Some(parallelism), paths) @@ -366,67 +396,71 @@ package object parquet { import files.sparkSession.implicits._ - files.flatMap { case (_, file) => - readFooters(file).flatMap { footer => - footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.flatMap { case (block, idx) => - block.getColumns.asScala.map { column => - ( - footer.getFile.toString, - BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1, - column.getPath.toSeq, - column.getCodec.toString, - column.getPrimitiveType.toString, - column.getEncodings.asScala.toSeq.map(_.toString).sorted, - Option(column.getStatistics).map(_.minAsString), - Option(column.getStatistics).map(_.maxAsString), - column.getStartingPos, - column.getTotalSize, - column.getTotalUncompressedSize, - column.getValueCount, - Option(column.getStatistics).map(_.getNumNulls), - ) + files + .flatMap { case (_, file) => + readFooters(file).flatMap { footer => + footer.getParquetMetadata.getBlocks.asScala.zipWithIndex.flatMap { case (block, idx) => + block.getColumns.asScala.map { column => + ( + footer.getFile.toString, + BlockMetaDataUtil.getOrdinal(block).getOrElse(idx) + 1, + column.getPath.toSeq, + column.getCodec.toString, + column.getPrimitiveType.toString, + column.getEncodings.asScala.toSeq.map(_.toString).sorted, + Option(column.getStatistics).map(_.minAsString), + Option(column.getStatistics).map(_.maxAsString), + column.getStartingPos, + column.getTotalSize, + column.getTotalUncompressedSize, + column.getValueCount, + Option(column.getStatistics).map(_.getNumNulls), + ) + } } } } - }.toDF( - "filename", - "block", - "column", - "codec", - "type", - "encodings", - "minValue", - "maxValue", - "columnStart", - "compressedBytes", - "uncompressedBytes", - "values", - "nulls" - ) + .toDF( + "filename", + "block", + "column", + "codec", + "type", + "encodings", + "minValue", + "maxValue", + "columnStart", + "compressedBytes", + "uncompressedBytes", + "values", + "nulls" + ) } /** * Read the metadata of how Spark partitions Parquet files into a Dataframe. * - * The returned DataFrame has as many partitions as there are Parquet files, - * at most `spark.sparkContext.defaultParallelism` partitions. + * The returned DataFrame has as many partitions as there are Parquet files, at most + * `spark.sparkContext.defaultParallelism` partitions. * * This provides the following per-partition information: - * - partition (int): The Spark partition id - * - start (long): The start position of the partition - * - end (long): The end position of the partition - * - length (long): The length of the partition - * - blocks (int): The number of Parquet blocks / RowGroups in this partition - * - compressedBytes (long): The number of compressed bytes in this partition - * - uncompressedBytes (long): The number of uncompressed bytes in this partition - * - rows (long): The number of rows in this partition - * - columns (int): Number of columns in the file - * - values (long): The number of values in this partition - * - filename (string): The Parquet file name - * - fileLength (long): The length of the Parquet file + * - partition (int): The Spark partition id + * - start (long): The start position of the partition + * - end (long): The end position of the partition + * - length (long): The length of the partition + * - blocks (int): The number of Parquet blocks / RowGroups in this partition + * - compressedBytes (long): The number of compressed bytes in this partition + * - uncompressedBytes (long): The number of uncompressed bytes in this partition + * - rows (long): The number of rows in this partition + * - columns (int): Number of columns in the file + * - values (long): The number of values in this partition + * - filename (string): The Parquet file name + * - fileLength (long): The length of the Parquet file * - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Spark Parquet partition metadata + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Spark Parquet partition metadata */ @scala.annotation.varargs def parquetPartitions(paths: String*): DataFrame = parquetPartitions(None, paths) @@ -437,22 +471,25 @@ package object parquet { * The returned DataFrame has as many partitions as specified via `parallelism`. * * This provides the following per-partition information: - * - partition (int): The Spark partition id - * - start (long): The start position of the partition - * - end (long): The end position of the partition - * - length (long): The length of the partition - * - blocks (int): The number of Parquet blocks / RowGroups in this partition - * - compressedBytes (long): The number of compressed bytes in this partition - * - uncompressedBytes (long): The number of uncompressed bytes in this partition - * - rows (long): The number of rows in this partition - * - columns (int): Number of columns in the file - * - values (long): The number of values in this partition - * - filename (string): The Parquet file name - * - fileLength (long): The length of the Parquet file + * - partition (int): The Spark partition id + * - start (long): The start position of the partition + * - end (long): The end position of the partition + * - length (long): The length of the partition + * - blocks (int): The number of Parquet blocks / RowGroups in this partition + * - compressedBytes (long): The number of compressed bytes in this partition + * - uncompressedBytes (long): The number of uncompressed bytes in this partition + * - rows (long): The number of rows in this partition + * - columns (int): Number of columns in the file + * - values (long): The number of values in this partition + * - filename (string): The Parquet file name + * - fileLength (long): The length of the Parquet file * - * @param parallelism number of partitions of returned DataFrame - * @param paths one or more paths to Parquet files or directories - * @return dataframe with Spark Parquet partition metadata + * @param parallelism + * number of partitions of returned DataFrame + * @param paths + * one or more paths to Parquet files or directories + * @return + * dataframe with Spark Parquet partition metadata */ @scala.annotation.varargs def parquetPartitions(parallelism: Int, paths: String*): DataFrame = parquetPartitions(Some(parallelism), paths) @@ -462,56 +499,69 @@ package object parquet { import files.sparkSession.implicits._ - files.flatMap { case (part, file) => - readFooters(file) - .map(footer => (footer, getBlocks(footer, file.start, file.length))) - .map { case (footer, blocks) => ( - part, - file.start, - file.start + file.length, - file.length, - blocks.size, - blocks.map(_.getCompressedSize).sum, - blocks.map(_.getTotalByteSize).sum, - blocks.map(_.getRowCount).sum, - blocks.map(_.getColumns.map(_.getPath.mkString(".")).toSet).foldLeft(Set.empty[String])((left, right) => left.union(right)).size, - blocks.map(_.getColumns.asScala.map(_.getValueCount).sum).sum, - // when all columns have statistics, count the null values - Option(blocks.flatMap(_.getColumns.asScala.map(c => Option(c.getStatistics)))) - .filter(_.forall(_.isDefined)) - .map(_.map(_.get.getNumNulls).sum), - footer.getFile.toString, - file.fileSize, - )} - }.toDF( - "partition", - "start", - "end", - "length", - "blocks", - "compressedBytes", - "uncompressedBytes", - "rows", - "columns", - "values", - "nulls", - "filename", - "fileLength" - ) + files + .flatMap { case (part, file) => + readFooters(file) + .map(footer => (footer, getBlocks(footer, file.start, file.length))) + .map { case (footer, blocks) => + ( + part, + file.start, + file.start + file.length, + file.length, + blocks.size, + blocks.map(_.getCompressedSize).sum, + blocks.map(_.getTotalByteSize).sum, + blocks.map(_.getRowCount).sum, + blocks + .map(_.getColumns.map(_.getPath.mkString(".")).toSet) + .foldLeft(Set.empty[String])((left, right) => left.union(right)) + .size, + blocks.map(_.getColumns.asScala.map(_.getValueCount).sum).sum, + // when all columns have statistics, count the null values + Option(blocks.flatMap(_.getColumns.asScala.map(c => Option(c.getStatistics)))) + .filter(_.forall(_.isDefined)) + .map(_.map(_.get.getNumNulls).sum), + footer.getFile.toString, + file.fileSize, + ) + } + } + .toDF( + "partition", + "start", + "end", + "length", + "blocks", + "compressedBytes", + "uncompressedBytes", + "rows", + "columns", + "values", + "nulls", + "filename", + "fileLength" + ) } private def getFiles(parallelism: Option[Int], paths: Seq[String]): Dataset[(Int, SplitFile)] = { val df = reader.parquet(paths: _*) - val parts = df.rdd.partitions.flatMap(part => - part.asInstanceOf[FilePartition] - .files - .map(file => (part.index, SplitFile(file))) - ).toSeq.distinct + val parts = df.rdd.partitions + .flatMap(part => + part + .asInstanceOf[FilePartition] + .files + .map(file => (part.index, SplitFile(file))) + ) + .toSeq + .distinct import df.sparkSession.implicits._ - parts.toDS() - .when(parallelism.isDefined).call(_.repartition(parallelism.get)) + parts + .toDS() + .when(parallelism.isDefined) + .call(_.repartition(parallelism.get)) } } diff --git a/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala b/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala index 5a7dca51..d473965a 100644 --- a/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/GroupBySuite.scala @@ -92,7 +92,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { testGroupByIdSortBySeq(ds.groupByKeySorted(v => v.id)(v => (v.seq, v.value))) testGroupByIdSortBySeqDesc(ds.groupByKeySorted(v => v.id)(v => (v.seq, v.value), reverse = true)) testGroupByIdSortBySeqWithPartitionNum(ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value))) - testGroupByIdSortBySeqDescWithPartitionNum(ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value), reverse = true)) + testGroupByIdSortBySeqDescWithPartitionNum( + ds.groupByKeySorted(v => v.id, partitions = Some(10))(v => (v.seq, v.value), reverse = true) + ) testGroupByIdSeqSortByValue(ds.groupByKeySorted(v => (v.id, v.seq))(v => v.value)) } @@ -106,14 +108,19 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { describe("df.groupByKeySorted") { testGroupByIdSortBySeq(df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2)))) - testGroupByIdSortBySeqDesc(df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2)), reverse = true)) - testGroupByIdSortBySeqWithPartitionNum(df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)))) - testGroupByIdSortBySeqDescWithPartitionNum(df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)), reverse = true)) + testGroupByIdSortBySeqDesc( + df.groupByKeySorted(v => v.getInt(0))(v => (v.getInt(1), v.getDouble(2)), reverse = true) + ) + testGroupByIdSortBySeqWithPartitionNum( + df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2))) + ) + testGroupByIdSortBySeqDescWithPartitionNum( + df.groupByKeySorted(v => v.getInt(0), partitions = Some(10))(v => (v.getInt(1), v.getDouble(2)), reverse = true) + ) testGroupByIdSeqSortByValue(df.groupByKeySorted(v => (v.getInt(0), v.getInt(1)))(v => v.getDouble(2))) } - def testGroupByIdSortBySeq[T](ds: SortedGroupByDataset[Int, T]) - (implicit asTuple: T => (Int, Int, Double)): Unit = { + def testGroupByIdSortBySeq[T](ds: SortedGroupByDataset[Int, T])(implicit asTuple: T => (Int, Int, Double)): Unit = { it("should flatMapSortedGroups") { val actual = ds @@ -167,8 +174,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { } - def testGroupByIdSortBySeqDesc[T](ds: SortedGroupByDataset[Int, T]) - (implicit asTuple: T => (Int, Int, Double)): Unit = { + def testGroupByIdSortBySeqDesc[T]( + ds: SortedGroupByDataset[Int, T] + )(implicit asTuple: T => (Int, Int, Double)): Unit = { it("should flatMapSortedGroups reverse") { val actual = ds .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1)))) @@ -196,8 +204,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { } - def testGroupByIdSortBySeqWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10) - (implicit asTuple: T => (Int, Int, Double)): Unit = { + def testGroupByIdSortBySeqWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)(implicit + asTuple: T => (Int, Int, Double) + ): Unit = { it("should flatMapSortedGroups with partition num") { val grouped = ds @@ -230,8 +239,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { } - def testGroupByIdSortBySeqDescWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10) - (implicit asTuple: T => (Int, Int, Double)): Unit = { + def testGroupByIdSortBySeqDescWithPartitionNum[T](ds: SortedGroupByDataset[Int, T], partitions: Int = 10)(implicit + asTuple: T => (Int, Int, Double) + ): Unit = { it("should flatMapSortedGroups with partition num and reverse") { val grouped = ds .flatMapSortedGroups((key, it) => it.zipWithIndex.map(v => (key, v._2, asTuple(v._1)))) @@ -262,8 +272,9 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { } } - def testGroupByIdSeqSortByValue[T](ds: SortedGroupByDataset[(Int, Int), T]) - (implicit asTuple: T => (Int, Int, Double)): Unit = { + def testGroupByIdSeqSortByValue[T]( + ds: SortedGroupByDataset[(Int, Int), T] + )(implicit asTuple: T => (Int, Int, Double)): Unit = { it("should flatMapSortedGroups with tuple key") { val actual = ds @@ -323,7 +334,6 @@ class GroupBySuite extends AnyFunSpec with SparkTestSession { } - object GroupBySortedSuite { implicit def valueToTuple(value: Val): (Int, Int, Double) = (value.id, value.seq, value.value) implicit def valueRowToTuple(value: Row): (Int, Int, Double) = (value.getInt(0), value.getInt(1), value.getDouble(2)) diff --git a/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala b/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala index 0cd1d4ac..04084dcb 100644 --- a/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/HistogramSuite.scala @@ -63,34 +63,40 @@ class HistogramSuite extends AnyFunSuite with SparkTestSession { Seq(3, 0, 1, 0, 1, 0, 1), Seq(4, 0, 0, 1, 0, 0, 0) ) - val expectedSchema: StructType = StructType(Seq( - StructField("id", IntegerType, nullable = false), - StructField("≤-200", LongType, nullable = true), - StructField("≤-100", LongType, nullable = true), - StructField("≤0", LongType, nullable = true), - StructField("≤100", LongType, nullable = true), - StructField("≤200", LongType, nullable = true), - StructField(">200", LongType, nullable = true) - )) - val expectedSchema2: StructType = StructType(Seq( - StructField("id", IntegerType, nullable = false), - StructField("title", StringType, nullable = true), - StructField("≤-200", LongType, nullable = true), - StructField("≤-100", LongType, nullable = true), - StructField("≤0", LongType, nullable = true), - StructField("≤100", LongType, nullable = true), - StructField("≤200", LongType, nullable = true), - StructField(">200", LongType, nullable = true) - )) - val expectedDoubleSchema: StructType = StructType(Seq( - StructField("id", IntegerType, nullable = false), - StructField("≤-200.0", LongType, nullable = true), - StructField("≤-100.0", LongType, nullable = true), - StructField("≤0.0", LongType, nullable = true), - StructField("≤100.0", LongType, nullable = true), - StructField("≤200.0", LongType, nullable = true), - StructField(">200.0", LongType, nullable = true) - )) + val expectedSchema: StructType = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("≤-200", LongType, nullable = true), + StructField("≤-100", LongType, nullable = true), + StructField("≤0", LongType, nullable = true), + StructField("≤100", LongType, nullable = true), + StructField("≤200", LongType, nullable = true), + StructField(">200", LongType, nullable = true) + ) + ) + val expectedSchema2: StructType = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("title", StringType, nullable = true), + StructField("≤-200", LongType, nullable = true), + StructField("≤-100", LongType, nullable = true), + StructField("≤0", LongType, nullable = true), + StructField("≤100", LongType, nullable = true), + StructField("≤200", LongType, nullable = true), + StructField(">200", LongType, nullable = true) + ) + ) + val expectedDoubleSchema: StructType = StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("≤-200.0", LongType, nullable = true), + StructField("≤-100.0", LongType, nullable = true), + StructField("≤0.0", LongType, nullable = true), + StructField("≤100.0", LongType, nullable = true), + StructField("≤200.0", LongType, nullable = true), + StructField(">200.0", LongType, nullable = true) + ) + ) test("histogram with no aggregate columns") { val histogram = ints.histogram(intThresholds, $"value") @@ -110,12 +116,14 @@ class HistogramSuite extends AnyFunSuite with SparkTestSession { val histogram = ints.histogram(intThresholds, $"value", $"id", $"title") val actual = histogram.orderBy($"id").collect().toSeq.map(_.toSeq) assert(histogram.schema === expectedSchema2) - assert(actual === Seq( - Seq(1, "one", 0, 0, 0, 3, 0, 0), - Seq(2, "two", 0, 0, 0, 4, 0, 0), - Seq(3, "three", 0, 1, 0, 1, 0, 1), - Seq(4, "four", 0, 0, 1, 0, 0, 0) - )) + assert( + actual === Seq( + Seq(1, "one", 0, 0, 0, 3, 0, 0), + Seq(2, "two", 0, 0, 0, 4, 0, 0), + Seq(3, "three", 0, 1, 0, 1, 0, 1), + Seq(4, "four", 0, 0, 1, 0, 0, 0) + ) + ) } test("histogram with int values") { @@ -156,18 +164,23 @@ class HistogramSuite extends AnyFunSuite with SparkTestSession { test("histogram with one threshold") { val histogram = ints.histogram(Seq(0), $"value", $"id") val actual = histogram.orderBy($"id").collect().toSeq.map(_.toSeq) - assert(histogram.schema === StructType(Seq( - StructField("id", IntegerType, nullable = false), - StructField("≤0", LongType, nullable = true), - StructField(">0", LongType, nullable = true) - )) + assert( + histogram.schema === StructType( + Seq( + StructField("id", IntegerType, nullable = false), + StructField("≤0", LongType, nullable = true), + StructField(">0", LongType, nullable = true) + ) + ) + ) + assert( + actual === Seq( + Seq(1, 0, 3), + Seq(2, 0, 4), + Seq(3, 1, 2), + Seq(4, 1, 0) + ) ) - assert(actual === Seq( - Seq(1, 0, 3), - Seq(2, 0, 4), - Seq(3, 1, 2), - Seq(4, 1, 0) - )) } test("histogram with duplicate thresholds") { diff --git a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala index fbb12e40..78c90b48 100644 --- a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala @@ -39,7 +39,11 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { val emptyDataFrame: DataFrame = spark.createDataFrame(Seq.empty[Value]) test("Get Spark version") { - assert(VersionString.contains(s"-$BuildSparkCompatVersionString-") || VersionString.endsWith(s"-$BuildSparkCompatVersionString")) + assert( + VersionString.contains(s"-$BuildSparkCompatVersionString-") || VersionString.endsWith( + s"-$BuildSparkCompatVersionString" + ) + ) assert(spark.version.startsWith(s"$BuildSparkCompatVersionString.")) assert(SparkVersion === BuildSparkVersion) @@ -121,7 +125,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { test("UnpersistHandle throws on unpersist with blocking if no DataFrame is set") { val unpersist = UnpersistHandle() - assert(intercept[IllegalStateException] { unpersist(blocking = true) }.getMessage === s"DataFrame has to be set first") + assert(intercept[IllegalStateException] { + unpersist(blocking = true) + }.getMessage === s"DataFrame has to be set first") } test("SilentUnpersistHandle does not throw on unpersist if no DataFrame is set") { @@ -145,7 +151,14 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { test("count_null") { val df = Seq( - (1, "some"), (2, "text"), (3, "and"), (4, "some"), (5, "null"), (6, "values"), (7, null), (8, null) + (1, "some"), + (2, "text"), + (3, "and"), + (4, "some"), + (5, "null"), + (6, "values"), + (7, null), + (8, null) ).toDF("id", "str") val actual = df.select( @@ -153,7 +166,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { count($"str").as("strs"), count_null($"id").as("null ids"), count_null($"str").as("null strs") - ).collect().head + ).collect() + .head assert(actual === Row(8, 6, 0, 2)) } @@ -260,7 +274,6 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { assertIsDataset[Value](spark.createDataFrame(Seq.empty[Value]).call(_.as[Value])) } - Seq(true, false).foreach { condition => test(s"call on $condition condition dataset-to-dataset transformation") { assertIsGenericType[Dataset[Value]]( @@ -298,10 +311,10 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { ) } - test(s"call on $condition condition either dataset-to-dataset transformation") { assertIsGenericType[Dataset[Value]]( - spark.emptyDataset[Value] + spark + .emptyDataset[Value] .transform( _.on(condition) .either(_.sort()) @@ -312,7 +325,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { test(s"call on $condition condition either dataset-to-dataframe transformation") { assertIsGenericType[DataFrame]( - spark.emptyDataset[Value] + spark + .emptyDataset[Value] .transform( _.on(condition) .either(_.drop("string")) @@ -323,7 +337,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { test(s"call on $condition condition either dataframe-to-dataset transformation") { assertIsGenericType[Dataset[Value]]( - spark.createDataFrame(Seq.empty[Value]) + spark + .createDataFrame(Seq.empty[Value]) .transform( _.on(condition) .either(_.as[Value]) @@ -334,7 +349,8 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { test(s"call on $condition condition either dataframe-to-dataframe transformation") { assertIsGenericType[DataFrame]( - spark.createDataFrame(Seq.empty[Value]) + spark + .createDataFrame(Seq.empty[Value]) .transform( _.on(condition) .either(_.drop("string")) @@ -344,7 +360,6 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { } } - test("on true condition call either writer-to-writer methods") { assertIsGenericType[DataFrameWriter[Value]]( spark @@ -396,7 +411,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { } test("global row number preserves order") { - doTestWithRowNumbers()(){ df => + doTestWithRowNumbers()() { df => assert(df.columns === Seq("id", "rand", "row_number")) } } @@ -415,7 +430,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, OFF_HEAP, NONE).foreach { level => test(s"global row number with $level") { - if (level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) { + if ( + level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5) + ) { assertThrows[IllegalArgumentException] { doTestWithRowNumbers(storageLevel = level)($"id")() } @@ -427,7 +444,9 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { Seq(MEMORY_AND_DISK, MEMORY_ONLY, DISK_ONLY, OFF_HEAP, NONE).foreach { level => test(s"global row number allows to unpersist with $level") { - if (level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) { + if ( + level.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5) + ) { assertThrows[IllegalArgumentException] { doTestWithRowNumbers(storageLevel = level)($"id")() } @@ -447,59 +466,60 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { test("global row number with existing row_number column") { // this overwrites the existing column 'row_number' (formerly 'rand') with the row numbers - doTestWithRowNumbers { df => df.withColumnRenamed("rand", "row_number") }(){ df => + doTestWithRowNumbers { df => df.withColumnRenamed("rand", "row_number") }() { df => assert(df.columns === Seq("id", "row_number")) } } test("global row number with custom row_number column") { // this puts the row numbers in the column "row", which is not the default column name - doTestWithRowNumbers(df => df.withColumnRenamed("rand", "row_number"), - rowNumberColumnName = "row" )(){ df => + doTestWithRowNumbers(df => df.withColumnRenamed("rand", "row_number"), rowNumberColumnName = "row")() { df => assert(df.columns === Seq("id", "row_number", "row")) } } test("global row number with internal column names") { - val cols = Seq("mono_id", "partition_id", "local_row_number", "max_local_row_number", - "cum_row_numbers", "partition_offset") + val cols = + Seq("mono_id", "partition_id", "local_row_number", "max_local_row_number", "cum_row_numbers", "partition_offset") var prefix: String = null doTestWithRowNumbers { df => prefix = distinctPrefixFor(df.columns) - cols.foldLeft(df){ (df, name) => df.withColumn(prefix + name, rand()) } - }(){ df => + cols.foldLeft(df) { (df, name) => df.withColumn(prefix + name, rand()) } + }() { df => assert(df.columns === Seq("id", "rand") ++ cols.map(prefix + _) :+ "row_number") } } - def doTestWithRowNumbers(transform: DataFrame => DataFrame = identity, - rowNumberColumnName: String = "row_number", - storageLevel: StorageLevel = MEMORY_AND_DISK, - unpersistHandle: UnpersistHandle = UnpersistHandle.Noop) - (columns: Column*) - (handle: DataFrame => Unit = identity[DataFrame]): Unit = { + def doTestWithRowNumbers( + transform: DataFrame => DataFrame = identity, + rowNumberColumnName: String = "row_number", + storageLevel: StorageLevel = MEMORY_AND_DISK, + unpersistHandle: UnpersistHandle = UnpersistHandle.Noop + )(columns: Column*)(handle: DataFrame => Unit = identity[DataFrame]): Unit = { val partitions = 10 val rowsPerPartition = 1000 val rows = partitions * rowsPerPartition assert(partitions > 1) assert(rowsPerPartition > 1) - val df = spark.range(1, rows + 1, 1, partitions) + val df = spark + .range(1, rows + 1, 1, partitions) .withColumn("rand", rand()) .transform(transform) .withRowNumbers( - rowNumberColumnName=rowNumberColumnName, - storageLevel=storageLevel, - unpersistHandle=unpersistHandle, - columns: _*) + rowNumberColumnName = rowNumberColumnName, + storageLevel = storageLevel, + unpersistHandle = unpersistHandle, + columns: _* + ) .cache() try { // testing with descending order is only supported for a single column val desc = columns.map(_.expr) match { case Seq(SortOrder(_, Descending, _, _)) => true - case _ => false + case _ => false } // assert row numbers are correct @@ -519,7 +539,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { } val correctRowNumbers = df.where(expect).count() - val incorrectRowNumbers = df.where(! expect).count() + val incorrectRowNumbers = df.where(!expect).count() assert(correctRowNumbers === rows) assert(incorrectRowNumbers === 0) } @@ -536,50 +556,68 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { (7, 3155378975999999999L) ).toDF("id", "ts") - val plan = df.select( - $"id", - dotNetTicksToTimestamp($"ts"), - dotNetTicksToTimestamp("ts"), - dotNetTicksToUnixEpoch($"ts"), - dotNetTicksToUnixEpoch("ts"), - dotNetTicksToUnixEpochNanos($"ts"), - dotNetTicksToUnixEpochNanos("ts") - ).orderBy($"id") - assert(plan.schema.fields.map(_.dataType) === Seq( - IntegerType, TimestampType, TimestampType, DecimalType(29, 9), DecimalType(29, 9), LongType, LongType - )) + val plan = df + .select( + $"id", + dotNetTicksToTimestamp($"ts"), + dotNetTicksToTimestamp("ts"), + dotNetTicksToUnixEpoch($"ts"), + dotNetTicksToUnixEpoch("ts"), + dotNetTicksToUnixEpochNanos($"ts"), + dotNetTicksToUnixEpochNanos("ts") + ) + .orderBy($"id") + assert( + plan.schema.fields.map(_.dataType) === Seq( + IntegerType, + TimestampType, + TimestampType, + DecimalType(29, 9), + DecimalType(29, 9), + LongType, + LongType + ) + ) val actual = plan.collect() - assert(actual.map(_.getTimestamp(1)) === Seq( - Timestamp.from(Instant.parse("1900-01-01T00:00:00Z")), - Timestamp.from(Instant.parse("1970-01-01T00:00:00Z")), - Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")), - Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")), - Timestamp.from(Instant.parse("2023-03-27T19:16:14.895931Z")), - // largest possible unix epoch nanos - Timestamp.from(Instant.parse("2262-04-11T23:47:16.854775Z")), - Timestamp.from(Instant.parse("9999-12-31T23:59:59.999999Z")), - )) + assert( + actual.map(_.getTimestamp(1)) === Seq( + Timestamp.from(Instant.parse("1900-01-01T00:00:00Z")), + Timestamp.from(Instant.parse("1970-01-01T00:00:00Z")), + Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")), + Timestamp.from(Instant.parse("2023-03-27T19:16:14.89593Z")), + Timestamp.from(Instant.parse("2023-03-27T19:16:14.895931Z")), + // largest possible unix epoch nanos + Timestamp.from(Instant.parse("2262-04-11T23:47:16.854775Z")), + Timestamp.from(Instant.parse("9999-12-31T23:59:59.999999Z")), + ) + ) assert(actual.map(_.getTimestamp(2)) === actual.map(_.getTimestamp(1))) - assert(actual.map(_.getDecimal(3)).map(BigDecimal(_)) === Array( - BigDecimal(-2208988800000000000L, 9), - BigDecimal(0, 9), - BigDecimal(1679944574895930800L, 9), - BigDecimal(1679944574895930900L, 9), - BigDecimal(1679944574895931000L, 9), - // largest possible unix epoch nanos - BigDecimal(9223372036854775800L, 9), - BigDecimal(2534023007999999999L, 7).setScale(9), - )) + assert( + actual.map(_.getDecimal(3)).map(BigDecimal(_)) === Array( + BigDecimal(-2208988800000000000L, 9), + BigDecimal(0, 9), + BigDecimal(1679944574895930800L, 9), + BigDecimal(1679944574895930900L, 9), + BigDecimal(1679944574895931000L, 9), + // largest possible unix epoch nanos + BigDecimal(9223372036854775800L, 9), + BigDecimal(2534023007999999999L, 7).setScale(9), + ) + ) assert(actual.map(_.getDecimal(4)) === actual.map(_.getDecimal(3))) - assert(actual.map(row => - if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getLong(5) else null - ) === actual.map(row => - if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getDecimal(3).multiply(new java.math.BigDecimal(1000000000)).longValue() else null - )) + assert( + actual.map(row => + if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) row.getLong(5) else null + ) === actual.map(row => + if (BigDecimal(row.getDecimal(3)) <= BigDecimal(9223372036854775800L, 9)) + row.getDecimal(3).multiply(new java.math.BigDecimal(1000000000)).longValue() + else null + ) + ) assert(actual.map(_.get(6)) === actual.map(_.get(5))) } @@ -596,24 +634,32 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { df.select(timestampToDotNetTicks($"ts")) } } else { - val plan = df.select( - $"id", - timestampToDotNetTicks($"ts"), - timestampToDotNetTicks("ts"), - ).orderBy($"id") + val plan = df + .select( + $"id", + timestampToDotNetTicks($"ts"), + timestampToDotNetTicks("ts"), + ) + .orderBy($"id") - assert(plan.schema.fields.map(_.dataType) === Seq( - IntegerType, LongType, LongType - )) + assert( + plan.schema.fields.map(_.dataType) === Seq( + IntegerType, + LongType, + LongType + ) + ) val actual = plan.collect() - assert(actual.map(_.getLong(1)) === Seq( - 599266080000000000L, - 621355968000000000L, - 638155413748959310L, - 3155378975999999990L - )) + assert( + actual.map(_.getLong(1)) === Seq( + 599266080000000000L, + 621355968000000000L, + 638155413748959310L, + 3155378975999999990L + ) + ) assert(actual.map(_.getLong(2)) === actual.map(_.getLong(1))) val message = intercept[AnalysisException] { @@ -621,17 +667,29 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { }.getMessage if (SparkMajorVersion == 3 && SparkMinorVersion == 1) { - assert(message.startsWith("cannot resolve 'unix_micros(`ts`)' due to data type mismatch: argument 1 requires timestamp type, however, '`ts`' is of bigint type.;")) + assert( + message.startsWith( + "cannot resolve 'unix_micros(`ts`)' due to data type mismatch: argument 1 requires timestamp type, however, '`ts`' is of bigint type.;" + ) + ) } else if (SparkMajorVersion == 3 && SparkMinorVersion < 4) { - assert(message.startsWith("cannot resolve 'unix_micros(ts)' due to data type mismatch: argument 1 requires timestamp type, however, 'ts' is of bigint type.;")) + assert( + message.startsWith( + "cannot resolve 'unix_micros(ts)' due to data type mismatch: argument 1 requires timestamp type, however, 'ts' is of bigint type.;" + ) + ) } else if (SparkMajorVersion == 3 && SparkMinorVersion >= 4 || SparkMajorVersion > 3) { - assert(message.startsWith("[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve \"unix_micros(ts)\" due to data type mismatch: Parameter 1 requires the \"TIMESTAMP\" type, however \"ts\" has the type \"BIGINT\".")) + assert( + message.startsWith( + "[DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve \"unix_micros(ts)\" due to data type mismatch: Parameter 1 requires the \"TIMESTAMP\" type, however \"ts\" has the type \"BIGINT\"." + ) + ) } } } test("Unix epoch to .Net ticks") { - def df[T : Encoder](v: T): DataFrame = + def df[T: Encoder](v: T): DataFrame = spark.createDataset(Seq(v)).withColumnRenamed("value", "ts") Seq( @@ -667,7 +725,7 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { } test("Unix epoch nanos to .Net ticks") { - def df[T : Encoder](v: T): DataFrame = + def df[T: Encoder](v: T): DataFrame = spark.createDataset(Seq(v)).withColumnRenamed("value", "ts") Seq( @@ -713,7 +771,9 @@ object SparkSuite { import spark.implicits._ spark .range(0, 3, 1, 3) - .mapPartitions(it => it.map(id => (id, TaskContext.get().partitionId(), TaskContext.get().getLocalProperty("spark.job.description")))) + .mapPartitions(it => + it.map(id => (id, TaskContext.get().partitionId(), TaskContext.get().getLocalProperty("spark.job.description"))) + ) .as[(Long, Long, String)] .sort() .collect() diff --git a/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala b/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala index 57f7fcf7..62db45eb 100644 --- a/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/WritePartitionedSuite.scala @@ -44,7 +44,6 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { Value(4, Date.valueOf("2020-07-01"), "four") ).toDS() - test("write partitionedBy requires caching with AQE enabled") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { Some(spark.version) @@ -108,7 +107,9 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { test("write with one partition") { withTempPath { dir => withUnpersist() { handle => - values.writePartitionedBy(Seq($"id"), Seq($"date"), partitions = Some(1), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath) + values + .writePartitionedBy(Seq($"id"), Seq($"date"), partitions = Some(1), unpersistHandle = Some(handle)) + .csv(dir.getAbsolutePath) } val partitions = dir.list().filter(_.startsWith("id=")).sorted @@ -123,7 +124,9 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { test("write with partition order") { withTempPath { dir => withUnpersist() { handle => - values.writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date"), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath) + values + .writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date"), unpersistHandle = Some(handle)) + .csv(dir.getAbsolutePath) } val partitions = dir.list().filter(_.startsWith("id=")).sorted @@ -134,26 +137,40 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { assert(files.length === 1) val source = Source.fromFile(new File(file, files(0))) - val lines = try source.getLines().toList finally source.close() + val lines = + try source.getLines().toList + finally source.close() partition match { - case "id=1" => assert(lines === Seq( - "2020-07-01,one", - "2020-07-02,One", - "2020-07-03,ONE", - "2020-07-04,one" - )) - case "id=2" => assert(lines === Seq( - "2020-07-01,two", - "2020-07-02,Two", - "2020-07-03,TWO", - "2020-07-04,two" - )) - case "id=3" => assert(lines === Seq( - "2020-07-01,three" - )) - case "id=4" => assert(lines === Seq( - "2020-07-01,four" - )) + case "id=1" => + assert( + lines === Seq( + "2020-07-01,one", + "2020-07-02,One", + "2020-07-03,ONE", + "2020-07-04,one" + ) + ) + case "id=2" => + assert( + lines === Seq( + "2020-07-01,two", + "2020-07-02,Two", + "2020-07-03,TWO", + "2020-07-04,two" + ) + ) + case "id=3" => + assert( + lines === Seq( + "2020-07-01,three" + ) + ) + case "id=4" => + assert( + lines === Seq( + "2020-07-01,four" + ) + ) } } } @@ -162,7 +179,9 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { test("write with desc partition order") { withTempPath { dir => withUnpersist() { handle => - values.writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date".desc), unpersistHandle = Some(handle)).csv(dir.getAbsolutePath) + values + .writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date".desc), unpersistHandle = Some(handle)) + .csv(dir.getAbsolutePath) } val partitions = dir.list().filter(_.startsWith("id=")).sorted @@ -173,26 +192,40 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { assert(files.length === 1) val source = Source.fromFile(new File(file, files(0))) - val lines = try source.getLines().toList finally source.close() + val lines = + try source.getLines().toList + finally source.close() partition match { - case "id=1" => assert(lines === Seq( - "2020-07-04,one", - "2020-07-03,ONE", - "2020-07-02,One", - "2020-07-01,one" - )) - case "id=2" => assert(lines === Seq( - "2020-07-04,two", - "2020-07-03,TWO", - "2020-07-02,Two", - "2020-07-01,two" - )) - case "id=3" => assert(lines === Seq( - "2020-07-01,three" - )) - case "id=4" => assert(lines === Seq( - "2020-07-01,four" - )) + case "id=1" => + assert( + lines === Seq( + "2020-07-04,one", + "2020-07-03,ONE", + "2020-07-02,One", + "2020-07-01,one" + ) + ) + case "id=2" => + assert( + lines === Seq( + "2020-07-04,two", + "2020-07-03,TWO", + "2020-07-02,Two", + "2020-07-01,two" + ) + ) + case "id=3" => + assert( + lines === Seq( + "2020-07-01,three" + ) + ) + case "id=4" => + assert( + lines === Seq( + "2020-07-01,four" + ) + ) } } } @@ -202,7 +235,15 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { val projection = Some(Seq(col("id"), reverse(col("value")))) withTempPath { path => withUnpersist() { handle => - values.writePartitionedBy(Seq($"id"), Seq.empty, Seq($"date"), writtenProjection = projection, unpersistHandle = Some(handle)).csv(path.getAbsolutePath) + values + .writePartitionedBy( + Seq($"id"), + Seq.empty, + Seq($"date"), + writtenProjection = projection, + unpersistHandle = Some(handle) + ) + .csv(path.getAbsolutePath) } val partitions = path.list().filter(_.startsWith("id=")).sorted @@ -214,7 +255,8 @@ class WritePartitionedSuite extends AnyFunSuite with SparkTestSession { val lines = files.flatMap { file => val source = Source.fromFile(new File(dir, file)) - try source.getLines().toList finally source.close() + try source.getLines().toList + finally source.close() } partition match { diff --git a/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala index 855460b6..4fc9e48a 100644 --- a/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/diff/AppSuite.scala @@ -22,7 +22,6 @@ import uk.co.gresearch.spark.SparkTestSession import java.io.File - class AppSuite extends AnyFunSuite with SparkTestSession { import spark.implicits._ @@ -38,15 +37,21 @@ class AppSuite extends AnyFunSuite with SparkTestSession { // launch app val jsonPath = new File(path, "diff.json").getAbsolutePath - App.main(Array( - "--left-format", "csv", - "--left-schema", "id int, value string", - "--output-format", "json", - "--id", "id", - leftPath, - "right_parquet", - jsonPath - )) + App.main( + Array( + "--left-format", + "csv", + "--left-schema", + "id int, value string", + "--output-format", + "json", + "--id", + "id", + leftPath, + "right_parquet", + jsonPath + ) + ) // assert written diff val actual = spark.read.json(jsonPath) @@ -67,14 +72,18 @@ class AppSuite extends AnyFunSuite with SparkTestSession { // launch app val outputPath = new File(path, "diff.parquet").getAbsolutePath - App.main(Array( - "--format", "parquet", - "--id", "id", - ) ++ filter.toSeq.flatMap(f => Array("--filter", f)) ++ Array( - leftPath, - rightPath, - outputPath - )) + App.main( + Array( + "--format", + "parquet", + "--id", + "id", + ) ++ filter.toSeq.flatMap(f => Array("--filter", f)) ++ Array( + leftPath, + rightPath, + outputPath + ) + ) // assert written diff val actual = spark.read.parquet(outputPath).orderBy($"id").collect() @@ -98,14 +107,19 @@ class AppSuite extends AnyFunSuite with SparkTestSession { // launch app val outputPath = new File(path, "diff.parquet").getAbsolutePath assertThrows[RuntimeException]( - App.main(Array( - "--format", "parquet", - "--id", "id", - "--filter", "A", - leftPath, - rightPath, - outputPath - )) + App.main( + Array( + "--format", + "parquet", + "--id", + "id", + "--filter", + "A", + leftPath, + rightPath, + outputPath + ) + ) ) } } @@ -122,14 +136,18 @@ class AppSuite extends AnyFunSuite with SparkTestSession { // launch app val outputPath = new File(path, "diff.parquet").getAbsolutePath - App.main(Array( - "--format", "parquet", - "--statistics", - "--id", "id", - leftPath, - rightPath, - outputPath - )) + App.main( + Array( + "--format", + "parquet", + "--statistics", + "--id", + "id", + leftPath, + rightPath, + outputPath + ) + ) // assert written diff val actual = spark.read.parquet(outputPath).as[(String, Long)].collect().toMap diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala index c2401bdd..f69ac4ef 100644 --- a/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffComparatorSuite.scala @@ -24,13 +24,25 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.scalatest.funsuite.AnyFunSuite import uk.co.gresearch.spark.SparkTestSession -import uk.co.gresearch.spark.diff.DiffComparatorSuite.{decimalEnc, optionsWithRelaxedComparators, optionsWithTightComparators} +import uk.co.gresearch.spark.diff.DiffComparatorSuite.{ + decimalEnc, + optionsWithRelaxedComparators, + optionsWithTightComparators +} import uk.co.gresearch.spark.diff.comparator._ import java.sql.{Date, Timestamp} import java.time.Duration -case class Numbers(id: Int, longValue: Long, floatValue: Float, doubleValue: Double, decimalValue: Decimal, someInt: Option[Int], someLong: Option[Long]) +case class Numbers( + id: Int, + longValue: Long, + floatValue: Float, + doubleValue: Double, + decimalValue: Decimal, + someInt: Option[Int], + someLong: Option[Long] +) case class Strings(id: Int, string: String) case class Dates(id: Int, date: Date) case class Times(id: Int, time: Timestamp) @@ -123,10 +135,12 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { Maps(7, Map(3 -> 4L, 2 -> 2L, 1 -> 1L)), ).toDS() - def doTest(optionsWithTightComparators: DiffOptions, - optionsWithRelaxedComparators: DiffOptions, - left: DataFrame = this.left.toDF(), - right: DataFrame = this.right.toDF()): Unit = { + def doTest( + optionsWithTightComparators: DiffOptions, + optionsWithRelaxedComparators: DiffOptions, + left: DataFrame = this.left.toDF(), + right: DataFrame = this.right.toDF() + ): Unit = { // left and right numbers have some differences val actualWithoutComparators = left.diff(right, "id").orderBy($"id") @@ -178,22 +192,44 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs), "default any equiv" -> DiffOptions.default .withDefaultComparator((_: Any, _: Any) => true), - "typed diff comparator" -> DiffOptions.default .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs)) - .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), LongType, FloatType, DoubleType, DecimalType(38, 18)), + .withComparator( + (left: Column, right: Column) => abs(left) <=> abs(right), + LongType, + FloatType, + DoubleType, + DecimalType(38, 18) + ), "typed diff comparator for type" -> DiffOptions.default // only works if data type is equal to input type of typed diff comparator .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), IntegerType) - .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), LongType, FloatType, DoubleType, DecimalType(38, 18)), - + .withComparator( + (left: Column, right: Column) => abs(left) <=> abs(right), + LongType, + FloatType, + DoubleType, + DecimalType(38, 18) + ), "diff comparator for type" -> DiffOptions.default .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), IntegerType) - .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), LongType, FloatType, DoubleType, DecimalType(38, 18)), + .withComparator( + (left: Column, right: Column) => abs(left) <=> abs(right), + LongType, + FloatType, + DoubleType, + DecimalType(38, 18) + ), "diff comparator for name" -> DiffOptions.default .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), "someInt") - .withComparator((left: Column, right: Column) => abs(left) <=> abs(right), "longValue", "floatValue", "doubleValue", "someLong", "decimalValue"), - + .withComparator( + (left: Column, right: Column) => abs(left) <=> abs(right), + "longValue", + "floatValue", + "doubleValue", + "someLong", + "decimalValue" + ), "encoder equiv" -> DiffOptions.default .withComparator((left: Int, right: Int) => left.abs == right.abs) .withComparator((left: Long, right: Long) => left.abs == right.abs) @@ -211,12 +247,14 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { .withComparator((left: Long, right: Long) => left.abs == right.abs, Encoders.scalaLong, "longValue", "someLong") .withComparator((left: Float, right: Float) => left.abs == right.abs, Encoders.scalaFloat, "floatValue") .withComparator((left: Double, right: Double) => left.abs == right.abs, Encoders.scalaDouble, "doubleValue") - .withComparator((left: Decimal, right: Decimal) => left.abs == right.abs, ExpressionEncoder[Decimal](), "decimalValue"), - + .withComparator( + (left: Decimal, right: Decimal) => left.abs == right.abs, + ExpressionEncoder[Decimal](), + "decimalValue" + ), "typed equiv for type" -> DiffOptions.default .withComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType) .withComparator(alwaysTrueEquiv, LongType, FloatType, DoubleType, DecimalType(38, 18)), - "any equiv for column name" -> DiffOptions.default .withComparator(alwaysTrueEquiv, "someInt") .withComparator(alwaysTrueEquiv, "longValue", "floatValue", "doubleValue", "someLong", "decimalValue") @@ -228,7 +266,8 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { val allValuesEqual = Set("default any equiv", "any equiv for type", "any equiv for column name").contains(label) val unchangedIds = if (allValuesEqual) Seq(2, 3) else Seq(2) - val expected = diffWithoutComparators.withColumn("diff", when($"id".isin(unchangedIds: _*), lit("N")).otherwise($"diff")) + val expected = + diffWithoutComparators.withColumn("diff", when($"id".isin(unchangedIds: _*), lit("N")).otherwise($"diff")) assert(expected.where($"diff" === "C").count() === 3 - unchangedIds.size) val actual = left.diff(rightSign, options, "id").orderBy($"id").collect() @@ -240,44 +279,45 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { test("null-aware comparator") { val options = DiffOptions.default.withComparator( // only if this method is called with nulls, the expected result can occur - (x: Column, y: Column) => x.isNull || y.isNull || x === y, StringType) + (x: Column, y: Column) => x.isNull || y.isNull || x === y, + StringType + ) val diff = leftStrings.diff(rightStrings, options, "id").orderBy($"id").collect() - assert(diff === Seq( - Row("N", 1, "1", "1"), - Row("N", 2, null, "2"), - Row("N", 3, "3", null), - Row("N", 4, null, null), - )) + assert( + diff === Seq( + Row("N", 1, "1", "1"), + Row("N", 2, null, "2"), + Row("N", 3, "3", null), + Row("N", 4, null, null), + ) + ) } Seq( "diff comparator" -> (DiffOptions.default .withDefaultComparator((_: Column, _: Column) => lit(1)), - Seq( - "'(1 AND 1)' requires boolean type, not int", // until Spark 3.3 - "\"(1 AND 1)\" due to data type mismatch: " + // Spark 3.4 and beyond - "the binary operator requires the input type \"BOOLEAN\", not \"INT\"." - ) - ), + Seq( + "'(1 AND 1)' requires boolean type, not int", // until Spark 3.3 + "\"(1 AND 1)\" due to data type mismatch: " + // Spark 3.4 and beyond + "the binary operator requires the input type \"BOOLEAN\", not \"INT\"." + )), "encoder equiv" -> (DiffOptions.default .withDefaultComparator((_: Int, _: Int) => true), - Seq( - "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1 - "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3 - "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond - "the binary operator requires the input type \"INT\", not \"BIGINT\"." - ) - ), + Seq( + "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1 + "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3 + "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond + "the binary operator requires the input type \"INT\", not \"BIGINT\"." + )), "typed equiv" -> (DiffOptions.default .withDefaultComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs, IntegerType)), - Seq( - "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1 - "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3 - "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond - "the binary operator requires the input type \"INT\", not \"BIGINT\"." - ) - ) + Seq( + "'(`longValue` ≡ `longValue`)' requires int type, not bigint", // Spark 3.0 and 3.1 + "'(longValue ≡ longValue)' requires int type, not bigint", // Spark 3.2 and 3.3 + "\"(longValue ≡ longValue)\" due to data type mismatch: " + // Spark 3.4 and beyond + "the binary operator requires the input type \"INT\", not \"BIGINT\"." + )) ).foreach { case (label, (options, expecteds)) => test(s"with comparator of incompatible type - $label") { val exception = intercept[AnalysisException] { @@ -289,34 +329,45 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { } test("absolute epsilon comparator (inclusive)") { - val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.5).asAbsolute().asInclusive()) - val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asInclusive()) + val optionsWithTightComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.5).asAbsolute().asInclusive()) + val optionsWithRelaxedComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asInclusive()) doTest(optionsWithTightComparator, optionsWithRelaxedComparator) } test("absolute epsilon comparator (exclusive)") { - val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asExclusive()) - val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.001).asAbsolute().asExclusive()) + val optionsWithTightComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.0).asAbsolute().asExclusive()) + val optionsWithRelaxedComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1.001).asAbsolute().asExclusive()) doTest(optionsWithTightComparator, optionsWithRelaxedComparator) } test("relative epsilon comparator (inclusive)") { - val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.1).asRelative().asInclusive()) - val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1/3.0).asRelative().asInclusive()) + val optionsWithTightComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(0.1).asRelative().asInclusive()) + val optionsWithRelaxedComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0).asRelative().asInclusive()) doTest(optionsWithTightComparator, optionsWithRelaxedComparator) } test("relative epsilon comparator (exclusive)") { - val optionsWithTightComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1/3.0).asRelative().asExclusive()) - val optionsWithRelaxedComparator = DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1/3.0 + .001).asRelative().asExclusive()) + val optionsWithTightComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0).asRelative().asExclusive()) + val optionsWithRelaxedComparator = + DiffOptions.default.withDefaultComparator(DiffComparators.epsilon(1 / 3.0 + .001).asRelative().asExclusive()) doTest(optionsWithTightComparator, optionsWithRelaxedComparator) } test("whitespace agnostic string comparator") { val left = Seq(Strings(1, "one"), Strings(2, "two spaces "), Strings(3, "three"), Strings(4, "four")).toDF() - val right = Seq(Strings(1, "one"), Strings(2, " two \t\nspaces"), Strings(3, "three\nspaces"), Strings(5, "five")).toDF() - val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = false)) - val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = true)) + val right = + Seq(Strings(1, "one"), Strings(2, " two \t\nspaces"), Strings(3, "three\nspaces"), Strings(5, "five")).toDF() + val optionsWithTightComparator = + DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = false)) + val optionsWithRelaxedComparator = + DiffOptions.default.withComparator(DiffComparators.string(whitespaceAgnostic = true)) doTest(optionsWithTightComparator, optionsWithRelaxedComparator, left, right) } @@ -331,26 +382,34 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { } } else { test("duration comparator with date (inclusive)") { - val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(23)).asInclusive(), "date") - val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asInclusive(), "date") + val optionsWithTightComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(23)).asInclusive(), "date") + val optionsWithRelaxedComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asInclusive(), "date") doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftDates.toDF, rightDates.toDF) } test("duration comparator with date (exclusive)") { - val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asExclusive(), "date") - val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(25)).asExclusive(), "date") + val optionsWithTightComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(24)).asExclusive(), "date") + val optionsWithRelaxedComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofHours(25)).asExclusive(), "date") doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftDates.toDF, rightDates.toDF) } test("duration comparator with time (inclusive)") { - val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(59)).asInclusive(), "time") - val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asInclusive(), "time") + val optionsWithTightComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(59)).asInclusive(), "time") + val optionsWithRelaxedComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asInclusive(), "time") doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftTimes.toDF, rightTimes.toDF) } test("duration comparator with time (exclusive)") { - val optionsWithTightComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asExclusive(), "time") - val optionsWithRelaxedComparator = DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(61)).asExclusive(), "time") + val optionsWithTightComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(60)).asExclusive(), "time") + val optionsWithRelaxedComparator = + DiffOptions.default.withComparator(DiffComparators.duration(Duration.ofSeconds(61)).asExclusive(), "time") doTest(optionsWithTightComparator, optionsWithRelaxedComparator, leftTimes.toDF, rightTimes.toDF) } } @@ -365,12 +424,15 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { val options = DiffOptions.default.withComparator(DiffComparators.map[Int, Long](sensitive), "map") val actual = leftMaps.diff(rightMaps, options, "id").orderBy($"id").collect() - val diffs = Seq((1, "N"), (2, "C"), (3, "C"), (4, "D"), (5, "I"), (6, if (sensitive) "C" else "N"), (7, "C")).toDF("id", "diff") - val expected = leftMaps.withColumnRenamed("map", "left_map") + val diffs = Seq((1, "N"), (2, "C"), (3, "C"), (4, "D"), (5, "I"), (6, if (sensitive) "C" else "N"), (7, "C")) + .toDF("id", "diff") + val expected = leftMaps + .withColumnRenamed("map", "left_map") .join(rightMaps.withColumnRenamed("map", "right_map"), Seq("id"), "fullouter") .join(diffs, "id") .select($"diff", $"id", $"left_map", $"right_map") - .orderBy($"id").collect() + .orderBy($"id") + .collect() assert(actual === expected) } } @@ -386,16 +448,27 @@ class DiffComparatorSuite extends AnyFunSuite with SparkTestSession { } val diffComparatorMethodTests: Seq[(String, (() => DiffComparator, DiffComparator))] = - if(DurationDiffComparator.isSupportedBySpark) { - Seq("duration" -> (() => DiffComparators.duration(Duration.ofSeconds(1)).asExclusive(), DurationDiffComparator(Duration.ofSeconds(1), inclusive = false))) - } else { Seq.empty } ++ Seq( - "default" -> (() => DiffComparators.default(), DefaultDiffComparator), - "nullSafeEqual" -> (() => DiffComparators.nullSafeEqual(), NullSafeEqualDiffComparator), - "equiv with encoder" -> (() => DiffComparators.equiv(IntEquiv), EquivDiffComparator(IntEquiv)), - "equiv with type" -> (() => DiffComparators.equiv(IntEquiv, IntegerType), EquivDiffComparator(IntEquiv, IntegerType)), - "equiv with any" -> (() => DiffComparators.equiv(AnyEquiv), EquivDiffComparator(AnyEquiv)), - "epsilon" -> (() => DiffComparators.epsilon(1.0).asAbsolute().asExclusive(), EpsilonDiffComparator(1.0, relative = false, inclusive = false)) - ) + if (DurationDiffComparator.isSupportedBySpark) { + Seq( + "duration" -> (() => DiffComparators.duration(Duration.ofSeconds(1)).asExclusive(), DurationDiffComparator( + Duration.ofSeconds(1), + inclusive = false + )) + ) + } else + { Seq.empty } ++ Seq( + "default" -> (() => DiffComparators.default(), DefaultDiffComparator), + "nullSafeEqual" -> (() => DiffComparators.nullSafeEqual(), NullSafeEqualDiffComparator), + "equiv with encoder" -> (() => DiffComparators.equiv(IntEquiv), EquivDiffComparator(IntEquiv)), + "equiv with type" -> (() => + DiffComparators.equiv(IntEquiv, IntegerType), EquivDiffComparator(IntEquiv, IntegerType)), + "equiv with any" -> (() => DiffComparators.equiv(AnyEquiv), EquivDiffComparator(AnyEquiv)), + "epsilon" -> (() => DiffComparators.epsilon(1.0).asAbsolute().asExclusive(), EpsilonDiffComparator( + 1.0, + relative = false, + inclusive = false + )) + ) diffComparatorMethodTests.foreach { case (label, (method, expected)) => test(s"DiffComparator.$label") { @@ -414,9 +487,12 @@ object DiffComparatorSuite { val tightIntComparator: EquivDiffComparator[Int] = EquivDiffComparator((x: Int, y: Int) => math.abs(x - y) < 1) val tightLongComparator: EquivDiffComparator[Long] = EquivDiffComparator((x: Long, y: Long) => math.abs(x - y) < 1) - val tightFloatComparator: EquivDiffComparator[Float] = EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) < 0.001) - val tightDoubleComparator: EquivDiffComparator[Double] = EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) < 0.001) - val tightDecimalComparator: EquivDiffComparator[Decimal] = EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs < Decimal(0.001)) + val tightFloatComparator: EquivDiffComparator[Float] = + EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) < 0.001) + val tightDoubleComparator: EquivDiffComparator[Double] = + EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) < 0.001) + val tightDecimalComparator: EquivDiffComparator[Decimal] = + EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs < Decimal(0.001)) val optionsWithTightComparators: DiffOptions = DiffOptions.default .withComparator(tightIntComparator, IntegerType) @@ -427,9 +503,12 @@ object DiffComparatorSuite { val relaxedIntComparator: EquivDiffComparator[Int] = EquivDiffComparator((x: Int, y: Int) => math.abs(x - y) <= 1) val relaxedLongComparator: EquivDiffComparator[Long] = EquivDiffComparator((x: Long, y: Long) => math.abs(x - y) <= 1) - val relaxedFloatComparator: EquivDiffComparator[Float] = EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) <= 0.001) - val relaxedDoubleComparator: EquivDiffComparator[Double] = EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) <= 0.001) - val relaxedDecimalComparator: EquivDiffComparator[Decimal] = EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs <= Decimal(0.001)) + val relaxedFloatComparator: EquivDiffComparator[Float] = + EquivDiffComparator((x: Float, y: Float) => math.abs(x - y) <= 0.001) + val relaxedDoubleComparator: EquivDiffComparator[Double] = + EquivDiffComparator((x: Double, y: Double) => math.abs(x - y) <= 0.001) + val relaxedDecimalComparator: EquivDiffComparator[Decimal] = + EquivDiffComparator[Decimal]((x: Decimal, y: Decimal) => (x - y).abs <= Decimal(0.001)) val optionsWithRelaxedComparators: DiffOptions = DiffOptions.default .withComparator(relaxedIntComparator, IntegerType) diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala index 64629f93..191375ab 100644 --- a/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffOptionsSuite.scala @@ -38,17 +38,19 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession { test("diff options left and right prefixes") { // test the copy method (constructor), not the fluent methods val default = DiffOptions.default - doTestRequirement(default.copy(leftColumnPrefix = ""), - "Left column prefix must not be empty") - doTestRequirement(default.copy(rightColumnPrefix = ""), - "Right column prefix must not be empty") + doTestRequirement(default.copy(leftColumnPrefix = ""), "Left column prefix must not be empty") + doTestRequirement(default.copy(rightColumnPrefix = ""), "Right column prefix must not be empty") val prefix = "prefix" - doTestRequirement(default.copy(leftColumnPrefix = prefix, rightColumnPrefix = prefix), - s"Left and right column prefix must be distinct: $prefix") + doTestRequirement( + default.copy(leftColumnPrefix = prefix, rightColumnPrefix = prefix), + s"Left and right column prefix must be distinct: $prefix" + ) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - doTestRequirement(default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase), - s"Left and right column prefix must be distinct: $prefix") + doTestRequirement( + default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase), + s"Left and right column prefix must be distinct: $prefix" + ) } withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { default.copy(leftColumnPrefix = prefix.toLowerCase, rightColumnPrefix = prefix.toUpperCase) @@ -69,18 +71,30 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession { assert(emptyNochangeDiffValueOpts.nochangeDiffValue.isEmpty) Seq("value", "").foreach { value => - doTestRequirement(default.copy(insertDiffValue = value, changeDiffValue = value), - s"Diff values must be distinct: List($value, $value, D, N)") - doTestRequirement(default.copy(insertDiffValue = value, deleteDiffValue = value), - s"Diff values must be distinct: List($value, C, $value, N)") - doTestRequirement(default.copy(insertDiffValue = value, nochangeDiffValue = value), - s"Diff values must be distinct: List($value, C, D, $value)") - doTestRequirement(default.copy(changeDiffValue = value, deleteDiffValue = value), - s"Diff values must be distinct: List(I, $value, $value, N)") - doTestRequirement(default.copy(changeDiffValue = value, nochangeDiffValue = value), - s"Diff values must be distinct: List(I, $value, D, $value)") - doTestRequirement(default.copy(deleteDiffValue = value, nochangeDiffValue = value), - s"Diff values must be distinct: List(I, C, $value, $value)") + doTestRequirement( + default.copy(insertDiffValue = value, changeDiffValue = value), + s"Diff values must be distinct: List($value, $value, D, N)" + ) + doTestRequirement( + default.copy(insertDiffValue = value, deleteDiffValue = value), + s"Diff values must be distinct: List($value, C, $value, N)" + ) + doTestRequirement( + default.copy(insertDiffValue = value, nochangeDiffValue = value), + s"Diff values must be distinct: List($value, C, D, $value)" + ) + doTestRequirement( + default.copy(changeDiffValue = value, deleteDiffValue = value), + s"Diff values must be distinct: List(I, $value, $value, N)" + ) + doTestRequirement( + default.copy(changeDiffValue = value, nochangeDiffValue = value), + s"Diff values must be distinct: List(I, $value, D, $value)" + ) + doTestRequirement( + default.copy(deleteDiffValue = value, nochangeDiffValue = value), + s"Diff values must be distinct: List(I, C, $value, $value)" + ) } } @@ -177,11 +191,16 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession { DiffOptions.default .withComparator(EquivDiffComparator((left: Int, right: Int) => left.abs == right.abs), LongType, FloatType) } - assert(exceptionMulti.getMessage.contains("Comparator with input type int cannot be used for data type bigint, float")) + assert( + exceptionMulti.getMessage.contains("Comparator with input type int cannot be used for data type bigint, float") + ) } test("fluent methods of diff options") { - assert(DiffMode.Default != DiffMode.LeftSide, "test assumption on default diff mode must hold, otherwise test is trivial") + assert( + DiffMode.Default != DiffMode.LeftSide, + "test assumption on default diff mode must hold, otherwise test is trivial" + ) val cmp1 = new DiffComparator { override def equiv(left: Column, right: Column): Column = lit(true) @@ -211,7 +230,21 @@ class DiffOptionsSuite extends AnyFunSuite with SparkTestSession { val dexpectedDefCmp = cmp1 val expectedDtCmps = Map(IntegerType.asInstanceOf[DataType] -> cmp2) val expectedColCmps = Map("col1" -> cmp3) - val expected = DiffOptions("d", "l", "r", "i", "c", "d", "n", Some("change"), DiffMode.LeftSide, sparseMode = true, dexpectedDefCmp, expectedDtCmps, expectedColCmps) + val expected = DiffOptions( + "d", + "l", + "r", + "i", + "c", + "d", + "n", + Some("change"), + DiffMode.LeftSide, + sparseMode = true, + dexpectedDefCmp, + expectedDtCmps, + expectedColCmps + ) assert(options === expected) } diff --git a/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala b/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala index c47551c0..1fd8f2c4 100644 --- a/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/diff/DiffSuite.scala @@ -38,60 +38,50 @@ case class Value9up(ID: Int, SEQ: Option[Int], VALUE: Option[String], INFO: Opti case class ValueLeft(left_id: Int, value: Option[String]) case class ValueRight(right_id: Int, value: Option[String]) -case class DiffAs(diff: String, - id: Int, - left_value: Option[String], - right_value: Option[String]) -case class DiffAs8(diff: String, - id: Int, - seq: Option[Int], - left_value: Option[String], - right_value: Option[String], - left_meta: Option[String], - right_meta: Option[String]) -case class DiffAs8SideBySide(diff: String, - id: Int, - seq: Option[Int], - left_value: Option[String], - left_meta: Option[String], - right_value: Option[String], - right_meta: Option[String]) -case class DiffAs8OneSide(diff: String, - id: Int, - seq: Option[Int], - value: Option[String], - meta: Option[String]) -case class DiffAs8changes(diff: String, - changed: Array[String], - id: Int, - seq: Option[Int], - left_value: Option[String], - right_value: Option[String], - left_meta: Option[String], - right_meta: Option[String]) -case class DiffAs8and9(diff: String, - id: Int, - seq: Option[Int], - left_value: Option[String], - right_value: Option[String], - left_meta: Option[String], - right_info: Option[String]) - -case class DiffAsCustom(action: String, - id: Int, - before_value: Option[String], - after_value: Option[String]) -case class DiffAsSubset(diff: String, - id: Int, - left_value: Option[String]) -case class DiffAsExtra(diff: String, - id: Int, - left_value: Option[String], - right_value: Option[String], - extra: String) -case class DiffAsOneSide(diff: String, - id: Int, - value: Option[String]) +case class DiffAs(diff: String, id: Int, left_value: Option[String], right_value: Option[String]) +case class DiffAs8( + diff: String, + id: Int, + seq: Option[Int], + left_value: Option[String], + right_value: Option[String], + left_meta: Option[String], + right_meta: Option[String] +) +case class DiffAs8SideBySide( + diff: String, + id: Int, + seq: Option[Int], + left_value: Option[String], + left_meta: Option[String], + right_value: Option[String], + right_meta: Option[String] +) +case class DiffAs8OneSide(diff: String, id: Int, seq: Option[Int], value: Option[String], meta: Option[String]) +case class DiffAs8changes( + diff: String, + changed: Array[String], + id: Int, + seq: Option[Int], + left_value: Option[String], + right_value: Option[String], + left_meta: Option[String], + right_meta: Option[String] +) +case class DiffAs8and9( + diff: String, + id: Int, + seq: Option[Int], + left_value: Option[String], + right_value: Option[String], + left_meta: Option[String], + right_info: Option[String] +) + +case class DiffAsCustom(action: String, id: Int, before_value: Option[String], after_value: Option[String]) +case class DiffAsSubset(diff: String, id: Int, left_value: Option[String]) +case class DiffAsExtra(diff: String, id: Int, left_value: Option[String], right_value: Option[String], extra: String) +case class DiffAsOneSide(diff: String, id: Int, value: Option[String]) object DiffSuite { def left(spark: SparkSession): Dataset[Value] = { @@ -170,7 +160,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Value8(3, None, None, None) ).toDS() - lazy val right9: Dataset[Value9] = right8.withColumn("info", regexp_replace($"meta", "user", "info")).drop("meta").as[Value9] + lazy val right9: Dataset[Value9] = + right8.withColumn("info", regexp_replace($"meta", "user", "info")).drop("meta").as[Value9] lazy val expectedDiffColumns: Seq[String] = Seq("diff", "id", "left_value", "right_value") @@ -183,9 +174,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("D", 4, "four", null) ) - lazy val expectedDiffAs: Seq[DiffAs] = expectedDiff.map(r => - DiffAs(r.getString(0), r.getInt(1), Option(r.getString(2)), Option(r.getString(3))) - ) + lazy val expectedDiffAs: Seq[DiffAs] = + expectedDiff.map(r => DiffAs(r.getString(0), r.getInt(1), Option(r.getString(2)), Option(r.getString(3)))) lazy val expectedDiff7: Seq[Row] = Seq( Row("C", 1, "one", "One", "one label", "one label"), @@ -200,9 +190,13 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("I", 10, null, null, null, null) ) - lazy val expectedSideBySideDiff7: Seq[Row] = expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5))) - lazy val expectedLeftSideDiff7: Seq[Row] = expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4))) - lazy val expectedRightSideDiff7: Seq[Row] = expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5))) + lazy val expectedSideBySideDiff7: Seq[Row] = expectedDiff7.map(row => + Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5)) + ) + lazy val expectedLeftSideDiff7: Seq[Row] = + expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4))) + lazy val expectedRightSideDiff7: Seq[Row] = + expectedDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5))) lazy val expectedSparseDiff7: Seq[Row] = Seq( Row("C", 1, "one", "One", null, null), @@ -217,9 +211,13 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("I", 10, null, null, null, null) ) - lazy val expectedSideBySideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5))) - lazy val expectedLeftSideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4))) - lazy val expectedRightSideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5))) + lazy val expectedSideBySideSparseDiff7: Seq[Row] = expectedSparseDiff7.map(row => + Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4), row.getString(3), row.getString(5)) + ) + lazy val expectedLeftSideSparseDiff7: Seq[Row] = + expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(2), row.getString(4))) + lazy val expectedRightSideSparseDiff7: Seq[Row] = + expectedSparseDiff7.map(row => Row(row.getString(0), row.getInt(1), row.getString(3), row.getString(5))) lazy val expectedDiff7WithChanges: Seq[Row] = Seq( Row("C", Seq("value"), 1, "one", "One", "one label", "one label"), @@ -256,9 +254,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("N", 3, null, null, null, null, null) ) - lazy val expectedSideBySideDiff8: Seq[Row] = expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6))) - lazy val expectedLeftSideDiff8: Seq[Row] = expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5))) - lazy val expectedRightSideDiff8: Seq[Row] = expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6))) + lazy val expectedSideBySideDiff8: Seq[Row] = + expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6))) + lazy val expectedLeftSideDiff8: Seq[Row] = + expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5))) + lazy val expectedRightSideDiff8: Seq[Row] = + expectedDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6))) lazy val expectedSparseDiff8: Seq[Row] = Seq( Row("N", 1, 1, null, null, "user1", "user2"), @@ -271,50 +272,66 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("N", 3, null, null, null, null, null) ) - lazy val expectedSideBySideSparseDiff8: Seq[Row] = expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6))) - lazy val expectedLeftSideSparseDiff8: Seq[Row] = expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5))) - lazy val expectedRightSideSparseDiff8: Seq[Row] = expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6))) + lazy val expectedSideBySideSparseDiff8: Seq[Row] = + expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5), r.get(4), r.get(6))) + lazy val expectedLeftSideSparseDiff8: Seq[Row] = + expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(3), r.get(5))) + lazy val expectedRightSideSparseDiff8: Seq[Row] = + expectedSparseDiff8.map(r => Row(r.get(0), r.get(1), r.get(2), r.get(4), r.get(6))) lazy val expectedDiffAs8: Seq[DiffAs8] = expectedDiff8.map(r => - DiffAs8(r.getString(0), - r.getInt(1), Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)), - Option(r.getString(3)), Option(r.getString(4)), - Option(r.getString(5)), Option(r.getString(6)) + DiffAs8( + r.getString(0), + r.getInt(1), + Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)), + Option(r.getString(3)), + Option(r.getString(4)), + Option(r.getString(5)), + Option(r.getString(6)) ) ) lazy val expectedDiff8WithChanges: Seq[Row] = expectedDiff8.map(r => - Row(r.get(0), + Row( + r.get(0), r.get(0) match { case "N" => Seq.empty case "I" => null case "C" => Seq("value") case "D" => null }, - r.get(1), r.get(2), - r.getString(3), r.getString(4), - r.getString(5), r.getString(6) + r.get(1), + r.get(2), + r.getString(3), + r.getString(4), + r.getString(5), + r.getString(6) ) ) lazy val expectedDiffAs8and9: Seq[DiffAs8and9] = expectedDiff8and9.map(r => - DiffAs8and9(r.getString(0), - r.getInt(1), Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)), - Option(r.getString(3)), Option(r.getString(4)), - Option(r.getString(5)), Option(r.getString(6)) + DiffAs8and9( + r.getString(0), + r.getInt(1), + Some(r).filterNot(_.isNullAt(2)).map(_.getInt(2)), + Option(r.getString(3)), + Option(r.getString(4)), + Option(r.getString(5)), + Option(r.getString(6)) ) ) - lazy val expectedDiffWith8and9: Seq[(String, Value8, Value9)] = expectedDiffAs8and9.map(v => ( - v.diff, - if (v.diff == "I") null else Value8(v.id, v.seq, v.left_value, v.left_meta), - if (v.diff == "D") null else Value9(v.id, v.seq, v.right_value, v.right_info) - )) - - lazy val expectedDiffWith8and9up: Seq[(String, Value8, Value9up)] = expectedDiffWith8and9.map(t => - t.copy(_3 = Option(t._3).map(v => Value9up(v.id, v.seq, v.value, v.info)).orNull) + lazy val expectedDiffWith8and9: Seq[(String, Value8, Value9)] = expectedDiffAs8and9.map(v => + ( + v.diff, + if (v.diff == "I") null else Value8(v.id, v.seq, v.left_value, v.left_meta), + if (v.diff == "D") null else Value9(v.id, v.seq, v.right_value, v.right_info) + ) ) + lazy val expectedDiffWith8and9up: Seq[(String, Value8, Value9up)] = + expectedDiffWith8and9.map(t => t.copy(_3 = Option(t._3).map(v => Value9up(v.id, v.seq, v.value, v.info)).orNull)) + test("distinct prefix for") { assert(distinctPrefixFor(Seq.empty[String]) === "_") assert(distinctPrefixFor(Seq("a")) === "_") @@ -328,9 +345,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { test("diff dataframe with duplicate columns") { val df = Seq(1).toDF("id").select($"id", $"id") - doTestRequirement(df.diff(df, "id"), + doTestRequirement( + df.diff(df, "id"), "The datasets have duplicate columns.\n" + - "Left column names: id, id\nRight column names: id, id") + "Left column names: id, id\nRight column names: id, id" + ) } test("diff with no id column") { @@ -383,8 +402,7 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { test("diff with one id column case-sensitive") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - doTestRequirement(left.diff(right, "ID"), - "Some id columns do not exist: ID missing among id, value") + doTestRequirement(left.diff(right, "ID"), "Some id columns do not exist: ID missing among id, value") val actual = left.diff(right, "id").orderBy("id") val reverse = right.diff(left, "id").orderBy("id") @@ -526,10 +544,14 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { ) val expectedColumns = Seq( "diff", - "id", "seq", - "left_value1", "right_value1", - "left_value2", "right_value2", - "left_value3", "right_value3" + "id", + "seq", + "left_value1", + "right_value1", + "left_value2", + "right_value2", + "left_value3", + "right_value3" ) val actual = left.diff(right, "id", "seq").orderBy("id", "seq") @@ -548,10 +570,14 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { ) val expectedColumns = Seq( "diff", - "id", "seq", - "left_value2", "right_value2", - "left_value3", "right_value3", - "left_value1", "right_value1" + "id", + "seq", + "left_value2", + "right_value2", + "left_value3", + "right_value3", + "left_value1", + "right_value1" ) val actual = right.diff(left, "id", "seq").orderBy("id", "seq") @@ -570,7 +596,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("I", "val2.2.1", 2, "val2.2.2", 2, "val2.2.3") ) val expectedColumns = Seq( - "diff", "value1", "id", "value2", "seq", "value3" + "diff", + "value1", + "id", + "value2", + "seq", + "value3" ) val actual = left.diff(right).orderBy("id", "seq", "diff") @@ -589,7 +620,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { Row("I", "val2.2.1", 2, "val2.2.2", 2, "val2.2.3") ) val expectedColumns = Seq( - "diff", "value1", "id", "value2", "seq", "value3" + "diff", + "value1", + "id", + "value2", + "seq", + "value3" ) val actual = left.diff(right).orderBy("id", "seq", "diff") @@ -617,9 +653,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val expectedColumns = Seq( "diff", "id", - "left_left_value", "right_left_value", - "left_right_value", "right_right_value", - "left_value", "right_value" + "left_left_value", + "right_left_value", + "left_right_value", + "right_right_value", + "left_value", + "right_value" ) val expectedDiff = Seq( Row("C", 1, "left", "Left", "right", "Right", "value", "Value") @@ -633,21 +672,25 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq(Value4(1, "diff")).toDS() val right = Seq(Value4(1, "Diff")).toDS() - doTestRequirement(left.diff(right), - "The id columns must not contain the diff column name 'diff': id, diff") - doTestRequirement(left.diff(right, "diff"), - "The id columns must not contain the diff column name 'diff': diff") - doTestRequirement(left.diff(right, "diff", "id"), - "The id columns must not contain the diff column name 'diff': diff, id") + doTestRequirement(left.diff(right), "The id columns must not contain the diff column name 'diff': id, diff") + doTestRequirement(left.diff(right, "diff"), "The id columns must not contain the diff column name 'diff': diff") + doTestRequirement( + left.diff(right, "diff", "id"), + "The id columns must not contain the diff column name 'diff': diff, id" + ) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - doTestRequirement(left.withColumnRenamed("diff", "Diff") - .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id"), - "The id columns must not contain the diff column name 'diff': Diff, id") + doTestRequirement( + left + .withColumnRenamed("diff", "Diff") + .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id"), + "The id columns must not contain the diff column name 'diff': Diff, id" + ) } withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - left.withColumnRenamed("diff", "Diff") + left + .withColumnRenamed("diff", "Diff") .diff(right.withColumnRenamed("diff", "Diff"), "Diff", "id") } } @@ -660,7 +703,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val expectedColumns = Seq( "diff", "id", - "left_diff", "right_diff" + "left_diff", + "right_diff" ) val expectedDiff = Seq( Row("C", 1, "diff", "Diff") @@ -676,9 +720,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { .withLeftColumnPrefix("a") .withRightColumnPrefix("b") - doTestRequirement(left.diff(right, options, "id"), + doTestRequirement( + left.diff(right, options, "id"), "The column prefixes 'a' and 'b', together with these non-id columns " + - "must not produce the diff column name 'a_value': value") + "must not produce the diff column name 'a_value': value" + ) } test("diff with left-side mode where non-id column would produce diff column name") { @@ -708,9 +754,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { .withLeftColumnPrefix("A") .withRightColumnPrefix("B") - doTestRequirement(left.diff(right, options, "id"), + doTestRequirement( + left.diff(right, options, "id"), "The column prefixes 'A' and 'B', together with these non-id columns " + - "must not produce the diff column name 'a_value': value") + "must not produce the diff column name 'a_value': value" + ) } } @@ -747,7 +795,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val actual = left.diff(right, options, "id").orderBy("id") val expectedColumns = Seq( - "a_value", "id", "A_value", "B_value" + "a_value", + "id", + "A_value", + "B_value" ) assert(actual.columns === expectedColumns) @@ -761,9 +812,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { .withLeftColumnPrefix("a") .withRightColumnPrefix("b") - doTestRequirement(left.diff(right, options, "id"), + doTestRequirement( + left.diff(right, options, "id"), "The column prefixes 'a' and 'b', together with these non-id columns " + - "must not produce the change column name 'a_value': value") + "must not produce the change column name 'a_value': value" + ) } test("diff where case-insensitive non-id column produces change column name") { @@ -773,9 +826,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { .withLeftColumnPrefix("A") .withRightColumnPrefix("B") - doTestRequirement(left.diff(right, options, "id"), + doTestRequirement( + left.diff(right, options, "id"), "The column prefixes 'A' and 'B', together with these non-id columns " + - "must not produce the change column name 'a_value': value") + "must not produce the change column name 'a_value': value" + ) } } @@ -788,7 +843,13 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val actual = left7.diff(right7, options, "id").orderBy("id") val expectedColumns = Seq( - "diff", "a_value", "id", "A_value", "B_value", "A_label", "B_label" + "diff", + "a_value", + "id", + "A_value", + "B_value", + "A_label", + "B_label" ) assert(actual.columns === expectedColumns) @@ -804,9 +865,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq(Value5(1, "value")).toDS() val right = Seq(Value5(1, "Value")).toDS() - doTestRequirement(left.diff(right, options, "first_id"), + doTestRequirement( + left.diff(right, options, "first_id"), "The column prefixes 'first' and 'second', together with these non-id columns " + - "must not produce any id column name 'first_id': id") + "must not produce any id column name 'first_id': id" + ) } test("diff where case-insensitive non-id column produces id column name") { @@ -818,9 +881,11 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq(Value5(1, "value")).toDS() val right = Seq(Value5(1, "Value")).toDS() - doTestRequirement(left.diff(right, options, "first_id"), + doTestRequirement( + left.diff(right, options, "first_id"), "The column prefixes 'FIRST' and 'SECOND', together with these non-id columns " + - "must not produce any id column name 'first_id': id") + "must not produce any id column name 'first_id': id" + ) } } @@ -835,7 +900,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val actual = left.diff(right, options, "first_id") val expectedColumns = Seq( - "diff", "first_id", "FIRST_id", "SECOND_id" + "diff", + "first_id", + "FIRST_id", + "SECOND_id" ) assert(actual.columns === expectedColumns) @@ -870,8 +938,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq((1, "info")).toDF("id", "info") val right = Seq((1, "meta")).toDF("id", "meta") - doTestRequirement(left.diff(right, Seq.empty, Seq("id", "info", "meta")), - "The schema except ignored columns must not be empty") + doTestRequirement( + left.diff(right, Seq.empty, Seq("id", "info", "meta")), + "The schema except ignored columns must not be empty" + ) } test("diff with different types") { @@ -879,10 +949,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq((1, "str")).toDF("id", "value") val right = Seq((1, 2)).toDF("id", "value") - doTestRequirement(left.diff(right), + doTestRequirement( + left.diff(right), "The datasets do not have the same schema.\n" + "Left extra columns: value (StringType)\n" + - "Right extra columns: value (IntegerType)") + "Right extra columns: value (IntegerType)" + ) } test("diff with ignored columns of different types") { @@ -891,12 +963,16 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val right = Seq((1, 2)).toDF("id", "value") val actual = left.diff(right, Seq.empty, Seq("value")) - assert(ignoreNullable(actual.schema) === StructType(Seq( - StructField("diff", StringType), - StructField("id", IntegerType), - StructField("left_value", StringType), - StructField("right_value", IntegerType), - ))) + assert( + ignoreNullable(actual.schema) === StructType( + Seq( + StructField("diff", StringType), + StructField("id", IntegerType), + StructField("left_value", StringType), + StructField("right_value", IntegerType), + ) + ) + ) assert(actual.collect() === Seq(Row("N", 1, "str", 2))) } @@ -922,10 +998,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq((1, "str")).toDF("id", "value") val right = Seq((1, "str")).toDF("id", "comment") - doTestRequirement(left.diff(right, "id"), + doTestRequirement( + left.diff(right, "id"), "The datasets do not have the same schema.\n" + "Left extra columns: value (StringType)\n" + - "Right extra columns: comment (StringType)") + "Right extra columns: comment (StringType)" + ) } test("diff with case-insensitive column names") { @@ -950,16 +1028,20 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val right = this.right.toDF("ID", "VaLuE") withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - doTestRequirement(left.diff(right, "id"), + doTestRequirement( + left.diff(right, "id"), "The datasets do not have the same schema.\n" + "Left extra columns: id (IntegerType), value (StringType)\n" + - "Right extra columns: ID (IntegerType), VaLuE (StringType)") + "Right extra columns: ID (IntegerType), VaLuE (StringType)" + ) } } test("diff of non-existing id column") { - doTestRequirement(left.diff(right, "does not exists"), - "Some id columns do not exist: does not exists missing among id, value") + doTestRequirement( + left.diff(right, "does not exists"), + "Some id columns do not exist: does not exists missing among id, value" + ) } test("diff with different number of columns") { @@ -967,20 +1049,24 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = Seq((1, "str")).toDF("id", "value") val right = Seq((1, 1, "str")).toDF("id", "seq", "value") - doTestRequirement(left.diff(right, "id"), + doTestRequirement( + left.diff(right, "id"), "The number of columns doesn't match.\n" + "Left column names (2): id, value\n" + - "Right column names (3): id, seq, value") + "Right column names (3): id, seq, value" + ) } test("diff similar with ignored column and different number of columns") { val left = Seq((1, "str", "meta")).toDF("id", "value", "meta") val right = Seq((1, 1, "str")).toDF("id", "seq", "value") - doTestRequirement(left.diff(right, Seq("id"), Seq("meta")), + doTestRequirement( + left.diff(right, Seq("id"), Seq("meta")), "The number of columns doesn't match.\n" + "Left column names except ignored columns (2): id, value\n" + - "Right column names except ignored columns (3): id, seq, value") + "Right column names except ignored columns (3): id, seq, value" + ) } test("diff as U") { @@ -1010,7 +1096,8 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { (DiffOptions.default.nochangeDiffValue, "eq") ).toDF("diff", "action") - val expected = expectedDiffAs.toDS() + val expected = expectedDiffAs + .toDS() .join(actions, "diff") .select($"action", $"id", $"left_value".as("before_value"), $"right_value".as("after_value")) .as[DiffAsCustom] @@ -1032,8 +1119,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { } test("diff as U with extra column") { - doTestRequirement(left.diffAs[DiffAsExtra](right, "id"), - "Diff encoder's columns must be part of the diff result schema, these columns are unexpected: extra") + doTestRequirement( + left.diffAs[DiffAsExtra](right, "id"), + "Diff encoder's columns must be part of the diff result schema, these columns are unexpected: extra" + ) } test("diff with change column") { @@ -1041,15 +1130,19 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val actual = left7.diff(right7, options, "id").orderBy("id") assert(actual.columns === Seq("diff", "changes", "id", "left_value", "right_value", "left_label", "right_label")) - assert(actual.schema === StructType(Seq( - StructField("diff", StringType, nullable = false), - StructField("changes", ArrayType(StringType, containsNull = false), nullable = true), - StructField("id", IntegerType, nullable = true), - StructField("left_value", StringType, nullable = true), - StructField("right_value", StringType, nullable = true), - StructField("left_label", StringType, nullable = true), - StructField("right_label", StringType, nullable = true) - ))) + assert( + actual.schema === StructType( + Seq( + StructField("diff", StringType, nullable = false), + StructField("changes", ArrayType(StringType, containsNull = false), nullable = true), + StructField("id", IntegerType, nullable = true), + StructField("left_value", StringType, nullable = true), + StructField("right_value", StringType, nullable = true), + StructField("left_label", StringType, nullable = true), + StructField("right_label", StringType, nullable = true) + ) + ) + ) assert(actual.collect() === expectedDiff7WithChanges) } @@ -1058,15 +1151,21 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val actual = left7.diff(right7, options) assert(actual.columns === Seq("diff", "changes", "id", "value", "label")) - assert(actual.schema === StructType(Seq( - StructField("diff", StringType, nullable = false), - StructField("changes", ArrayType(StringType, containsNull = false), nullable = true), - StructField("id", IntegerType, nullable = true), - StructField("value", StringType, nullable = true), - StructField("label", StringType, nullable = true) - ))) - assert(actual.select($"diff", $"changes").distinct().orderBy($"diff").collect() === - Seq(Row("D", null), Row("I", null), Row("N", Seq.empty[String]))) + assert( + actual.schema === StructType( + Seq( + StructField("diff", StringType, nullable = false), + StructField("changes", ArrayType(StringType, containsNull = false), nullable = true), + StructField("id", IntegerType, nullable = true), + StructField("value", StringType, nullable = true), + StructField("label", StringType, nullable = true) + ) + ) + ) + assert( + actual.select($"diff", $"changes").distinct().orderBy($"diff").collect() === + Seq(Row("D", null), Row("I", null), Row("N", Seq.empty[String])) + ) } test("diff with change column name in non-id columns") { @@ -1079,8 +1178,10 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { test("diff with change column name in id columns") { val options = DiffOptions.default.withChangeColumn("value") - doTestRequirement(left.diff(right, options, "id", "value"), - "The id columns must not contain the change column name 'value': id, value") + doTestRequirement( + left.diff(right, options, "id", "value"), + "The id columns must not contain the change column name 'value': id, value" + ) } test("diff with column-by-column diff mode") { @@ -1145,26 +1246,34 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { test("diff with left-side diff mode and diff column name in value columns") { val options = DiffOptions.default.withDiffColumn("value").withDiffMode(DiffMode.LeftSide) - doTestRequirement(left.diff(right, options, "id"), - "The left non-id columns must not contain the diff column name 'value': value") + doTestRequirement( + left.diff(right, options, "id"), + "The left non-id columns must not contain the diff column name 'value': value" + ) } test("diff with right-side diff mode and diff column name in value columns") { val options = DiffOptions.default.withDiffColumn("value").withDiffMode(DiffMode.RightSide) - doTestRequirement(right.diff(right, options, "id"), - "The right non-id columns must not contain the diff column name 'value': value") + doTestRequirement( + right.diff(right, options, "id"), + "The right non-id columns must not contain the diff column name 'value': value" + ) } test("diff with left-side diff mode and change column name in value columns") { val options = DiffOptions.default.withChangeColumn("value").withDiffMode(DiffMode.LeftSide) - doTestRequirement(left.diff(right, options, "id"), - "The left non-id columns must not contain the change column name 'value': value") + doTestRequirement( + left.diff(right, options, "id"), + "The left non-id columns must not contain the change column name 'value': value" + ) } test("diff with right-side diff mode and change column name in value columns") { val options = DiffOptions.default.withChangeColumn("value").withDiffMode(DiffMode.RightSide) - doTestRequirement(right.diff(right, options, "id"), - "The right non-id columns must not contain the change column name 'value': value") + doTestRequirement( + right.diff(right, options, "id"), + "The right non-id columns must not contain the change column name 'value': value" + ) } test("diff with dots in diff column") { @@ -1283,10 +1392,12 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { schema.copy(fields = schema.fields .map(_.copy(nullable = true)) - .map(field => field.dataType match { - case a: ArrayType => field.copy(dataType = a.copy(containsNull = false)) - case _ => field - }) + .map(field => + field.dataType match { + case a: ArrayType => field.copy(dataType = a.copy(containsNull = false)) + case _ => field + } + ) ) } @@ -1297,24 +1408,63 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { } test("diff with ignored columns") { - assertIgnoredColumns(left8.diff(right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema) - assertIgnoredColumns(Diff.of(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema) - assertIgnoredColumns(Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema) - - assertIgnoredColumns[DiffAs8](left8.diffAs(right8, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, Encoders.product[DiffAs8].schema) - assertIgnoredColumns[DiffAs8](Diff.ofAs(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, Encoders.product[DiffAs8].schema) - assertIgnoredColumns[DiffAs8](Diff.default.diffAs(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, Encoders.product[DiffAs8].schema) - - val expected = expectedDiff8.map(row => ( - row.getString(0), - Value8(row.getInt(1), Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3)), Option(row.getString(5))), - Value8(row.getInt(1), Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(4)), Option(row.getString(6))) - )).map { case (diff, left, right) => ( - diff, - if (diff == "I") null else left, - if (diff == "D") null else right + assertIgnoredColumns( + left8.diff(right8, Seq("id", "seq"), Seq("meta")), + expectedDiff8, + Encoders.product[DiffAs8].schema ) - } + assertIgnoredColumns( + Diff.of(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedDiff8, + Encoders.product[DiffAs8].schema + ) + assertIgnoredColumns( + Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedDiff8, + Encoders.product[DiffAs8].schema + ) + + assertIgnoredColumns[DiffAs8]( + left8.diffAs(right8, Seq("id", "seq"), Seq("meta")), + expectedDiffAs8, + Encoders.product[DiffAs8].schema + ) + assertIgnoredColumns[DiffAs8]( + Diff.ofAs(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedDiffAs8, + Encoders.product[DiffAs8].schema + ) + assertIgnoredColumns[DiffAs8]( + Diff.default.diffAs(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedDiffAs8, + Encoders.product[DiffAs8].schema + ) + + val expected = expectedDiff8 + .map(row => + ( + row.getString(0), + Value8( + row.getInt(1), + Option(row.get(2)).map(_.asInstanceOf[Int]), + Option(row.getString(3)), + Option(row.getString(5)) + ), + Value8( + row.getInt(1), + Option(row.get(2)).map(_.asInstanceOf[Int]), + Option(row.getString(4)), + Option(row.getString(6)) + ) + ) + ) + .map { case (diff, left, right) => + ( + diff, + if (diff == "I") null else left, + if (diff == "D") null else right + ) + } assertDiffWith(left8.diffWith(right8, Seq("id", "seq"), Seq("meta")).collect(), expected) assertDiffWith(Diff.ofWith(left8, right8, Seq("id", "seq"), Seq("meta")).collect(), expected) @@ -1325,112 +1475,230 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val options = DiffOptions.default.withChangeColumn("changed") val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedDiff8WithChanges, Encoders.product[DiffAs8changes].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8WithChanges, Encoders.product[DiffAs8changes].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedDiff8WithChanges, + Encoders.product[DiffAs8changes].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedDiff8WithChanges, + Encoders.product[DiffAs8changes].schema + ) } test("diff with ignored columns and column-by-column diff mode") { val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedDiff8, Encoders.product[DiffAs8].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedDiff8, + Encoders.product[DiffAs8].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedDiff8, + Encoders.product[DiffAs8].schema + ) } test("diff with ignored columns and side-by-side diff mode") { val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedSideBySideDiff8, Encoders.product[DiffAs8SideBySide].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedSideBySideDiff8, Encoders.product[DiffAs8SideBySide].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedSideBySideDiff8, + Encoders.product[DiffAs8SideBySide].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedSideBySideDiff8, + Encoders.product[DiffAs8SideBySide].schema + ) } test("diff with ignored columns and left-side diff mode") { val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedLeftSideDiff8, Encoders.product[DiffAs8OneSide].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedLeftSideDiff8, Encoders.product[DiffAs8OneSide].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedLeftSideDiff8, + Encoders.product[DiffAs8OneSide].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedLeftSideDiff8, + Encoders.product[DiffAs8OneSide].schema + ) } test("diff with ignored columns and right-side diff mode") { val options = DiffOptions.default.withDiffMode(DiffMode.RightSide) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedRightSideDiff8, Encoders.product[DiffAs8OneSide].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedRightSideDiff8, Encoders.product[DiffAs8OneSide].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedRightSideDiff8, + Encoders.product[DiffAs8OneSide].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedRightSideDiff8, + Encoders.product[DiffAs8OneSide].schema + ) } test("diff with ignored columns, column-by-column diff and sparse mode") { val options = DiffOptions.default.withDiffMode(DiffMode.ColumnByColumn).withSparseMode(true) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedSparseDiff8, Encoders.product[DiffAs8].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedSparseDiff8, Encoders.product[DiffAs8].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedSparseDiff8, + Encoders.product[DiffAs8].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedSparseDiff8, + Encoders.product[DiffAs8].schema + ) } test("diff with ignored columns, side-by-side diff and sparse mode") { val options = DiffOptions.default.withDiffMode(DiffMode.SideBySide).withSparseMode(true) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedSideBySideSparseDiff8, Encoders.product[DiffAs8SideBySide].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedSideBySideSparseDiff8, Encoders.product[DiffAs8SideBySide].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedSideBySideSparseDiff8, + Encoders.product[DiffAs8SideBySide].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedSideBySideSparseDiff8, + Encoders.product[DiffAs8SideBySide].schema + ) } test("diff with ignored columns, left-side diff and sparse mode") { val options = DiffOptions.default.withDiffMode(DiffMode.LeftSide).withSparseMode(true) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedLeftSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedLeftSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedLeftSideSparseDiff8, + Encoders.product[DiffAs8OneSide].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedLeftSideSparseDiff8, + Encoders.product[DiffAs8OneSide].schema + ) } test("diff with ignored columns, right-side diff and sparse mode") { val options = DiffOptions.default.withDiffMode(DiffMode.RightSide).withSparseMode(true) val differ = new Differ(options) - assertIgnoredColumns(left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), expectedRightSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema) - assertIgnoredColumns(differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), expectedRightSideSparseDiff8, Encoders.product[DiffAs8OneSide].schema) + assertIgnoredColumns( + left8.diff(right8, options, Seq("id", "seq"), Seq("meta")), + expectedRightSideSparseDiff8, + Encoders.product[DiffAs8OneSide].schema + ) + assertIgnoredColumns( + differ.diff(left8, right8, Seq("id", "seq"), Seq("meta")), + expectedRightSideSparseDiff8, + Encoders.product[DiffAs8OneSide].schema + ) } test("diff similar with ignored columns") { - val expectedSchema = StructType(Seq( - StructField("diff", StringType), - StructField("id", IntegerType), - StructField("seq", IntegerType), - StructField("left_value", StringType), - StructField("right_value", StringType), - StructField("left_meta", StringType), - StructField("right_info", StringType), - )) + val expectedSchema = StructType( + Seq( + StructField("diff", StringType), + StructField("id", IntegerType), + StructField("seq", IntegerType), + StructField("left_value", StringType), + StructField("right_value", StringType), + StructField("left_meta", StringType), + StructField("right_info", StringType), + ) + ) assertIgnoredColumns(left8.diff(right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiff8and9, expectedSchema) - assertIgnoredColumns(Diff.of(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiff8and9, expectedSchema) - assertIgnoredColumns(Diff.default.diff(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiff8and9, expectedSchema) - - assertIgnoredColumns[DiffAs8and9](left8.diffAs(right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema) - assertIgnoredColumns[DiffAs8and9](Diff.ofAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema) - assertIgnoredColumns[DiffAs8and9](Diff.default.diffAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema) - - val expectedSchemaWith = StructType(Seq( - StructField("_1", StringType), - StructField("_2", StructType(Seq( - StructField("id", IntegerType, nullable = true), - StructField("seq", IntegerType, nullable = true), - StructField("value", StringType, nullable = true), - StructField("meta", StringType, nullable = true) - ))), - StructField("_3", StructType(Seq( - StructField("id", IntegerType, nullable = true), - StructField("seq", IntegerType, nullable = true), - StructField("value", StringType, nullable = true), - StructField("info", StringType, nullable = true) - ))), - )) - - assertDiffWithSchema(left8.diffWith(right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffWith8and9, expectedSchemaWith) - assertDiffWithSchema(Diff.ofWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffWith8and9, expectedSchemaWith) - assertDiffWithSchema(Diff.default.diffWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")), expectedDiffWith8and9, expectedSchemaWith) + assertIgnoredColumns( + Diff.of(left8, right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiff8and9, + expectedSchema + ) + assertIgnoredColumns( + Diff.default.diff(left8, right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiff8and9, + expectedSchema + ) + + assertIgnoredColumns[DiffAs8and9]( + left8.diffAs(right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffAs8and9, + expectedSchema + ) + assertIgnoredColumns[DiffAs8and9]( + Diff.ofAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffAs8and9, + expectedSchema + ) + assertIgnoredColumns[DiffAs8and9]( + Diff.default.diffAs(left8, right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffAs8and9, + expectedSchema + ) + + val expectedSchemaWith = StructType( + Seq( + StructField("_1", StringType), + StructField( + "_2", + StructType( + Seq( + StructField("id", IntegerType, nullable = true), + StructField("seq", IntegerType, nullable = true), + StructField("value", StringType, nullable = true), + StructField("meta", StringType, nullable = true) + ) + ) + ), + StructField( + "_3", + StructType( + Seq( + StructField("id", IntegerType, nullable = true), + StructField("seq", IntegerType, nullable = true), + StructField("value", StringType, nullable = true), + StructField("info", StringType, nullable = true) + ) + ) + ), + ) + ) + + assertDiffWithSchema( + left8.diffWith(right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffWith8and9, + expectedSchemaWith + ) + assertDiffWithSchema( + Diff.ofWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffWith8and9, + expectedSchemaWith + ) + assertDiffWithSchema( + Diff.default.diffWith(left8, right9, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffWith8and9, + expectedSchemaWith + ) } test("diff similar with ignored columns of different type") { @@ -1443,23 +1711,45 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val right = right8.toDF("ID", "SEQ", "VALUE", "META") def expectedSchema(id: String, seq: String): StructType = - StructType(Seq( - StructField("diff", StringType), - StructField(id, IntegerType), - StructField(seq, IntegerType), - StructField("left_value", StringType), - StructField("right_VALUE", StringType), - StructField("left_meta", StringType), - StructField("right_META", StringType), - )) + StructType( + Seq( + StructField("diff", StringType), + StructField(id, IntegerType), + StructField(seq, IntegerType), + StructField("left_value", StringType), + StructField("right_VALUE", StringType), + StructField("left_meta", StringType), + StructField("right_META", StringType), + ) + ) assertIgnoredColumns(left.diff(right, Seq("iD", "sEq"), Seq("MeTa")), expectedDiff8, expectedSchema("iD", "sEq")) - assertIgnoredColumns(Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA")), expectedDiff8, expectedSchema("Id", "SeQ")) - assertIgnoredColumns(Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META")), expectedDiff8, expectedSchema("ID", "SEQ")) + assertIgnoredColumns( + Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA")), + expectedDiff8, + expectedSchema("Id", "SeQ") + ) + assertIgnoredColumns( + Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META")), + expectedDiff8, + expectedSchema("ID", "SEQ") + ) - assertIgnoredColumns[DiffAs8](left.diffAs(right, Seq("id", "seq"), Seq("MeTa")), expectedDiffAs8, expectedSchema("id", "seq")) - assertIgnoredColumns[DiffAs8](Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA")), expectedDiffAs8, expectedSchema("id", "seq")) - assertIgnoredColumns[DiffAs8](Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta")), expectedDiffAs8, expectedSchema("id", "seq")) + assertIgnoredColumns[DiffAs8]( + left.diffAs(right, Seq("id", "seq"), Seq("MeTa")), + expectedDiffAs8, + expectedSchema("id", "seq") + ) + assertIgnoredColumns[DiffAs8]( + Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA")), + expectedDiffAs8, + expectedSchema("id", "seq") + ) + assertIgnoredColumns[DiffAs8]( + Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta")), + expectedDiffAs8, + expectedSchema("id", "seq") + ) // TODO: add diffWith } @@ -1470,26 +1760,44 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = left8.toDF("id", "seq", "value", "meta") val right = right8.toDF("ID", "SEQ", "VALUE", "META") - doTestRequirement(left.diff(right, Seq("Id", "SeQ"), Seq("MeTa")), - "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)") - doTestRequirement(Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa")), - "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)") - doTestRequirement(Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa")), - "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)") - - doTestRequirement(left8.diff(right8, Seq("Id", "SeQ"), Seq("MeTa")), - "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta") - doTestRequirement(Diff.of(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")), - "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta") - doTestRequirement(Diff.default.diff(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")), - "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta") - - doTestRequirement(left8.diff(right8, Seq("id", "seq"), Seq("MeTa")), - "Some ignore columns do not exist: MeTa missing among id, meta, seq, value") - doTestRequirement(Diff.of(left8, right8, Seq("id", "seq"), Seq("MeTa")), - "Some ignore columns do not exist: MeTa missing among id, meta, seq, value") - doTestRequirement(Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("MeTa")), - "Some ignore columns do not exist: MeTa missing among id, meta, seq, value") + doTestRequirement( + left.diff(right, Seq("Id", "SeQ"), Seq("MeTa")), + "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)" + ) + doTestRequirement( + Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa")), + "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)" + ) + doTestRequirement( + Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa")), + "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), META (StringType)" + ) + + doTestRequirement( + left8.diff(right8, Seq("Id", "SeQ"), Seq("MeTa")), + "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta" + ) + doTestRequirement( + Diff.of(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")), + "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta" + ) + doTestRequirement( + Diff.default.diff(left8, right8, Seq("Id", "SeQ"), Seq("MeTa")), + "Some id columns do not exist: Id, SeQ missing among id, seq, value, meta" + ) + + doTestRequirement( + left8.diff(right8, Seq("id", "seq"), Seq("MeTa")), + "Some ignore columns do not exist: MeTa missing among id, meta, seq, value" + ) + doTestRequirement( + Diff.of(left8, right8, Seq("id", "seq"), Seq("MeTa")), + "Some ignore columns do not exist: MeTa missing among id, meta, seq, value" + ) + doTestRequirement( + Diff.default.diff(left8, right8, Seq("id", "seq"), Seq("MeTa")), + "Some ignore columns do not exist: MeTa missing among id, meta, seq, value" + ) } } @@ -1499,45 +1807,97 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val right = right9.toDF("ID", "SEQ", "VALUE", "INFO").as[Value9up] def expectedSchema(id: String, seq: String): StructType = - StructType(Seq( - StructField("diff", StringType), - StructField(id, IntegerType), - StructField(seq, IntegerType), - StructField("left_value", StringType), - StructField("right_VALUE", StringType), - StructField("left_meta", StringType), - StructField("right_INFO", StringType), - )) - - assertIgnoredColumns(left.diff(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")), expectedDiff8and9, expectedSchema("iD", "sEq")) - assertIgnoredColumns(Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")), expectedDiff8and9, expectedSchema("Id", "SeQ")) - assertIgnoredColumns(Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")), expectedDiff8and9, expectedSchema("ID", "SEQ")) + StructType( + Seq( + StructField("diff", StringType), + StructField(id, IntegerType), + StructField(seq, IntegerType), + StructField("left_value", StringType), + StructField("right_VALUE", StringType), + StructField("left_meta", StringType), + StructField("right_INFO", StringType), + ) + ) + + assertIgnoredColumns( + left.diff(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")), + expectedDiff8and9, + expectedSchema("iD", "sEq") + ) + assertIgnoredColumns( + Diff.of(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")), + expectedDiff8and9, + expectedSchema("Id", "SeQ") + ) + assertIgnoredColumns( + Diff.default.diff(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")), + expectedDiff8and9, + expectedSchema("ID", "SEQ") + ) // TODO: remove generic type - assertIgnoredColumns[DiffAs8and9](left.diffAs(right, Seq("id", "seq"), Seq("MeTa", "InFo")), expectedDiffAs8and9, expectedSchema("id", "seq")) - assertIgnoredColumns[DiffAs8and9](Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA", "iNfO")), expectedDiffAs8and9, expectedSchema("id", "seq")) - assertIgnoredColumns[DiffAs8and9](Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta", "info")), expectedDiffAs8and9, expectedSchema("id", "seq")) + assertIgnoredColumns[DiffAs8and9]( + left.diffAs(right, Seq("id", "seq"), Seq("MeTa", "InFo")), + expectedDiffAs8and9, + expectedSchema("id", "seq") + ) + assertIgnoredColumns[DiffAs8and9]( + Diff.ofAs(left, right, Seq("id", "seq"), Seq("mEtA", "iNfO")), + expectedDiffAs8and9, + expectedSchema("id", "seq") + ) + assertIgnoredColumns[DiffAs8and9]( + Diff.default.diffAs(left, right, Seq("id", "seq"), Seq("meta", "info")), + expectedDiffAs8and9, + expectedSchema("id", "seq") + ) def expectedSchemaWith(id: String, seq: String): StructType = - StructType(Seq( - StructField("_1", StringType, nullable = false), - StructField("_2", StructType(Seq( - StructField(id, IntegerType), - StructField(seq, IntegerType), - StructField("value", StringType), - StructField("meta", StringType) - )), nullable = true), - StructField("_3", StructType(Seq( - StructField(id, IntegerType), - StructField(seq, IntegerType), - StructField("VALUE", StringType), - StructField("INFO", StringType) - )), nullable = true), - )) - - assertIgnoredColumns[(String, Value8, Value9up)](left.diffWith(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")), expectedDiffWith8and9up, expectedSchemaWith("iD", "sEq")) - assertIgnoredColumns[(String, Value8, Value9up)](Diff.ofWith(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")), expectedDiffWith8and9up, expectedSchemaWith("Id", "SeQ")) - assertIgnoredColumns[(String, Value8, Value9up)](Diff.default.diffWith(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")), expectedDiffWith8and9up, expectedSchemaWith("ID", "SEQ")) + StructType( + Seq( + StructField("_1", StringType, nullable = false), + StructField( + "_2", + StructType( + Seq( + StructField(id, IntegerType), + StructField(seq, IntegerType), + StructField("value", StringType), + StructField("meta", StringType) + ) + ), + nullable = true + ), + StructField( + "_3", + StructType( + Seq( + StructField(id, IntegerType), + StructField(seq, IntegerType), + StructField("VALUE", StringType), + StructField("INFO", StringType) + ) + ), + nullable = true + ), + ) + ) + + assertIgnoredColumns[(String, Value8, Value9up)]( + left.diffWith(right, Seq("iD", "sEq"), Seq("MeTa", "InFo")), + expectedDiffWith8and9up, + expectedSchemaWith("iD", "sEq") + ) + assertIgnoredColumns[(String, Value8, Value9up)]( + Diff.ofWith(left, right, Seq("Id", "SeQ"), Seq("mEtA", "iNfO")), + expectedDiffWith8and9up, + expectedSchemaWith("Id", "SeQ") + ) + assertIgnoredColumns[(String, Value8, Value9up)]( + Diff.default.diffWith(left, right, Seq("ID", "SEQ"), Seq("META", "INFO")), + expectedDiffWith8and9up, + expectedSchemaWith("ID", "SEQ") + ) } } @@ -1546,26 +1906,44 @@ class DiffSuite extends AnyFunSuite with SparkTestSession { val left = left8.toDF("id", "seq", "value", "meta").as[Value8] val right = right9.toDF("ID", "SEQ", "VALUE", "INFO").as[Value9up] - doTestRequirement(left.diff(right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), - "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)") - doTestRequirement(Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), - "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)") - doTestRequirement(Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), - "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)") - - doTestRequirement(left8.diff(right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), - "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)") - doTestRequirement(Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), - "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)") - doTestRequirement(Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), - "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)") - - doTestRequirement(left8.diff(right9, Seq("Id", "SeQ"), Seq("meta", "info")), - "Some id columns do not exist: Id, SeQ missing among id, seq, value") - doTestRequirement(Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")), - "Some id columns do not exist: Id, SeQ missing among id, seq, value") - doTestRequirement(Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")), - "Some id columns do not exist: Id, SeQ missing among id, seq, value") + doTestRequirement( + left.diff(right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), + "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)" + ) + doTestRequirement( + Diff.of(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), + "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)" + ) + doTestRequirement( + Diff.default.diff(left, right, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), + "The datasets do not have the same schema.\nLeft extra columns: id (IntegerType), seq (IntegerType), value (StringType), meta (StringType)\nRight extra columns: ID (IntegerType), SEQ (IntegerType), VALUE (StringType), INFO (StringType)" + ) + + doTestRequirement( + left8.diff(right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), + "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)" + ) + doTestRequirement( + Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), + "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)" + ) + doTestRequirement( + Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("MeTa", "InFo")), + "The datasets do not have the same schema.\nLeft extra columns: meta (StringType)\nRight extra columns: info (StringType)" + ) + + doTestRequirement( + left8.diff(right9, Seq("Id", "SeQ"), Seq("meta", "info")), + "Some id columns do not exist: Id, SeQ missing among id, seq, value" + ) + doTestRequirement( + Diff.of(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")), + "Some id columns do not exist: Id, SeQ missing among id, seq, value" + ) + doTestRequirement( + Diff.default.diff(left8, right9, Seq("Id", "SeQ"), Seq("meta", "info")), + "Some id columns do not exist: Id, SeQ missing among id, seq, value" + ) } } diff --git a/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala b/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala index 5c69813a..14664ad9 100644 --- a/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala +++ b/src/test/scala/uk/co/gresearch/spark/diff/examples/Examples.scala @@ -26,10 +26,12 @@ class Examples extends AnyFunSuite with SparkTestSession { test("issue") { import spark.implicits._ - val originalDF = Seq((1,"gaurav","jaipur",550,70000),(2,"sunil","noida",600,80000),(3,"rishi","ahmedabad",510,65000)) - .toDF("id","name","city","credit_score","credit_limit") - val changedDF= Seq((1,"gaurav","jaipur",550,70000),(2,"sunil","noida",650,90000),(4,"Joshua","cochin",612,85000)) - .toDF("id","name","city","credit_score","credit_limit") + val originalDF = + Seq((1, "gaurav", "jaipur", 550, 70000), (2, "sunil", "noida", 600, 80000), (3, "rishi", "ahmedabad", 510, 65000)) + .toDF("id", "name", "city", "credit_score", "credit_limit") + val changedDF = + Seq((1, "gaurav", "jaipur", 550, 70000), (2, "sunil", "noida", 650, 90000), (4, "Joshua", "cochin", 612, 85000)) + .toDF("id", "name", "city", "credit_score", "credit_limit") val options = DiffOptions.default.withChangeColumn("changes") val diff = originalDF.diff(changedDF, options, "id") diff.show(false) @@ -56,7 +58,7 @@ class Examples extends AnyFunSuite with SparkTestSession { { Seq(DiffMode.ColumnByColumn, DiffMode.SideBySide, DiffMode.LeftSide, DiffMode.RightSide).foreach { mode => - Seq(false, true).foreach{ sparse => + Seq(false, true).foreach { sparse => val options = DiffOptions.default.withDiffMode(mode) left.diff(right, options, "id").orderBy("id").show(false) } diff --git a/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala b/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala index b2fb22c8..4e04c36c 100644 --- a/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/group/GroupSuite.scala @@ -79,16 +79,22 @@ class GroupSuite extends AnyFunSpec { def testUnconsumedKey(unconsumedKey: K, func: () => Iterator[(K, V)]): Unit = { // we expect all tuples (k, Some(v)), except for k == unconsumedKey, where we expect (k, None) // here we consume all groups (it.toList), which is tested elsewhere to work - val expected = new GroupedIterator(func()).map { - case (key, it) if key == unconsumedKey => it.toList; (key, Iterator(None)) - case (key, it) => (key, it.map(Some(_))) - }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList + val expected = new GroupedIterator(func()) + .map { + case (key, it) if key == unconsumedKey => it.toList; (key, Iterator(None)) + case (key, it) => (key, it.map(Some(_))) + } + .flatMap { case (k, it) => it.map(v => (k, v)) } + .toList // here we do not consume the group with key `unconsumedKey` - val actual = new GroupedIterator(func()).map { - case (key, _) if key == unconsumedKey => (key, Iterator(None)) - case (key, it) => (key, it.map(Some(_))) - }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList + val actual = new GroupedIterator(func()) + .map { + case (key, _) if key == unconsumedKey => (key, Iterator(None)) + case (key, it) => (key, it.map(Some(_))) + } + .flatMap { case (k, it) => it.map(v => (k, v)) } + .toList assert(actual === expected) } @@ -135,16 +141,22 @@ class GroupSuite extends AnyFunSpec { // we expect all tuples (k, v), except for k == unconsumedKey, // where we expect only the first tuple with k == partiallyConsumedKey // here we consume all groups (it.toList), which is tested elsewhere to work - val expected = new GroupedIterator(func()).map { - case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.toList.head)) - case (key, it) => (key, it) - }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList + val expected = new GroupedIterator(func()) + .map { + case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.toList.head)) + case (key, it) => (key, it) + } + .flatMap { case (k, it) => it.map(v => (k, v)) } + .toList // here we only consume the first element of the group with key `unconsumedKey` - val actual = new GroupedIterator(func()).map { - case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.next())) - case (key, it) => (key, it) - }.flatMap { case (k, it) => it.map(v => (k, v)) }.toList + val actual = new GroupedIterator(func()) + .map { + case (key, it) if key == partiallyConsumedKey => (key, Iterator(it.next())) + case (key, it) => (key, it) + } + .flatMap { case (k, it) => it.map(v => (k, v)) } + .toList assert(actual === expected) } @@ -178,8 +190,8 @@ class GroupSuite extends AnyFunSpec { // this consumes all group iterators it("and fully consumed groups") { val expected = func().toList - val actual = new GroupedIterator(func()).flatMap { - case (k, it) => it.map(v => (k, v)) + val actual = new GroupedIterator(func()).flatMap { case (k, it) => + it.map(v => (k, v)) }.toList assert(actual === expected) } @@ -211,7 +223,7 @@ class GroupSuite extends AnyFunSpec { } describe("should iterate only over current key") { - def test[K : Ordering, V](it: Seq[(K, V)], expectedValues: Seq[V]): Unit = { + def test[K: Ordering, V](it: Seq[(K, V)], expectedValues: Seq[V]): Unit = { val git = new GroupIterator[K, V](it.iterator.buffered) assert(git.toList === expectedValues) } diff --git a/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala b/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala index b8b6cbb4..67ce3a16 100644 --- a/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/parquet/ParquetSuite.scala @@ -39,12 +39,14 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { val parallelisms = Seq(None, Some(1), Some(2), Some(8)) - def assertDf(actual: DataFrame, - order: Seq[Column], - expectedSchema: StructType, - expectedRows: Seq[Row], - expectedParallelism: Option[Int], - postProcess: DataFrame => DataFrame = identity): Unit = { + def assertDf( + actual: DataFrame, + order: Seq[Column], + expectedSchema: StructType, + expectedRows: Seq[Row], + expectedParallelism: Option[Int], + postProcess: DataFrame => DataFrame = identity + ): Unit = { assert(actual.schema === expectedSchema) if (expectedParallelism.isDefined) { @@ -56,7 +58,10 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { val replaced = actual .orderBy(order: _*) - .withColumn("filename", regexp_replace(regexp_replace($"filename", ".*/test.parquet/", ""), ".*/nested.parquet", "nested.parquet")) + .withColumn( + "filename", + regexp_replace(regexp_replace($"filename", ".*/test.parquet/", ""), ".*/nested.parquet", "nested.parquet") + ) .when(actual.columns.contains("schema")) .call(_.withColumn("schema", regexp_replace($"schema", "\n", "\\\\n"))) .call(postProcess) @@ -81,20 +86,22 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { .either(_.parquetMetadata(parallelism.get, testFile)) .or(_.parquetMetadata(testFile)), Seq($"filename"), - StructType(Seq( - StructField("filename", StringType, nullable = true), - StructField("blocks", IntegerType, nullable = false), - StructField("compressedBytes", LongType, nullable = false), - StructField("uncompressedBytes", LongType, nullable = false), - StructField("rows", LongType, nullable = false), - StructField("columns", IntegerType, nullable = false), - StructField("values", LongType, nullable = false), - StructField("nulls", LongType, nullable = true), - StructField("createdBy", StringType, nullable = true), - StructField("schema", StringType, nullable = true), - StructField("encryption", StringType, nullable = true), - StructField("keyValues", MapType(StringType, StringType, valueContainsNull = true), nullable = true), - )), + StructType( + Seq( + StructField("filename", StringType, nullable = true), + StructField("blocks", IntegerType, nullable = false), + StructField("compressedBytes", LongType, nullable = false), + StructField("uncompressedBytes", LongType, nullable = false), + StructField("rows", LongType, nullable = false), + StructField("columns", IntegerType, nullable = false), + StructField("values", LongType, nullable = false), + StructField("nulls", LongType, nullable = true), + StructField("createdBy", StringType, nullable = true), + StructField("schema", StringType, nullable = true), + StructField("encryption", StringType, nullable = true), + StructField("keyValues", MapType(StringType, StringType, valueContainsNull = true), nullable = true), + ) + ), Seq( Row("file1.parquet", 1, 1268, 1652, 100, 2, 200, 0, createdBy, schema, UNENCRYPTED, keyValues), Row("file2.parquet", 2, 2539, 3302, 200, 2, 400, 0, createdBy, schema, UNENCRYPTED, keyValues), @@ -116,21 +123,23 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { .either(_.parquetSchema(parallelism.get, nestedFile)) .or(_.parquetSchema(nestedFile)), Seq($"filename", $"columnPath"), - StructType(Seq( - StructField("filename", StringType, nullable = true), - StructField("columnName", StringType, nullable = true), - StructField("columnPath", ArrayType(StringType, containsNull = true), nullable = true), - StructField("repetition", StringType, nullable = true), - StructField("type", StringType, nullable = true), - StructField("length", IntegerType, nullable = true), - StructField("originalType", StringType, nullable = true), - StructField("logicalType", StringType, nullable = true), - StructField("isPrimitive", BooleanType, nullable = false), - StructField("primitiveType", StringType, nullable = true), - StructField("primitiveOrder", StringType, nullable = true), - StructField("maxDefinitionLevel", IntegerType, nullable = false), - StructField("maxRepetitionLevel", IntegerType, nullable = false), - )), + StructType( + Seq( + StructField("filename", StringType, nullable = true), + StructField("columnName", StringType, nullable = true), + StructField("columnPath", ArrayType(StringType, containsNull = true), nullable = true), + StructField("repetition", StringType, nullable = true), + StructField("type", StringType, nullable = true), + StructField("length", IntegerType, nullable = true), + StructField("originalType", StringType, nullable = true), + StructField("logicalType", StringType, nullable = true), + StructField("isPrimitive", BooleanType, nullable = false), + StructField("primitiveType", StringType, nullable = true), + StructField("primitiveOrder", StringType, nullable = true), + StructField("maxDefinitionLevel", IntegerType, nullable = false), + StructField("maxRepetitionLevel", IntegerType, nullable = false), + ) + ), // format: off Seq( Row("nested.parquet", "a", Seq("a"), "REQUIRED", "INT64", 0, null, null, true, "INT64", "TYPE_DEFINED_ORDER", 0, 0), @@ -153,17 +162,19 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { .either(_.parquetBlocks(parallelism.get, testFile)) .or(_.parquetBlocks(testFile)), Seq($"filename", $"block"), - StructType(Seq( - StructField("filename", StringType, nullable = true), - StructField("block", IntegerType, nullable = false), - StructField("blockStart", LongType, nullable = false), - StructField("compressedBytes", LongType, nullable = false), - StructField("uncompressedBytes", LongType, nullable = false), - StructField("rows", LongType, nullable = false), - StructField("columns", IntegerType, nullable = false), - StructField("values", LongType, nullable = false), - StructField("nulls", LongType, nullable = true), - )), + StructType( + Seq( + StructField("filename", StringType, nullable = true), + StructField("block", IntegerType, nullable = false), + StructField("blockStart", LongType, nullable = false), + StructField("compressedBytes", LongType, nullable = false), + StructField("uncompressedBytes", LongType, nullable = false), + StructField("rows", LongType, nullable = false), + StructField("columns", IntegerType, nullable = false), + StructField("values", LongType, nullable = false), + StructField("nulls", LongType, nullable = true), + ) + ), Seq( Row("file1.parquet", 1, 4, 1268, 1652, 100, 2, 200, 0), Row("file2.parquet", 1, 4, 1269, 1651, 100, 2, 200, 0), @@ -182,21 +193,23 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { .either(_.parquetBlockColumns(parallelism.get, testFile)) .or(_.parquetBlockColumns(testFile)), Seq($"filename", $"block", $"column"), - StructType(Seq( - StructField("filename", StringType, nullable = true), - StructField("block", IntegerType, nullable = false), - StructField("column", ArrayType(StringType), nullable = true), - StructField("codec", StringType, nullable = true), - StructField("type", StringType, nullable = true), - StructField("encodings", ArrayType(StringType), nullable = true), - StructField("minValue", StringType, nullable = true), - StructField("maxValue", StringType, nullable = true), - StructField("columnStart", LongType, nullable = false), - StructField("compressedBytes", LongType, nullable = false), - StructField("uncompressedBytes", LongType, nullable = false), - StructField("values", LongType, nullable = false), - StructField("nulls", LongType, nullable = true), - )), + StructType( + Seq( + StructField("filename", StringType, nullable = true), + StructField("block", IntegerType, nullable = false), + StructField("column", ArrayType(StringType), nullable = true), + StructField("codec", StringType, nullable = true), + StructField("type", StringType, nullable = true), + StructField("encodings", ArrayType(StringType), nullable = true), + StructField("minValue", StringType, nullable = true), + StructField("maxValue", StringType, nullable = true), + StructField("columnStart", LongType, nullable = false), + StructField("compressedBytes", LongType, nullable = false), + StructField("uncompressedBytes", LongType, nullable = false), + StructField("values", LongType, nullable = false), + StructField("nulls", LongType, nullable = true), + ) + ), // format: off Seq( Row("file1.parquet", 1, "[id]", "SNAPPY", "required int64 id", "[BIT_PACKED, PLAIN]", "0", "99", 4, 437, 826, 100, 0), @@ -208,9 +221,10 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { ), // format: on parallelism, - (df: DataFrame) => df - .withColumn("column", $"column".cast(StringType)) - .withColumn("encodings", $"encodings".cast(StringType)) + (df: DataFrame) => + df + .withColumn("column", $"column".cast(StringType)) + .withColumn("encodings", $"encodings".cast(StringType)) ) } } @@ -230,7 +244,14 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { .join(rows, Seq("partition"), "left") .select($"partition", $"start", $"end", $"length", $"rows", $"actual_rows", $"filename") - if (partitions.where($"rows" =!= $"actual_rows" || ($"rows" =!= 0 || $"actual_rows" =!= 0) && $"length" =!= partitionSize).head(1).nonEmpty) { + if ( + partitions + .where( + $"rows" =!= $"actual_rows" || ($"rows" =!= 0 || $"actual_rows" =!= 0) && $"length" =!= partitionSize + ) + .head(1) + .nonEmpty + ) { partitions .orderBy($"start") .where($"rows" =!= 0 || $"actual_rows" =!= 0) @@ -276,8 +297,11 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { ), ).foreach { case (partitionSize, expectedRows) => parallelisms.foreach { parallelism => - test(s"read parquet partitions (${partitionSize.getOrElse("default")} bytes) (parallelism=${parallelism.map(_.toString).getOrElse("None")})") { - withSQLConf(partitionSize.map(size => Seq("spark.sql.files.maxPartitionBytes" -> size.toString)).getOrElse(Seq.empty): _*) { + test(s"read parquet partitions (${partitionSize + .getOrElse("default")} bytes) (parallelism=${parallelism.map(_.toString).getOrElse("None")})") { + withSQLConf( + partitionSize.map(size => Seq("spark.sql.files.maxPartitionBytes" -> size.toString)).getOrElse(Seq.empty): _* + ) { val expected = expectedRows.map { case row if SparkMajorVersion > 3 || SparkMinorVersion >= 3 => row case row => Row(unapplySeq(row).get.updated(11, null): _*) @@ -296,21 +320,23 @@ class ParquetSuite extends AnyFunSuite with SparkTestSession with SparkVersion { assert(Seq(0, 0) === partitions) } - val schema = StructType(Seq( - StructField("partition", IntegerType, nullable = false), - StructField("start", LongType, nullable = false), - StructField("end", LongType, nullable = false), - StructField("length", LongType, nullable = false), - StructField("blocks", IntegerType, nullable = false), - StructField("compressedBytes", LongType, nullable = false), - StructField("uncompressedBytes", LongType, nullable = false), - StructField("rows", LongType, nullable = false), - StructField("columns", IntegerType, nullable = false), - StructField("values", LongType, nullable = false), - StructField("nulls", LongType, nullable = true), - StructField("filename", StringType, nullable = true), - StructField("fileLength", LongType, nullable = true), - )) + val schema = StructType( + Seq( + StructField("partition", IntegerType, nullable = false), + StructField("start", LongType, nullable = false), + StructField("end", LongType, nullable = false), + StructField("length", LongType, nullable = false), + StructField("blocks", IntegerType, nullable = false), + StructField("compressedBytes", LongType, nullable = false), + StructField("uncompressedBytes", LongType, nullable = false), + StructField("rows", LongType, nullable = false), + StructField("columns", IntegerType, nullable = false), + StructField("values", LongType, nullable = false), + StructField("nulls", LongType, nullable = true), + StructField("filename", StringType, nullable = true), + StructField("fileLength", LongType, nullable = true), + ) + ) assertDf(actual, Seq($"filename", $"start"), schema, expected, parallelism, df => df.drop("partition")) actual.unpersist()