Skip to content

Commit

Permalink
KI-32 [tfjs] fix Assertions module
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank committed Jul 11, 2023
1 parent 1b7ba94 commit 252a1c3
Showing 1 changed file with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.kinference.tfjs.utils
import io.kinference.TestLoggerFactory
import io.kinference.data.ONNXDataType
import io.kinference.ndarray.extensions.*
import io.kinference.primitives.types.DataType
import io.kinference.tfjs.TFJSData
import io.kinference.tfjs.data.map.TFJSMap
import io.kinference.tfjs.data.seq.TFJSSequence
Expand All @@ -13,34 +14,35 @@ import kotlin.test.assertEquals
object TFJSAssertions {
val logger = TestLoggerFactory.create("Assertions")

@OptIn(ExperimentalUnsignedTypes::class)
fun assertEquals(expected: TFJSTensor, actual: TFJSTensor, delta: Double) {
assertEquals(expected.data.dtype, actual.data.dtype, "Types of tensors ${expected.name} do not match")
assertEquals(expected.data.type, actual.data.type, "Types of tensors ${expected.name} do not match")
ArrayAssertions.assertArrayEquals(expected.data.shapeArray, actual.data.shapeArray, "Shapes are incorrect")


logger.info { "Errors for ${expected.name}:" }
when(expected.data.dtype) {
"float32" -> {
when(expected.data.type) {
DataType.FLOAT -> {
val expectedArray = expected.data.dataFloat()
val actualArray = actual.data.dataFloat()

ArrayAssertions.assertEquals(expectedArray, actualArray, delta, expected.name.orEmpty())
}

"int32" -> {
DataType.INT -> {
val expectedArray = expected.data.dataInt()
val actualArray = actual.data.dataInt()

ArrayAssertions.assertEquals(expectedArray, actualArray, delta, expected.name.orEmpty())
}

"bool" -> {
DataType.BOOLEAN -> {
val expectedArray = expected.data.dataBool()
val actualArray = actual.data.dataBool()

ArrayAssertions.assertArrayEquals(expectedArray, actualArray, "Tensor ${expected.name} does not match")
}

else -> error("Unsupported data type of ${expected.name} tensor")
}
}

Expand Down

0 comments on commit 252a1c3

Please sign in to comment.