From 0bc38acc615ad411a97779c6a1ff43d4391c0c3d Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Fri, 21 Jun 2024 22:45:47 +0800 Subject: [PATCH] [SPARK-48675][SQL] Fix cache table with collated column ### What changes were proposed in this pull request? Following sequence of queries produces the error: ``` > cache lazy table t as select col from values ('a' collate utf8_lcase) as (col); > select col from t; org.apache.spark.SparkException: not support type: org.apache.spark.sql.types.StringType1. at org.apache.spark.sql.errors.QueryExecutionErrors$.notSupportTypeError(QueryExecutionErrors.scala:1069) at org.apache.spark.sql.execution.columnar.ColumnBuilder$.apply(ColumnBuilder.scala:200) at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.$anonfun$next$1(InMemoryRelation.scala:85) at scala.collection.immutable.List.map(List.scala:247) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:84) at org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:82) at org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:296) at org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:293) ... ``` This is also the problem on non-lazy cached tables. It turns out that the problem happens to occur during the execution of `InMemoryTableScanExec` where we need to update `ColumnAccessor`, `ColumnBuilder`, `ColumnType` and `ColumnStats`. ### Why are the changes needed? To fix the described error. ### Does this PR introduce _any_ user-facing change? Yes, the described sequence of queries should produce valid results after these changes are applied instead of throwing error. ### How was this patch tested? Added checks to columnar suites for the mentioned classes and integration test to `CollationSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47045 from nikolamand-db/SPARK-48675. Authored-by: Nikola Mandic Signed-off-by: Wenchen Fan --- .../execution/columnar/ColumnAccessor.scala | 6 +- .../execution/columnar/ColumnBuilder.scala | 5 +- .../sql/execution/columnar/ColumnStats.scala | 10 ++-- .../sql/execution/columnar/ColumnType.scala | 12 +++- .../columnar/GenerateColumnAccessor.scala | 4 +- .../compression/compressionSchemes.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 34 +++++++++++ .../execution/columnar/ColumnStatsSuite.scala | 59 ++++++++++++++++++- .../execution/columnar/ColumnTypeSuite.scala | 33 ++++++++--- .../columnar/ColumnarTestUtils.scala | 2 +- .../NullableColumnAccessorSuite.scala | 23 ++++++-- .../columnar/NullableColumnBuilderSuite.scala | 23 ++++++-- .../CompressionSchemeBenchmark.scala | 5 +- .../compression/DictionaryEncodingSuite.scala | 14 ++++- .../compression/RunLengthEncodingSuite.scala | 14 ++++- 15 files changed, 205 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 4a922dcb062e5..9652a48e5270e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -100,8 +100,8 @@ private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, DOUBLE) -private[columnar] class StringColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, STRING) +private[columnar] class StringColumnAccessor(buffer: ByteBuffer, dataType: StringType) + extends NativeColumnAccessor(buffer, STRING(dataType)) private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) @@ -147,7 +147,7 @@ private[sql] object ColumnAccessor { new LongColumnAccessor(buf) case FloatType => new FloatColumnAccessor(buf) case DoubleType => new DoubleColumnAccessor(buf) - case StringType => new StringColumnAccessor(buf) + case s: StringType => new StringColumnAccessor(buf, s) case BinaryType => new BinaryColumnAccessor(buf) case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => new CompactDecimalColumnAccessor(buf, dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 367547155beef..9fafdb7948416 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -122,7 +122,8 @@ private[columnar] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) private[columnar] -class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +class StringColumnBuilder(dataType: StringType) + extends NativeColumnBuilder(new StringColumnStats(dataType), STRING(dataType)) private[columnar] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) @@ -185,7 +186,7 @@ private[columnar] object ColumnBuilder { new LongColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder - case StringType => new StringColumnBuilder + case s: StringType => new StringColumnBuilder(s) case BinaryType => new BinaryColumnBuilder case CalendarIntervalType => new IntervalColumnBuilder case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 18ef84262aad3..45f489cb13c2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -255,14 +255,16 @@ private[columnar] final class DoubleColumnStats extends ColumnStats { Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] final class StringColumnStats extends ColumnStats { +private[columnar] final class StringColumnStats(collationId: Int) extends ColumnStats { + def this(dt: StringType) = this(dt.collationId) + protected var upper: UTF8String = null protected var lower: UTF8String = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getUTF8String(ordinal) - val size = STRING.actualSize(row, ordinal) + val size = STRING(collationId).actualSize(row, ordinal) gatherValueStats(value, size) } else { gatherNullStats() @@ -270,8 +272,8 @@ private[columnar] final class StringColumnStats extends ColumnStats { } def gatherValueStats(value: UTF8String, size: Int): Unit = { - if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone() - if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone() + if (upper == null || value.semanticCompare(upper, collationId) > 0) upper = value.clone() + if (lower == null || value.semanticCompare(lower, collationId) < 0) lower = value.clone() sizeInBytes += size count += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index ee1f9b4133026..b8e63294f3cdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -491,8 +491,8 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType } } -private[columnar] object STRING - extends NativeColumnType(PhysicalStringType(StringType.collationId), 8) +private[columnar] case class STRING(collationId: Int) + extends NativeColumnType(PhysicalStringType(collationId), 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { @@ -532,6 +532,12 @@ private[columnar] object STRING override def clone(v: UTF8String): UTF8String = v.clone() } +private[columnar] object STRING { + def apply(dt: StringType): STRING = { + STRING(dt.collationId) + } +} + private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) extends NativeColumnType(PhysicalDecimalType(precision, scale), 8) { @@ -821,7 +827,7 @@ private[columnar] object ColumnType { case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => LONG case FloatType => FLOAT case DoubleType => DOUBLE - case StringType => STRING + case s: StringType => STRING(s) case BinaryType => BINARY case i: CalendarIntervalType => CALENDAR_INTERVAL case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 5eadc7d47c92e..75416b8789142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -86,7 +86,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera classOf[LongColumnAccessor].getName case FloatType => classOf[FloatColumnAccessor].getName case DoubleType => classOf[DoubleColumnAccessor].getName - case StringType => classOf[StringColumnAccessor].getName + case _: StringType => classOf[StringColumnAccessor].getName case BinaryType => classOf[BinaryColumnAccessor].getName case CalendarIntervalType => classOf[IntervalColumnAccessor].getName case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => @@ -101,7 +101,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val createCode = dt match { case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" - case NullType | StringType | BinaryType | CalendarIntervalType => + case NullType | BinaryType | CalendarIntervalType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case other => s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index 46044f6919d17..86d76856e12bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -176,7 +176,7 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { } override def supports(columnType: ColumnType[_]): Boolean = columnType match { - case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true + case INT | LONG | SHORT | BYTE | _: STRING | BOOLEAN => true case _ => false } @@ -373,7 +373,7 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } override def supports(columnType: ColumnType[_]): Boolean = columnType match { - case INT | LONG | STRING => true + case INT | LONG | _: STRING => true case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index e2a11fc137c3d..c4eaedfb215e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} @@ -1431,4 +1432,37 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }) } + test("cache table with collated columns") { + val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI") + val lazyOptions = Seq(false, true) + + for ( + collation <- collations; + lazyTable <- lazyOptions + ) { + val lazyStr = if (lazyTable) "LAZY" else "" + + def checkCacheTable(values: String): Unit = { + sql(s"CACHE $lazyStr TABLE tbl AS SELECT col FROM VALUES ($values) AS (col)") + // Checks in-memory fetching code path. + val all = sql("SELECT col FROM tbl") + assert(all.queryExecution.executedPlan.collectFirst { + case _: InMemoryTableScanExec => true + }.nonEmpty) + checkAnswer(all, Row("a")) + // Checks column stats code path. + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Row("a")) + checkAnswer(sql("SELECT col FROM tbl WHERE col = 'b'"), Seq.empty) + } + + withTable("tbl") { + checkCacheTable(s"'a' COLLATE $collation") + } + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { + withTable("tbl") { + checkCacheTable("'a'") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index f39057013e64b..bdb118b91fa28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.types.PhysicalDataType +import org.apache.spark.sql.types.StringType class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) @@ -28,9 +29,9 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) testDecimalColumnStats(Array(null, null, 0)) testIntervalColumnStats(Array(null, null, 0)) + testStringColumnStats(Array(null, null, 0)) def testColumnStats[T <: PhysicalDataType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -141,4 +142,60 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats]( + initialStatistics: Array[Any]): Unit = { + + Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI").foreach(collation => { + val columnType = STRING(StringType(collation)) + + test(s"STRING($collation): empty") { + val columnStats = new StringColumnStats(StringType(collation).collationId) + columnStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"STRING($collation): non-empty") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + + val columnStats = new StringColumnStats(StringType(collation).collationId) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, + ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))) + val ordering = PhysicalDataType.ordering( + ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)) + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + }) + + test("STRING(UTF8_LCASE): collation-defined ordering") { + import org.apache.spark.sql.catalyst.expressions.GenericInternalRow + import org.apache.spark.unsafe.types.UTF8String + + val columnStats = new StringColumnStats(StringType("UTF8_LCASE").collationId) + val rows = Seq("b", "a", "C", "A").map(str => { + val row = new GenericInternalRow(1) + row(0) = UTF8String.fromString(str) + row + }) + rows.foreach(columnStats.gatherStats(_, 0)) + + val stats = columnStats.collectedStatistics + assertResult(UTF8String.fromString("a"), "Wrong lower bound")(stats(0)) + assertResult(UTF8String.fromString("C"), "Wrong upper bound")(stats(1)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index d79ac8dc35459..a95bda9bf71df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalDataType, PhysicalMapType, PhysicalStructType} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -40,7 +41,9 @@ class ColumnTypeSuite extends SparkFunSuite { val checks = Map( NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, - STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68, + STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8, + STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) -> 8, + BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68, CALENDAR_INTERVAL -> 16) checks.foreach { case (columnType, expectedSize) => @@ -73,7 +76,12 @@ class ColumnTypeSuite extends SparkFunSuite { checkActualSize(LONG, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(STRING, "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length) + Seq( + "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI" + ).foreach(collation => { + checkActualSize(STRING(StringType(collation)), + "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length) + }) checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) @@ -93,7 +101,10 @@ class ColumnTypeSuite extends SparkFunSuite { testNativeColumnType(FLOAT) testNativeColumnType(DOUBLE) testNativeColumnType(COMPACT_DECIMAL(15, 10)) - testNativeColumnType(STRING) + testNativeColumnType(STRING(StringType)) // UTF8_BINARY + testNativeColumnType(STRING(StringType("UTF8_LCASE"))) + testNativeColumnType(STRING(StringType("UNICODE"))) + testNativeColumnType(STRING(StringType("UNICODE_CI"))) testColumnType(NULL) testColumnType(BINARY) @@ -104,11 +115,18 @@ class ColumnTypeSuite extends SparkFunSuite { testColumnType(CALENDAR_INTERVAL) def testNativeColumnType[T <: PhysicalDataType](columnType: NativeColumnType[T]): Unit = { - testColumnType[T#InternalType](columnType) + val typeName = columnType match { + case s: STRING => + val collation = CollationFactory.fetchCollation(s.collationId).collationName + Some(if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)") + case _ => None + } + testColumnType[T#InternalType](columnType, typeName) } - def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - + def testColumnType[JvmType]( + columnType: ColumnType[JvmType], + typeName: Option[String] = None): Unit = { val proj = UnsafeProjection.create( Array[DataType](ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))) val converter = CatalystTypeConverters.createToScalaConverter( @@ -116,8 +134,9 @@ class ColumnTypeSuite extends SparkFunSuite { val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) val totalSize = seq.map(_.getSizeInBytes).sum val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize) + val testName = typeName.getOrElse(columnType.toString) - test(s"$columnType append/extract") { + test(s"$testName append/extract") { val buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder()) seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index e7b509c087b79..d08c34056f565 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -50,7 +50,7 @@ object ColumnarTestUtils { case LONG => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) + case _: STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) case CALENDAR_INTERVAL => new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index 169d9356c00cc..ee622793ee0a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalMapType, PhysicalStructType} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -41,21 +42,33 @@ object TestNullableColumnAccessor { class NullableColumnAccessorSuite extends SparkFunSuite { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - Seq( + val stringTypes = Seq( + STRING(StringType), // UTF8_BINARY + STRING(StringType("UTF8_LCASE")), + STRING(StringType("UNICODE")), + STRING(StringType("UNICODE_CI"))) + val otherTypes = Seq( NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, - STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), STRUCT(PhysicalStructType(Array(StructField("a", StringType)))), ARRAY(PhysicalArrayType(IntegerType, true)), MAP(PhysicalMapType(IntegerType, StringType, true)), CALENDAR_INTERVAL) - .foreach { + + stringTypes.foreach(s => { + val collation = CollationFactory.fetchCollation(s.collationId).collationName + val typeName = if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)" + testNullableColumnAccessor(s, Some(typeName)) + }) + otherTypes.foreach { testNullableColumnAccessor(_) } def testNullableColumnAccessor[JvmType]( - columnType: ColumnType[JvmType]): Unit = { + columnType: ColumnType[JvmType], + testTypeName: Option[String] = None): Unit = { - val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val typeName = testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$")) val nullRow = makeNullRow(1) test(s"Nullable $typeName column accessor: empty column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 22f557e49ded5..609212c95e987 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalMapType, PhysicalStructType} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -39,21 +40,33 @@ object TestNullableColumnBuilder { class NullableColumnBuilderSuite extends SparkFunSuite { import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ - Seq( + val stringTypes = Seq( + STRING(StringType), // UTF8_BINARY + STRING(StringType("UTF8_LCASE")), + STRING(StringType("UNICODE")), + STRING(StringType("UNICODE_CI"))) + val otherTypes = Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, - STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), STRUCT(PhysicalStructType(Array(StructField("a", StringType)))), ARRAY(PhysicalArrayType(IntegerType, true)), MAP(PhysicalMapType(IntegerType, StringType, true)), CALENDAR_INTERVAL) - .foreach { + + stringTypes.foreach(s => { + val collation = CollationFactory.fetchCollation(s.collationId).collationName + val typeName = if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)" + testNullableColumnBuilder(s, Some(typeName)) + }) + otherTypes.foreach { testNullableColumnBuilder(_) } def testNullableColumnBuilder[JvmType]( - columnType: ColumnType[JvmType]): Unit = { + columnType: ColumnType[JvmType], + testTypeName: Option[String] = None): Unit = { - val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val typeName = testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$")) val dataType = columnType.dataType val proj = UnsafeProjection.create(Array[DataType]( ColumnarDataTypeUtils.toLogicalDataType(dataType))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 2da0adf439dae..05ae575305299 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -27,6 +27,7 @@ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, NativeColumnType, SHORT, STRING} +import org.apache.spark.sql.types.StringType import org.apache.spark.util.Utils._ /** @@ -231,8 +232,8 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem } testData.rewind() - runEncodeBenchmark("STRING Encode", iters, count, STRING, testData) - runDecodeBenchmark("STRING Decode", iters, count, STRING, testData) + runEncodeBenchmark("STRING Encode", iters, count, STRING(StringType), testData) + runDecodeBenchmark("STRING Decode", iters, count, STRING(StringType), testData) } override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 10d5e8a0eb9a3..2b2bc7e761368 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -25,19 +25,27 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.StringType class DictionaryEncodingSuite extends SparkFunSuite { val nullValue = -1 testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) - testDictionaryEncoding(new StringColumnStats, STRING, false) + Seq( + "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI" + ).foreach(collation => { + val dt = StringType(collation) + val typeName = if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)" + testDictionaryEncoding(new StringColumnStats(dt), STRING(dt), false, Some(typeName)) + }) def testDictionaryEncoding[T <: PhysicalDataType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - testDecompress: Boolean = true): Unit = { + testDecompress: Boolean = true, + testTypeName: Option[String] = None): Unit = { - val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val typeName = testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$")) def buildDictionary(buffer: ByteBuffer) = { (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 00f242a6b9c4b..9b0067fd29832 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.StringType class RunLengthEncodingSuite extends SparkFunSuite { val nullValue = -1 @@ -31,14 +32,21 @@ class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING, false) + Seq( + "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI" + ).foreach(collation => { + val dt = StringType(collation) + val typeName = if (collation == "UTF8_BINARY") "STRING" else s"STRING($collation)" + testRunLengthEncoding(new StringColumnStats(dt), STRING(dt), false, Some(typeName)) + }) def testRunLengthEncoding[T <: PhysicalDataType]( columnStats: ColumnStats, columnType: NativeColumnType[T], - testDecompress: Boolean = true): Unit = { + testDecompress: Boolean = true, + testTypeName: Option[String] = None): Unit = { - val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val typeName = testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$")) def skeleton(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]): Unit = { // -------------