Skip to content

Commit

Permalink
added tests and fixed name hack
Browse files Browse the repository at this point in the history
  • Loading branch information
Jolanrensen committed Mar 17, 2024
1 parent 3e9261f commit 0d061b6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object KotlinTypeInference {
// TODO this hack is a WIP and can give errors
// TODO it's to make data classes get column names like "age" with functions like "getAge"
// TODO instead of column names like "getAge"
var DO_NAME_HACK = false
var DO_NAME_HACK = true

/**
* @param kClass the class for which to infer the encoder.
Expand Down Expand Up @@ -151,6 +151,7 @@ object KotlinTypeInference {
currentType = kType,
seenTypeSet = emptySet(),
typeVariables = emptyMap(),
isTopLevel = true,
) as AgnosticEncoder<T>


Expand Down Expand Up @@ -217,6 +218,7 @@ object KotlinTypeInference {

// how the generic types of the data class (like T, S) are filled in for this instance of the class
typeVariables: Map<String, KType>,
isTopLevel: Boolean = false,
): AgnosticEncoder<*> {
val kClass =
currentType.classifier as? KClass<*> ?: throw IllegalArgumentException("Unsupported type $currentType")
Expand Down Expand Up @@ -488,6 +490,7 @@ object KotlinTypeInference {

DirtyProductEncoderField(
doNameHack = DO_NAME_HACK,
isTopLevel = isTopLevel,
columnName = paramName,
readMethodName = readMethodName,
writeMethodName = writeMethodName,
Expand Down Expand Up @@ -545,6 +548,7 @@ internal open class DirtyProductEncoderField(
private val readMethodName: String, // the name of the method used to read the value
private val writeMethodName: String?,
private val doNameHack: Boolean,
private val isTopLevel: Boolean,
encoder: AgnosticEncoder<*>,
nullable: Boolean,
metadata: Metadata = Metadata.empty(),
Expand All @@ -557,20 +561,21 @@ internal open class DirtyProductEncoderField(
/* writeMethod = */ writeMethodName.toOption(),
), Serializable {

private var i = 0
private var isFirstNameCall = true

/**
* This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
* creates an [Invoke] using [columnName] first and then calls [columnName] again to retrieve
* creates an [Invoke] using [name] first and then calls [name] again to retrieve
* the name of the column. This way, we can alternate between the two names.
*/
override fun name(): String =
when (doNameHack) {
true -> if (i++ % 2 == 0) readMethodName else columnName
false -> readMethodName
if (doNameHack && !isFirstNameCall) {
columnName
} else {
isFirstNameCall = false
readMethodName
}


override fun canEqual(that: Any?): Boolean = that is AgnosticEncoders.EncoderField

override fun productElement(n: Int): Any =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package org.jetbrains.kotlinx.spark.api
import ch.tutteli.atrium.api.fluent.en_GB.*
import ch.tutteli.atrium.api.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.collections.shouldContain
import io.kotest.matchers.collections.shouldContainExactly
import io.kotest.matchers.shouldBe
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.Decimal
Expand Down Expand Up @@ -208,6 +210,39 @@ class EncodingTest : ShouldSpec({
context("schema") {
withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {

context("Give proper names to columns of data classe") {
val old = KotlinTypeInference.DO_NAME_HACK
KotlinTypeInference.DO_NAME_HACK = true

should("Be able to serialize pairs") {
val pairs = listOf(
1 to "1",
2 to "2",
3 to "3",
)
val dataset = pairs.toDS()
dataset.show()
dataset.collectAsList() shouldBe pairs
dataset.columns().shouldContainExactly("first", "second")
}

should("Be able to serialize pairs of pairs") {
val pairs = listOf(
1 to (1 to "1"),
2 to (2 to "2"),
3 to (3 to "3"),
)
val dataset = pairs.toDS()
dataset.show()
dataset.printSchema()
dataset.columns().shouldContainExactly("first", "second")
dataset.select("second.*").columns().shouldContainExactly("first", "second")
dataset.collectAsList() shouldBe pairs
}

KotlinTypeInference.DO_NAME_HACK = old
}

should("handle Scala Case class datasets") {
val caseClasses = listOf(
tupleOf(1, "1"),
Expand Down Expand Up @@ -253,14 +288,14 @@ class EncodingTest : ShouldSpec({
}


xshould("handle Scala Option datasets") {
should("handle Scala Option datasets") {
val caseClasses = listOf(Some(1), Some(2), Some(3))
val dataset = caseClasses.toDS()
dataset.show()
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle Scala Option Option datasets") {
should("handle Scala Option Option datasets") {
val caseClasses = listOf(
Some(Some(1)),
Some(Some(2)),
Expand All @@ -270,7 +305,7 @@ class EncodingTest : ShouldSpec({
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle data class Scala Option datasets") {
should("handle data class Scala Option datasets") {
val caseClasses = listOf(
Some(1) to Some(2),
Some(3) to Some(4),
Expand All @@ -280,7 +315,7 @@ class EncodingTest : ShouldSpec({
dataset.collectAsList() shouldBe caseClasses
}

xshould("handle Scala Option data class datasets") {
should("handle Scala Option data class datasets") {
val caseClasses = listOf(
Some(1 to 2),
Some(3 to 4),
Expand Down Expand Up @@ -501,7 +536,7 @@ class EncodingTest : ShouldSpec({
expect(result).toContain.inOrder.only.values(5.1 to 6)
}

should("!handle primitive arrays") {
should("handle boxed arrays") {
val result = listOf(arrayOf(1, 2, 3, 4))
.toDS()
.map { it.map { ai -> ai + 1 } }
Expand Down

0 comments on commit 0d061b6

Please sign in to comment.