Skip to content

Commit

Permalink
KI-31 [tfjs] Max implementation for TFJS backend
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank committed Jul 12, 2023
1 parent b2aafad commit c5e0303
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"LSTM" -> LSTM(name, version, attributes, inputs, outputs)
"MatMul" -> MatMul(name, version, attributes, inputs, outputs)
"MatMulInteger" -> MatMulInteger(name, version, attributes, inputs, outputs)
"Max" -> Max(name, version, attributes, inputs, outputs)
"Mul" -> Mul(name, version, attributes, inputs, outputs)
"NonZero" -> NonZero(name, version, attributes, inputs, outputs)
"Not" -> Not(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.NumberNDArrayTFJS
import io.kinference.ndarray.extensions.max
import io.kinference.operator.*
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Max(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) :
Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 6)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) =
when (version ?: DEFAULT_VERSION.sinceVersion) {
in MaxVer6.VERSION.asRange() -> MaxVer6(name, attributes, inputs, outputs)
else -> error("Unsupported version of Max operator: $version")
}
}
}

class MaxVer6(name: String, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) :
Max(name, INFO, attributes, inputs, outputs) {

companion object {
private val ATTRIBUTES_INFO = emptyList<AttributeInfo>()

private val INPUTS_INFO = listOf(
VariadicIOInfo(0, NUMBER_DATA_TYPES, "data_0", minimumArity = 1)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, NUMBER_DATA_TYPES, "max", optional = false)
)

//Realized the latest version, but there is backward compatibility between operators
internal val VERSION = VersionInfo(sinceVersion = 6)
private val INFO = OperatorInfo("Max", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val cleanInputs = inputs.filterNotNull().map { it.data as NumberNDArrayTFJS }
return listOf(cleanInputs.max().asTensor("Y"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class MaxTest {
private fun getTargetPath(dirName: String) = "max/$dirName/"

@Test
fun test_max_example() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_example"))
}

@Test
fun test_max_float16() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_float16"))
}

@Test
fun test_max_float32() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_float32"))
}

@Test
fun test_max_float64() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_float64"))
}

@Test
fun test_max_int8() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int8"))
}

@Test
fun test_max_int16() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int16"))
}

@Test
fun test_max_int32() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int32"))
}

@Test
fun test_max_int64() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int64"))
}

@Test
fun test_max_one_input() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_one_input"))
}

@Test
fun test_max_two_inputs() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_two_inputs"))
}

@Test
fun test_max_uint8() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint8"))
}

@Test
fun test_max_uint16() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint16"))
}

@Test
fun test_max_uint32() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint32"))
}

@Test
fun test_max_uint64() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint64"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,14 @@ suspend fun NumberNDArrayTFJS.isInf(detectNegative: Boolean = true, detectPositi
}

fun NumberNDArrayTFJS.isNaN() = BooleanNDArrayTFJS(tfjsArray.isNaN())

suspend fun List<NumberNDArrayTFJS>.max(): NumberNDArrayTFJS {
if (isEmpty()) error("Array for max operation must have at least one element")
if (size == 1) return single()

return tidyNDArray { reduce { acc, next -> max(acc, next) } }
}

suspend fun Array<out NumberNDArrayTFJS>.max() = toList().max()

suspend fun maxOf(vararg inputs: NumberNDArrayTFJS) = inputs.max()
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ object FlatTensorDecoder : TensorDecoder() {
return when (type) {
TensorProto.DataType.DOUBLE -> DoubleArray(size) { init(it) as Double }
TensorProto.DataType.FLOAT -> FloatArray(size) { init(it) as Float }
TensorProto.DataType.INT64 -> LongArray(size) { init(it) as Long }
TensorProto.DataType.INT32 -> IntArray(size) { init(it) as Int }
TensorProto.DataType.FLOAT16 -> FloatArray(size) { init(it) as Float }
TensorProto.DataType.BFLOAT16 -> FloatArray(size) { init(it) as Float }
TensorProto.DataType.INT8 -> ByteArray(size) { init(it) as Byte }
TensorProto.DataType.INT16 -> ShortArray(size) { init(it) as Short }
TensorProto.DataType.INT32 -> IntArray(size) { init(it) as Int }
TensorProto.DataType.INT64 -> LongArray(size) { init(it) as Long }
TensorProto.DataType.UINT8 -> UByteArray(size) { init(it) as UByte }
TensorProto.DataType.UINT16 -> UShortArray(size) { init(it) as UShort }
TensorProto.DataType.UINT32 -> UIntArray(size) { init(it) as UInt }
TensorProto.DataType.UINT64 -> ULongArray(size) { init(it) as ULong }
TensorProto.DataType.BOOL -> BooleanArray(size) { init(it) as Boolean }
TensorProto.DataType.INT8 -> ByteArray(size) { init(it) as Byte }
TensorProto.DataType.UINT8 -> UByteArray(size) { init(it) as UByte }
TensorProto.DataType.BFLOAT16 -> FloatArray(size) { init(it) as Float }
TensorProto.DataType.FLOAT16 -> FloatArray(size) { init(it) as Float }
else -> error("Unsupported data type: $type")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import io.kinference.*
import io.kinference.data.ONNXData
import io.kinference.data.ONNXDataType
import io.kinference.utils.*
import io.kinference.utils.Assertions.assertLessOrEquals
import okio.Path
import okio.Path.Companion.toPath
import kotlin.math.pow
import kotlin.test.assertEquals
import kotlin.test.*

class AccuracyRunner<T : ONNXData<*, *>>(private val testEngine: TestEngine<T>) {
private data class ONNXTestData<T : ONNXData<*, *>> (val name: String, val actual: Map<String, T>, val expected: Map<String, T>)
Expand Down Expand Up @@ -59,7 +60,7 @@ class AccuracyRunner<T : ONNXData<*, *>>(private val testEngine: TestEngine<T>)
val outputs = model.predict(inputs)
val memoryAfterTest = testEngine.allocatedMemory()
logger.info { "Memory after predict: $memoryAfterTest" }
assertEquals(expectedOutputs.size, memoryAfterTest - memoryBeforeTest, "Memory leak found")
assertLessOrEquals(expectedOutputs.size, memoryAfterTest - memoryBeforeTest, "Memory leak found")
outputs
} else {
model.predict(inputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.kinference.utils

import kotlin.test.assertTrue

object Assertions {
fun <T: Comparable<T>> assertLessOrEquals(expected: T, actual: T, message: String) {
assertTrue(actual <= expected, message)
}
}

0 comments on commit c5e0303

Please sign in to comment.