Skip to content

Commit

Permalink
Move spark imports into trait
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Nov 3, 2024
1 parent b57627b commit e18259d
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 120 deletions.
22 changes: 18 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>uk.co.gresearch.spark</groupId>
<artifactId>spark-extension_2.13</artifactId>
<version>2.13.0-3.5-SNAPSHOT</version>
<version>2.13.0-4.0-SNAPSHOT</version>
<name>Spark Extension</name>
<description>A library that provides useful extensions to Apache Spark.</description>
<inceptionYear>2020</inceptionYear>
Expand Down Expand Up @@ -44,9 +44,9 @@
<scala.compat.version>${scala.major.version}.${scala.minor.version}</scala.compat.version>
<scala.version>${scala.compat.version}.${scala.patch.version}</scala.version>
<!-- keep in-sync with python/requirements-3.3_2.13.txt -->
<spark.major.version>3</spark.major.version>
<spark.minor.version>5</spark.minor.version>
<spark.patch.version>1</spark.patch.version>
<spark.major.version>4</spark.major.version>
<spark.minor.version>0</spark.minor.version>
<spark.patch.version>0-SNAPSHOT</spark.patch.version>
<spark.compat.version>${spark.major.version}.${spark.minor.version}</spark.compat.version>
<spark.version>${spark.compat.version}.${spark.patch.version}</spark.version>
</properties>
Expand Down Expand Up @@ -181,16 +181,30 @@
<version>3.5.0</version>
<executions>
<execution>
<id>main</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>src/main/scala-spark-${spark.major.version}</source>
<source>src/main/scala-spark-${spark.compat.version}</source>
</sources>
</configuration>
</execution>
<execution>
<id>test</id>
<phase>generate-test-sources</phase>
<goals>
<goal>add-test-source</goal>
</goals>
<configuration>
<sources>
<source>src/test/scala-spark-${spark.major.version}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from setuptools import setup

jar_version = '2.13.0-3.5-SNAPSHOT'
jar_version = '2.13.0-4.0-SNAPSHOT'
scala_version = '2.13.8'
scala_compat_version = '.'.join(scala_version.split('.')[:2])
spark_compat_version = jar_version.split('-')[1]
Expand Down
34 changes: 34 additions & 0 deletions src/main/scala-spark-3/uk/co/gresearch/spark/SparkImports.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2024 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package uk.co.gresearch.spark

import org.apache.spark.sql
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.extension.ColumnExtension

trait SparkImports {
type Dataset[T] = sql.Dataset[T]
type DataFrame = sql.Dataset[Row]
type KeyValueGroupedDataset[K, V] = sql.KeyValueGroupedDataset[K, V]

def sql(col: Column): String = col.sql
def plan(ds: Dataset[_]): LogicalPlan = ds.queryExecution.logical
def output(ds: Dataset[_]): Seq[Attribute] = ds.queryExecution.analyzed.output
}
33 changes: 33 additions & 0 deletions src/main/scala-spark-4/uk/co/gresearch/spark/SparkImports.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2024 G-Research
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package uk.co.gresearch.spark

import org.apache.spark.sql.{Column, Row, api}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.extension.ColumnExtension

trait SparkImports {
type Dataset[T] = api.Dataset[T]
type DataFrame = api.Dataset[Row]
type KeyValueGroupedDataset[K, V] = api.KeyValueGroupedDataset[K, V]

def sql(col: Column): String = col.sql
def plan(ds: Dataset[_]): LogicalPlan = ds.queryExecution.logical
def output(ds: Dataset[_]): Seq[Attribute] = ds.queryExecution.analyzed.output
}
2 changes: 1 addition & 1 deletion src/main/scala/uk/co/gresearch/spark/Histogram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package uk.co.gresearch.spark

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{sum, when}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import uk.co.gresearch.ExtendedAny

import scala.collection.JavaConverters
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/uk/co/gresearch/spark/RowNumbers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package uk.co.gresearch.spark

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions}
import org.apache.spark.sql.{Column, functions}
import org.apache.spark.sql.functions.{coalesce, col, lit, max, monotonically_increasing_id, spark_partition_id, sum}
import org.apache.spark.storage.StorageLevel

Expand All @@ -26,7 +26,7 @@ case class RowNumbersFunc(
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
unpersistHandle: UnpersistHandle = UnpersistHandle.Noop,
orderColumns: Seq[Column] = Seq.empty
) {
) extends SparkImports {

def withRowNumberColumnName(rowNumberColumnName: String): RowNumbersFunc =
this.copy(rowNumberColumnName = rowNumberColumnName)
Expand Down
2 changes: 0 additions & 2 deletions src/main/scala/uk/co/gresearch/spark/UnpersistHandle.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package uk.co.gresearch.spark

import org.apache.spark.sql.DataFrame

/**
* Handle to call `DataFrame.unpersist` on a `DataFrame` that is not known to the caller. The [[RowNumbers.of]]
* constructs a `DataFrame` that is based ony an intermediate cached `DataFrame`, for witch `unpersist` must be called.
Expand Down
7 changes: 4 additions & 3 deletions src/main/scala/uk/co/gresearch/spark/diff/App.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
package uk.co.gresearch.spark.diff

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.sql.{SaveMode, SparkSession}
import scopt.OptionParser
import uk.co.gresearch._
import uk.co.gresearch.spark.SparkImports

object App {
object App extends SparkImports {
// define available options
case class Options(
master: Option[String] = None,
Expand Down Expand Up @@ -262,7 +263,7 @@ object App {
.when(schema.isDefined)
.call(_.schema(schema.get))
.when(format.isDefined)
.either(_.load(path))
.either(_.load(path).asInstanceOf[DataFrame])
.or(_.table(path))

def write(
Expand Down
4 changes: 1 addition & 3 deletions src/main/scala/uk/co/gresearch/spark/diff/Diff.scala
Original file line number Diff line number Diff line change
Expand Up @@ -642,15 +642,13 @@ class Differ(options: DiffOptions) {
.otherwise(struct(rightColumns: _*))
.as("_3")

val plan = diff.select(diffColumn, leftStruct, rightStruct).queryExecution.logical

val encoder: Encoder[(String, T, U)] = Encoders.tuple(
Encoders.STRING,
implicitly[Encoder[T]],
implicitly[Encoder[U]]
)

new Dataset(diff.sparkSession, plan, encoder)
diff.select(diffColumn, leftStruct, rightStruct).as(encoder)
}

}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/uk/co/gresearch/spark/diff/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package uk.co.gresearch.spark

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.{DataFrame, Dataset, Encoder}
import org.apache.spark.sql.Encoder

import java.util.Locale

package object diff {
package object diff extends SparkImports {

implicit class DatasetDiff[T](ds: Dataset[T]) {

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/uk/co/gresearch/spark/group/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package uk.co.gresearch.spark

import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders}
import org.apache.spark.sql.{Column, Encoder, Encoders}
import uk.co.gresearch.ExtendedAny

package object group {
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/uk/co/gresearch/spark/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import uk.co.gresearch.spark.group.SortedGroupByDataset

import java.nio.file.{Files, Paths}

package object spark extends Logging with SparkVersion with BuildVersion {
package object spark extends Logging with SparkVersion with BuildVersion with SparkImports {

/**
* Provides a prefix that makes any string distinct w.r.t. the given strings.
Expand Down Expand Up @@ -741,7 +741,7 @@ package object spark extends Logging with SparkVersion with BuildVersion {
case _ =>
}
// resolve partition column names
val partitionColumnNames = ds.select(partitionColumns: _*).queryExecution.analyzed.output.map(_.name)
val partitionColumnNames = output(ds.select(partitionColumns: _*)).map(_.name)
val partitionColumnsMap = partitionColumnNames.zip(partitionColumns).toMap
val rangeColumns = partitionColumnNames.map(col) ++ moreFileColumns
val sortColumns = partitionColumnNames.map(col) ++ moreFileColumns ++ moreFileOrder
Expand Down
Loading

0 comments on commit e18259d

Please sign in to comment.