From 0d061b65754d9fc80e4a6a66bfabe4316b1677a3 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Sun, 17 Mar 2024 17:14:30 +0100 Subject: [PATCH] added tests and fixed name hack --- .../jetbrains/kotlinx/spark/api/Encoding.kt | 19 +++++--- .../kotlinx/spark/api/EncodingTest.kt | 45 ++++++++++++++++--- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index 987e83ac..56230dff 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -112,7 +112,7 @@ object KotlinTypeInference { // TODO this hack is a WIP and can give errors // TODO it's to make data classes get column names like "age" with functions like "getAge" // TODO instead of column names like "getAge" - var DO_NAME_HACK = false + var DO_NAME_HACK = true /** * @param kClass the class for which to infer the encoder. @@ -151,6 +151,7 @@ object KotlinTypeInference { currentType = kType, seenTypeSet = emptySet(), typeVariables = emptyMap(), + isTopLevel = true, ) as AgnosticEncoder @@ -217,6 +218,7 @@ object KotlinTypeInference { // how the generic types of the data class (like T, S) are filled in for this instance of the class typeVariables: Map, + isTopLevel: Boolean = false, ): AgnosticEncoder<*> { val kClass = currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType") @@ -488,6 +490,7 @@ object KotlinTypeInference { DirtyProductEncoderField( doNameHack = DO_NAME_HACK, + isTopLevel = isTopLevel, columnName = paramName, readMethodName = readMethodName, writeMethodName = writeMethodName, @@ -545,6 +548,7 @@ internal open class DirtyProductEncoderField( private val readMethodName: String, // the name of the method used to read the value private val writeMethodName: String?, private val doNameHack: Boolean, + private val isTopLevel: Boolean, encoder: AgnosticEncoder<*>, nullable: Boolean, metadata: Metadata = Metadata.empty(), @@ -557,20 +561,21 @@ internal open class DirtyProductEncoderField( /* writeMethod = */ writeMethodName.toOption(), ), Serializable { - private var i = 0 + private var isFirstNameCall = true /** * This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder] - * creates an [Invoke] using [columnName] first and then calls [columnName] again to retrieve + * creates an [Invoke] using [name] first and then calls [name] again to retrieve * the name of the column. This way, we can alternate between the two names. */ override fun name(): String = - when (doNameHack) { - true -> if (i++ % 2 == 0) readMethodName else columnName - false -> readMethodName + if (doNameHack && !isFirstNameCall) { + columnName + } else { + isFirstNameCall = false + readMethodName } - override fun canEqual(that: Any?): Boolean = that is AgnosticEncoders.EncoderField override fun productElement(n: Int): Any = diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt index 83f74ea1..151bca14 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt @@ -22,6 +22,8 @@ package org.jetbrains.kotlinx.spark.api import ch.tutteli.atrium.api.fluent.en_GB.* import ch.tutteli.atrium.api.verbs.expect import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.shouldBe import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.Decimal @@ -208,6 +210,39 @@ class EncodingTest : ShouldSpec({ context("schema") { withSpark(props = mapOf("spark.sql.codegen.comments" to true)) { + context("Give proper names to columns of data classe") { + val old = KotlinTypeInference.DO_NAME_HACK + KotlinTypeInference.DO_NAME_HACK = true + + should("Be able to serialize pairs") { + val pairs = listOf( + 1 to "1", + 2 to "2", + 3 to "3", + ) + val dataset = pairs.toDS() + dataset.show() + dataset.collectAsList() shouldBe pairs + dataset.columns().shouldContainExactly("first", "second") + } + + should("Be able to serialize pairs of pairs") { + val pairs = listOf( + 1 to (1 to "1"), + 2 to (2 to "2"), + 3 to (3 to "3"), + ) + val dataset = pairs.toDS() + dataset.show() + dataset.printSchema() + dataset.columns().shouldContainExactly("first", "second") + dataset.select("second.*").columns().shouldContainExactly("first", "second") + dataset.collectAsList() shouldBe pairs + } + + KotlinTypeInference.DO_NAME_HACK = old + } + should("handle Scala Case class datasets") { val caseClasses = listOf( tupleOf(1, "1"), @@ -253,14 +288,14 @@ class EncodingTest : ShouldSpec({ } - xshould("handle Scala Option datasets") { + should("handle Scala Option datasets") { val caseClasses = listOf(Some(1), Some(2), Some(3)) val dataset = caseClasses.toDS() dataset.show() dataset.collectAsList() shouldBe caseClasses } - xshould("handle Scala Option Option datasets") { + should("handle Scala Option Option datasets") { val caseClasses = listOf( Some(Some(1)), Some(Some(2)), @@ -270,7 +305,7 @@ class EncodingTest : ShouldSpec({ dataset.collectAsList() shouldBe caseClasses } - xshould("handle data class Scala Option datasets") { + should("handle data class Scala Option datasets") { val caseClasses = listOf( Some(1) to Some(2), Some(3) to Some(4), @@ -280,7 +315,7 @@ class EncodingTest : ShouldSpec({ dataset.collectAsList() shouldBe caseClasses } - xshould("handle Scala Option data class datasets") { + should("handle Scala Option data class datasets") { val caseClasses = listOf( Some(1 to 2), Some(3 to 4), @@ -501,7 +536,7 @@ class EncodingTest : ShouldSpec({ expect(result).toContain.inOrder.only.values(5.1 to 6) } - should("!handle primitive arrays") { + should("handle boxed arrays") { val result = listOf(arrayOf(1, 2, 3, 4)) .toDS() .map { it.map { ai -> ai + 1 } }