Skip to content

Commit

Permalink
Apply scalafmt changes
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Nov 30, 2023
1 parent ed8fe18 commit 3598b31
Show file tree
Hide file tree
Showing 32 changed files with 3,091 additions and 1,963 deletions.
10 changes: 10 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@
</scalafmt>
</scala>
</configuration>
<executions>
<execution>
<!-- Runs in compile phase to fail fast in case of formatting issues.-->
<id>spotless-check</id>
<phase>compile</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- run java tests -->
<plugin>
Expand Down
33 changes: 19 additions & 14 deletions src/main/scala/uk/co/gresearch/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,28 @@ package object gresearch {

trait ConditionalCall[T] {
def call(f: T => T): T
def either[R](f: T => R): ConditionalCallOr[T,R]
def either[R](f: T => R): ConditionalCallOr[T, R]
}

trait ConditionalCallOr[T,R] {
trait ConditionalCallOr[T, R] {
def or(f: T => R): R
}

case class TrueCall[T](t: T) extends ConditionalCall[T] {
override def call(f: T => T): T = f(t)
override def either[R](f: T => R): ConditionalCallOr[T,R] = TrueCallOr[T,R](f(t))
override def either[R](f: T => R): ConditionalCallOr[T, R] = TrueCallOr[T, R](f(t))
}

case class FalseCall[T](t: T) extends ConditionalCall[T] {
override def call(f: T => T): T = t
override def either[R](f: T => R): ConditionalCallOr[T,R] = FalseCallOr[T,R](t)
override def either[R](f: T => R): ConditionalCallOr[T, R] = FalseCallOr[T, R](t)
}

case class TrueCallOr[T,R](r: R) extends ConditionalCallOr[T,R] {
case class TrueCallOr[T, R](r: R) extends ConditionalCallOr[T, R] {
override def or(f: T => R): R = r
}

case class FalseCallOr[T,R](t: T) extends ConditionalCallOr[T,R] {
case class FalseCallOr[T, R](t: T) extends ConditionalCallOr[T, R] {
override def or(f: T => R): R = f(t)
}

Expand Down Expand Up @@ -71,16 +71,17 @@ package object gresearch {
*
* which either needs many temporary variables or duplicate code.
*
* @param condition condition
* @return the function result
* @param condition
* condition
* @return
* the function result
*/
def on(condition: Boolean): ConditionalCall[T] = {
if (condition) TrueCall[T](t) else FalseCall[T](t)
}

/**
* Allows to call a function on the decorated instance conditionally.
* This is an alias for the `on` function.
* Allows to call a function on the decorated instance conditionally. This is an alias for the `on` function.
*
* This allows fluent code like
*
Expand All @@ -103,8 +104,10 @@ package object gresearch {
*
* which either needs many temporary variables or duplicate code.
*
* @param condition condition
* @return the function result
* @param condition
* condition
* @return
* the function result
*/
def when(condition: Boolean): ConditionalCall[T] = on(condition)

Expand All @@ -131,8 +134,10 @@ package object gresearch {
*
* where the effective sequence of operations is not clear.
*
* @param f function
* @return the function result
* @param f
* function
* @return
* the function result
*/
def call[R](f: T => R): R = f(t)
}
Expand Down
9 changes: 6 additions & 3 deletions src/main/scala/uk/co/gresearch/spark/Backticks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ object Backticks {
* col(Backticks.column_name("a.column", "a.field")) // produces "`a.column`.`a.field`"
* }}}
*
* @param string a string
* @param strings more strings
* @param string
* a string
* @param strings
* more strings
* @return
*/
@scala.annotation.varargs
def column_name(string: String, strings: String*): String = (string +: strings)
.map(s => if (s.contains(".") && !s.startsWith("`") && !s.endsWith("`")) s"`$s`" else s).mkString(".")
.map(s => if (s.contains(".") && !s.startsWith("`") && !s.endsWith("`")) s"`$s`" else s)
.mkString(".")

}
2 changes: 1 addition & 1 deletion src/main/scala/uk/co/gresearch/spark/BuildVersion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ trait BuildVersion {
}

lazy val VersionString: String = props.getProperty("project.version")

lazy val BuildSparkMajorVersion: Int = props.getProperty("spark.major.version").toInt
lazy val BuildSparkMinorVersion: Int = props.getProperty("spark.minor.version").toInt
lazy val BuildSparkPatchVersion: Int = props.getProperty("spark.patch.version").split("-").head.toInt
Expand Down
71 changes: 42 additions & 29 deletions src/main/scala/uk/co/gresearch/spark/Histogram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,25 @@ import uk.co.gresearch.ExtendedAny
import scala.collection.JavaConverters

object Histogram {

/**
* Compute the histogram of a column when aggregated by aggregate columns.
* Thresholds are expected to be provided in ascending order.
* The result dataframe contains the aggregate and histogram columns only.
* For each threshold value in thresholds, there will be a column named s"≤threshold".
* There will also be a final column called s">last_threshold", that counts the remaining
* values that exceed the last threshold.
* Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
* ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
* in thresholds, there will be a column named s"≤threshold". There will also be a final column called
* s">last_threshold", that counts the remaining values that exceed the last threshold.
*
* @param df dataset to compute histogram from
* @param thresholds sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn histogram is computed for values of this column
* @param aggregateColumns histogram is computed against these columns
* @tparam T type of histogram thresholds
* @return dataframe with aggregate and histogram columns
* @param df
* dataset to compute histogram from
* @param thresholds
* sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn
* histogram is computed for values of this column
* @param aggregateColumns
* histogram is computed against these columns
* @tparam T
* type of histogram thresholds
* @return
* dataframe with aggregate and histogram columns
*/
def of[D, T](df: Dataset[D], thresholds: Seq[T], valueColumn: Column, aggregateColumns: Column*): DataFrame = {
if (thresholds.isEmpty)
Expand All @@ -49,10 +54,9 @@ object Histogram {

df.toDF()
.withColumn(s"${thresholds.head}", when(valueColumn <= thresholds.head, 1).otherwise(0))
.call(
bins.foldLeft(_) { case (df, bin) =>
df.withColumn(s"${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0))
})
.call(bins.foldLeft(_) { case (df, bin) =>
df.withColumn(s"${bin.last}", when(valueColumn > bin.head && valueColumn <= bin.last, 1).otherwise(0))
})
.withColumn(s">${thresholds.last}", when(valueColumn > thresholds.last, 1).otherwise(0))
.groupBy(aggregateColumns: _*)
.agg(
Expand All @@ -63,22 +67,31 @@ object Histogram {
}

/**
* Compute the histogram of a column when aggregated by aggregate columns.
* Thresholds are expected to be provided in ascending order.
* The result dataframe contains the aggregate and histogram columns only.
* For each threshold value in thresholds, there will be a column named s"≤threshold".
* There will also be a final column called s">last_threshold", that counts the remaining
* values that exceed the last threshold.
* Compute the histogram of a column when aggregated by aggregate columns. Thresholds are expected to be provided in
* ascending order. The result dataframe contains the aggregate and histogram columns only. For each threshold value
* in thresholds, there will be a column named s"≤threshold". There will also be a final column called
* s">last_threshold", that counts the remaining values that exceed the last threshold.
*
* @param df dataset to compute histogram from
* @param thresholds sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn histogram is computed for values of this column
* @param aggregateColumns histogram is computed against these columns
* @tparam T type of histogram thresholds
* @return dataframe with aggregate and histogram columns
* @param df
* dataset to compute histogram from
* @param thresholds
* sequence of thresholds in ascending order, must implement <= and > operators w.r.t. valueColumn
* @param valueColumn
* histogram is computed for values of this column
* @param aggregateColumns
* histogram is computed against these columns
* @tparam T
* type of histogram thresholds
* @return
* dataframe with aggregate and histogram columns
*/
@scala.annotation.varargs
def of[D, T](df: Dataset[D], thresholds: java.util.List[T], valueColumn: Column, aggregateColumns: Column*): DataFrame =
def of[D, T](
df: Dataset[D],
thresholds: java.util.List[T],
valueColumn: Column,
aggregateColumns: Column*
): DataFrame =
of(df, JavaConverters.iterableAsScalaIterable(thresholds).toSeq, valueColumn, aggregateColumns: _*)

}
31 changes: 22 additions & 9 deletions src/main/scala/uk/co/gresearch/spark/RowNumbers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, functions}
import org.apache.spark.sql.functions.{coalesce, col, lit, max, monotonically_increasing_id, spark_partition_id, sum}
import org.apache.spark.storage.StorageLevel

case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,
orderColumns: Seq[Column] = Seq.empty) {
case class RowNumbersFunc(
rowNumberColumnName: String = "row_number",
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,
orderColumns: Seq[Column] = Seq.empty
) {

def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =
this.copy(rowNumberColumnName = rowNumberColumnName)
Expand All @@ -39,7 +41,11 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
this.copy(orderColumns = orderColumns)

def of[D](df: Dataset[D]): DataFrame = {
if (storageLevel.equals(StorageLevel.NONE) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)) {
if (
storageLevel.equals(
StorageLevel.NONE
) && (SparkMajorVersion > 3 || SparkMajorVersion == 3 && SparkMinorVersion >= 5)
) {
throw new IllegalArgumentException(s"Storage level $storageLevel not supported with Spark 3.5.0 and above.")
}

Expand All @@ -53,7 +59,9 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
val partitionOffsetColumnName = prefix + "partition_offset"

// if no order is given, we preserve existing order
val dfOrdered = if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id()) else df.orderBy(orderColumns: _*)
val dfOrdered =
if (orderColumns.isEmpty) df.withColumn(monoIdColumnName, monotonically_increasing_id())
else df.orderBy(orderColumns: _*)
val order = if (orderColumns.isEmpty) Seq(col(monoIdColumnName)) else orderColumns

// add partition ids and local row numbers
Expand All @@ -66,17 +74,22 @@ case class RowNumbersFunc(rowNumberColumnName: String = "row_number",
.withColumn(localRowNumberColumnName, functions.row_number().over(localRowNumberWindow))

// compute row offset for the partitions
val cumRowNumbersWindow = Window.orderBy(partitionIdColumnName)
val cumRowNumbersWindow = Window
.orderBy(partitionIdColumnName)
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val partitionOffsets = dfWithLocalRowNumbers
.groupBy(partitionIdColumnName)
.agg(max(localRowNumberColumnName).alias(maxLocalRowNumberColumnName))
.withColumn(cumRowNumbersColumnName, sum(maxLocalRowNumberColumnName).over(cumRowNumbersWindow))
.select(col(partitionIdColumnName) + 1 as partitionIdColumnName, col(cumRowNumbersColumnName).as(partitionOffsetColumnName))
.select(
col(partitionIdColumnName) + 1 as partitionIdColumnName,
col(cumRowNumbersColumnName).as(partitionOffsetColumnName)
)

// compute global row number by adding local row number with partition offset
val partitionOffsetColumn = coalesce(col(partitionOffsetColumnName), lit(0))
dfWithLocalRowNumbers.join(partitionOffsets, Seq(partitionIdColumnName), "left")
dfWithLocalRowNumbers
.join(partitionOffsets, Seq(partitionIdColumnName), "left")
.withColumn(rowNumberColumnName, col(localRowNumberColumnName) + partitionOffsetColumn)
.drop(monoIdColumnName, partitionIdColumnName, localRowNumberColumnName, partitionOffsetColumnName)
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class SilentUnpersistHandle() extends UnpersistHandle {
}
}

case class NoopUnpersistHandle() extends UnpersistHandle{
case class NoopUnpersistHandle() extends UnpersistHandle {
override def setDataFrame(dataframe: DataFrame): DataFrame = dataframe
override def apply(): Unit = {}
override def apply(blocking: Boolean): Unit = {}
Expand Down
Loading

0 comments on commit 3598b31

Please sign in to comment.