diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 5f263903c8bbc..2d72ea6bda8f3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -128,11 +128,13 @@ import org.apache.spark.util.SparkClassUtils class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, - val encoder: AgnosticEncoder[T]) + val encoder: Encoder[T]) extends Serializable { // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) + private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder) + override def toString: String = { try { val builder = new mutable.StringBuilder @@ -828,7 +830,7 @@ class Dataset[T] private[sql] ( } private def buildSort(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getSortBuilder .setInput(plan.getRoot) .setIsGlobal(global) @@ -878,8 +880,8 @@ class Dataset[T] private[sql] ( ProductEncoder[(T, U)]( ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")), Seq( - EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty), - EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty))) + EncoderField(s"_1", this.agnosticEncoder, leftNullable, Metadata.empty), + EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty))) sparkSession.newDataset(tupleEncoder) { builder => val joinBuilder = builder.getJoinBuilder @@ -889,8 +891,8 @@ class Dataset[T] private[sql] ( .setJoinType(joinTypeValue) .setJoinCondition(condition.expr) .setJoinDataType(joinBuilder.getJoinDataTypeBuilder - .setIsLeftStruct(this.encoder.isStruct) - .setIsRightStruct(other.encoder.isStruct)) + .setIsLeftStruct(this.agnosticEncoder.isStruct) + .setIsRightStruct(other.agnosticEncoder.isStruct)) } } @@ -1010,13 +1012,13 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = sparkSession.newDataset(encoder) { - builder => + def hint(name: String, parameters: Any*): Dataset[T] = + sparkSession.newDataset(agnosticEncoder) { builder => builder.getHintBuilder .setInput(plan.getRoot) .setName(name) .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava) - } + } private def getPlanId: Option[Long] = if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) { @@ -1056,7 +1058,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def as(alias: String): Dataset[T] = sparkSession.newDataset(encoder) { builder => + def as(alias: String): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getSubqueryAliasBuilder .setInput(plan.getRoot) .setAlias(alias) @@ -1238,8 +1240,9 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def filter(condition: Column): Dataset[T] = sparkSession.newDataset(encoder) { builder => - builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) + def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { + builder => + builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr) } /** @@ -1355,12 +1358,12 @@ class Dataset[T] private[sql] ( def reduce(func: (T, T) => T): T = { val udf = ScalarUserDefinedFunction( function = func, - inputEncoders = encoder :: encoder :: Nil, - outputEncoder = encoder) + inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil, + outputEncoder = agnosticEncoder) val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr val result = sparkSession - .newDataset(encoder) { builder => + .newDataset(agnosticEncoder) { builder => builder.getAggregateBuilder .setInput(plan.getRoot) .addAggregateExpressions(reduceExpr) @@ -1718,7 +1721,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def limit(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder => + def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder .setInput(plan.getRoot) .setLimit(n) @@ -1730,7 +1733,7 @@ class Dataset[T] private[sql] ( * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = sparkSession.newDataset(encoder) { builder => + def offset(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getOffsetBuilder .setInput(plan.getRoot) .setOffset(n) @@ -1739,7 +1742,7 @@ class Dataset[T] private[sql] ( private def buildSetOp(right: Dataset[T], setOpType: proto.SetOperation.SetOpType)( f: proto.SetOperation.Builder => Unit): Dataset[T] = { checkSameSparkSession(right) - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => f( builder.getSetOpBuilder .setSetOpType(setOpType) @@ -2012,7 +2015,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getSampleBuilder .setInput(plan.getRoot) .setWithReplacement(withReplacement) @@ -2080,7 +2083,7 @@ class Dataset[T] private[sql] ( normalizedCumWeights .sliding(2) .map { case Array(low, high) => - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getSampleBuilder .setInput(sortedInput) .setWithReplacement(false) @@ -2401,15 +2404,16 @@ class Dataset[T] private[sql] ( private def buildDropDuplicates( columns: Option[Seq[String]], - withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(encoder) { builder => - val dropBuilder = builder.getDeduplicateBuilder - .setInput(plan.getRoot) - .setWithinWatermark(withinWaterMark) - if (columns.isDefined) { - dropBuilder.addAllColumnNames(columns.get.asJava) - } else { - dropBuilder.setAllColumnsAsKeys(true) - } + withinWaterMark: Boolean): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { + builder => + val dropBuilder = builder.getDeduplicateBuilder + .setInput(plan.getRoot) + .setWithinWatermark(withinWaterMark) + if (columns.isDefined) { + dropBuilder.addAllColumnNames(columns.get.asJava) + } else { + dropBuilder.setAllColumnsAsKeys(true) + } } /** @@ -2630,9 +2634,9 @@ class Dataset[T] private[sql] ( def filter(func: T => Boolean): Dataset[T] = { val udf = ScalarUserDefinedFunction( function = func, - inputEncoders = encoder :: Nil, + inputEncoders = agnosticEncoder :: Nil, outputEncoder = PrimitiveBooleanEncoder) - sparkSession.newDataset[T](encoder) { builder => + sparkSession.newDataset[T](agnosticEncoder) { builder => builder.getFilterBuilder .setInput(plan.getRoot) .setCondition(udf.apply(col("*")).expr) @@ -2683,7 +2687,7 @@ class Dataset[T] private[sql] ( val outputEncoder = encoderFor[U] val udf = ScalarUserDefinedFunction( function = func, - inputEncoders = encoder :: Nil, + inputEncoders = agnosticEncoder :: Nil, outputEncoder = outputEncoder) sparkSession.newDataset(outputEncoder) { builder => builder.getMapPartitionsBuilder @@ -2785,7 +2789,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def tail(n: Int): Array[T] = { - val lastN = sparkSession.newDataset(encoder) { builder => + val lastN = sparkSession.newDataset(agnosticEncoder) { builder => builder.getTailBuilder .setInput(plan.getRoot) .setLimit(n) @@ -2856,7 +2860,7 @@ class Dataset[T] private[sql] ( } private def buildRepartition(numPartitions: Int, shuffle: Boolean): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getRepartitionBuilder .setInput(plan.getRoot) .setNumPartitions(numPartitions) @@ -2866,11 +2870,12 @@ class Dataset[T] private[sql] ( private def buildRepartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(encoder) { builder => - val repartitionBuilder = builder.getRepartitionByExpressionBuilder - .setInput(plan.getRoot) - .addAllPartitionExprs(partitionExprs.map(_.expr).asJava) - numPartitions.foreach(repartitionBuilder.setNumPartitions) + partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { + builder => + val repartitionBuilder = builder.getRepartitionByExpressionBuilder + .setInput(plan.getRoot) + .addAllPartitionExprs(partitionExprs.map(_.expr).asJava) + numPartitions.foreach(repartitionBuilder.setNumPartitions) } /** @@ -3183,7 +3188,7 @@ class Dataset[T] private[sql] ( * @since 3.5.0 */ def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = { - sparkSession.newDataset(encoder) { builder => + sparkSession.newDataset(agnosticEncoder) { builder => builder.getWithWatermarkBuilder .setInput(plan.getRoot) .setEventTime(eventTime) @@ -3251,7 +3256,7 @@ class Dataset[T] private[sql] ( sparkSession.analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA) } - def collectResult(): SparkResult[T] = sparkSession.execute(plan, encoder) + def collectResult(): SparkResult[T] = sparkSession.execute(plan, agnosticEncoder) private[sql] def withResult[E](f: SparkResult[T] => E): E = { val result = collectResult() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index e67ef1c0fa7e2..202891c66d748 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -988,15 +988,15 @@ private object KeyValueGroupedDatasetImpl { groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = { val gf = ScalarUserDefinedFunction( function = groupingFunc, - inputEncoders = ds.encoder :: Nil, // Using the original value and key encoders + inputEncoders = ds.agnosticEncoder :: Nil, // Using the original value and key encoders outputEncoder = kEncoder) new KeyValueGroupedDatasetImpl( ds.sparkSession, ds.plan, kEncoder, kEncoder, - ds.encoder, - ds.encoder, + ds.agnosticEncoder, + ds.agnosticEncoder, Arrays.asList(gf.apply(col("*")).expr), UdfUtils.identical(), () => ds.map(groupingFunc)(kEncoder)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b395a2d073d6d..b9aa1f5bc5838 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -216,7 +216,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { * @since 3.5.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { - val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.encoder)) + val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) val scalaWriterBuilder = proto.ScalarScalaUDF .newBuilder() .setPayload(ByteString.copyFrom(serialized)) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index d380a1bbb653e..4439a5f3e2adb 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -181,7 +181,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.encoder"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected @@ -334,8 +333,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.plan" ), // developer API - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.sql.Dataset.encoder"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.collectResult"),