Skip to content

Commit

Permalink
[SPARK-44720][CONNECT] Make Dataset use Encoder instead of AgnosticEn…
Browse files Browse the repository at this point in the history
…coder

### What changes were proposed in this pull request?
Make the Spark Connect Dataset use Encoder instead of AgnosticEncoder

### Why are the changes needed?
We want to improve binary compatibility between the Spark Connect Scala Client and the original sql/core APIs.

### Does this PR introduce _any_ user-facing change?
Yes. It changes the type of `Dataset.encoder` from `AgnosticEncoder` to `Encoder`.

### How was this patch tested?
Existing tests.

Closes apache#42396 from hvanhovell/SPARK-44720.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Aug 9, 2023
1 parent 27c5a1f commit be9ffb3
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ import org.apache.spark.util.SparkClassUtils
class Dataset[T] private[sql] (
val sparkSession: SparkSession,
@DeveloperApi val plan: proto.Plan,
val encoder: AgnosticEncoder[T])
val encoder: Encoder[T])
extends Serializable {
// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)

private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder)

override def toString: String = {
try {
val builder = new mutable.StringBuilder
Expand Down Expand Up @@ -828,7 +830,7 @@ class Dataset[T] private[sql] (
}

private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSortBuilder
.setInput(plan.getRoot)
.setIsGlobal(global)
Expand Down Expand Up @@ -878,8 +880,8 @@ class Dataset[T] private[sql] (
ProductEncoder[(T, U)](
ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
Seq(
EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty),
EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty)))
EncoderField(s"_1", this.agnosticEncoder, leftNullable, Metadata.empty),
EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)))

sparkSession.newDataset(tupleEncoder) { builder =>
val joinBuilder = builder.getJoinBuilder
Expand All @@ -889,8 +891,8 @@ class Dataset[T] private[sql] (
.setJoinType(joinTypeValue)
.setJoinCondition(condition.expr)
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftStruct(this.encoder.isStruct)
.setIsRightStruct(other.encoder.isStruct))
.setIsLeftStruct(this.agnosticEncoder.isStruct)
.setIsRightStruct(other.agnosticEncoder.isStruct))
}
}

Expand Down Expand Up @@ -1010,13 +1012,13 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
@scala.annotation.varargs
def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) {
builder =>
def hint(name: String, parameters: Any*): Dataset[T] =
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
}
}

private def getPlanId: Option[Long] =
if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
Expand Down Expand Up @@ -1056,7 +1058,7 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSubqueryAliasBuilder
.setInput(plan.getRoot)
.setAlias(alias)
Expand Down Expand Up @@ -1238,8 +1240,9 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
}

/**
Expand Down Expand Up @@ -1355,12 +1358,12 @@ class Dataset[T] private[sql] (
def reduce(func: (T, T) => T): T = {
val udf = ScalarUserDefinedFunction(
function = func,
inputEncoders = encoder :: encoder :: Nil,
outputEncoder = encoder)
inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
outputEncoder = agnosticEncoder)
val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr

val result = sparkSession
.newDataset(encoder) { builder =>
.newDataset(agnosticEncoder) { builder =>
builder.getAggregateBuilder
.setInput(plan.getRoot)
.addAggregateExpressions(reduceExpr)
Expand Down Expand Up @@ -1718,7 +1721,7 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getLimitBuilder
.setInput(plan.getRoot)
.setLimit(n)
Expand All @@ -1730,7 +1733,7 @@ class Dataset[T] private[sql] (
* @group typedrel
* @since 3.4.0
*/
def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getOffsetBuilder
.setInput(plan.getRoot)
.setOffset(n)
Expand All @@ -1739,7 +1742,7 @@ class Dataset[T] private[sql] (
private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)(
f: proto.SetOperation.Builder => Unit): Dataset[T] = {
checkSameSparkSession(right)
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
f(
builder.getSetOpBuilder
.setSetOpType(setOpType)
Expand Down Expand Up @@ -2012,7 +2015,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSampleBuilder
.setInput(plan.getRoot)
.setWithReplacement(withReplacement)
Expand Down Expand Up @@ -2080,7 +2083,7 @@ class Dataset[T] private[sql] (
normalizedCumWeights
.sliding(2)
.map { case Array(low, high) =>
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getSampleBuilder
.setInput(sortedInput)
.setWithReplacement(false)
Expand Down Expand Up @@ -2401,15 +2404,16 @@ class Dataset[T] private[sql] (

private def buildDropDuplicates(
columns: Option[Seq[String]],
withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
val dropBuilder = builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setWithinWatermark(withinWaterMark)
if (columns.isDefined) {
dropBuilder.addAllColumnNames(columns.get.asJava)
} else {
dropBuilder.setAllColumnsAsKeys(true)
}
withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
val dropBuilder = builder.getDeduplicateBuilder
.setInput(plan.getRoot)
.setWithinWatermark(withinWaterMark)
if (columns.isDefined) {
dropBuilder.addAllColumnNames(columns.get.asJava)
} else {
dropBuilder.setAllColumnsAsKeys(true)
}
}

/**
Expand Down Expand Up @@ -2630,9 +2634,9 @@ class Dataset[T] private[sql] (
def filter(func: T => Boolean): Dataset[T] = {
val udf = ScalarUserDefinedFunction(
function = func,
inputEncoders = encoder :: Nil,
inputEncoders = agnosticEncoder :: Nil,
outputEncoder = PrimitiveBooleanEncoder)
sparkSession.newDataset[T](encoder) { builder =>
sparkSession.newDataset[T](agnosticEncoder) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
.setCondition(udf.apply(col("*")).expr)
Expand Down Expand Up @@ -2683,7 +2687,7 @@ class Dataset[T] private[sql] (
val outputEncoder = encoderFor[U]
val udf = ScalarUserDefinedFunction(
function = func,
inputEncoders = encoder :: Nil,
inputEncoders = agnosticEncoder :: Nil,
outputEncoder = outputEncoder)
sparkSession.newDataset(outputEncoder) { builder =>
builder.getMapPartitionsBuilder
Expand Down Expand Up @@ -2785,7 +2789,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def tail(n: Int): Array[T] = {
val lastN = sparkSession.newDataset(encoder) { builder =>
val lastN = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getTailBuilder
.setInput(plan.getRoot)
.setLimit(n)
Expand Down Expand Up @@ -2856,7 +2860,7 @@ class Dataset[T] private[sql] (
}

private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getRepartitionBuilder
.setInput(plan.getRoot)
.setNumPartitions(numPartitions)
Expand All @@ -2866,11 +2870,12 @@ class Dataset[T] private[sql] (

private def buildRepartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}

/**
Expand Down Expand Up @@ -3183,7 +3188,7 @@ class Dataset[T] private[sql] (
* @since 3.5.0
*/
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getWithWatermarkBuilder
.setInput(plan.getRoot)
.setEventTime(eventTime)
Expand Down Expand Up @@ -3251,7 +3256,7 @@ class Dataset[T] private[sql] (
sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA)
}

def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder)
def collectResult(): SparkResult[T] = sparkSession.execute(plan, agnosticEncoder)

private[sql] def withResult[E](f: SparkResult[T] => E): E = {
val result = collectResult()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -988,15 +988,15 @@ private object KeyValueGroupedDatasetImpl {
groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = {
val gf = ScalarUserDefinedFunction(
function = groupingFunc,
inputEncoders = ds.encoder :: Nil, // Using the original value and key encoders
inputEncoders = ds.agnosticEncoder :: Nil, // Using the original value and key encoders
outputEncoder = kEncoder)
new KeyValueGroupedDatasetImpl(
ds.sparkSession,
ds.plan,
kEncoder,
kEncoder,
ds.encoder,
ds.encoder,
ds.agnosticEncoder,
ds.agnosticEncoder,
Arrays.asList(gf.apply(col("*")).expr),
UdfUtils.identical(),
() => ds.map(groupingFunc)(kEncoder))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
* @since 3.5.0
*/
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.encoder))
val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder))
val scalaWriterBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serialized))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected
Expand Down Expand Up @@ -334,8 +333,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.Dataset.plan"
), // developer API
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.sql.Dataset.encoder"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.Dataset.collectResult"),

Expand Down

0 comments on commit be9ffb3

Please sign in to comment.