From 252a1c34e2a971f8bc9c4a7189ce39797e894f88 Mon Sep 17 00:00:00 2001 From: Ilya Vologin Date: Tue, 11 Jul 2023 19:50:39 +0200 Subject: [PATCH] KI-32 [tfjs] fix Assertions module --- .../io/kinference/tfjs/utils/TFJSAssertions.kt | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt index d3b4c8b60..680552305 100644 --- a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt +++ b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt @@ -3,6 +3,7 @@ package io.kinference.tfjs.utils import io.kinference.TestLoggerFactory import io.kinference.data.ONNXDataType import io.kinference.ndarray.extensions.* +import io.kinference.primitives.types.DataType import io.kinference.tfjs.TFJSData import io.kinference.tfjs.data.map.TFJSMap import io.kinference.tfjs.data.seq.TFJSSequence @@ -13,34 +14,35 @@ import kotlin.test.assertEquals object TFJSAssertions { val logger = TestLoggerFactory.create("Assertions") - @OptIn(ExperimentalUnsignedTypes::class) fun assertEquals(expected: TFJSTensor, actual: TFJSTensor, delta: Double) { - assertEquals(expected.data.dtype, actual.data.dtype, "Types of tensors ${expected.name} do not match") + assertEquals(expected.data.type, actual.data.type, "Types of tensors ${expected.name} do not match") ArrayAssertions.assertArrayEquals(expected.data.shapeArray, actual.data.shapeArray, "Shapes are incorrect") logger.info { "Errors for ${expected.name}:" } - when(expected.data.dtype) { - "float32" -> { + when(expected.data.type) { + DataType.FLOAT -> { val expectedArray = expected.data.dataFloat() val actualArray = actual.data.dataFloat() ArrayAssertions.assertEquals(expectedArray, actualArray, delta, expected.name.orEmpty()) } - "int32" -> { + DataType.INT -> { val expectedArray = expected.data.dataInt() val actualArray = actual.data.dataInt() ArrayAssertions.assertEquals(expectedArray, actualArray, delta, expected.name.orEmpty()) } - "bool" -> { + DataType.BOOLEAN -> { val expectedArray = expected.data.dataBool() val actualArray = actual.data.dataBool() ArrayAssertions.assertArrayEquals(expectedArray, actualArray, "Tensor ${expected.name} does not match") } + + else -> error("Unsupported data type of ${expected.name} tensor") } }