-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KI-31 [tfjs] Max implementation for TFJS backend
- Loading branch information
Showing
7 changed files
with
159 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Max.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package io.kinference.tfjs.operators.tensor | ||
|
||
import io.kinference.attribute.Attribute | ||
import io.kinference.data.ONNXData | ||
import io.kinference.graph.Contexts | ||
import io.kinference.ndarray.arrays.NumberNDArrayTFJS | ||
import io.kinference.ndarray.extensions.max | ||
import io.kinference.operator.* | ||
import io.kinference.tfjs.data.tensors.TFJSTensor | ||
import io.kinference.tfjs.data.tensors.asTensor | ||
|
||
sealed class Max(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : | ||
Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) { | ||
companion object { | ||
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 6) | ||
|
||
operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) = | ||
when (version ?: DEFAULT_VERSION.sinceVersion) { | ||
in MaxVer6.VERSION.asRange() -> MaxVer6(name, attributes, inputs, outputs) | ||
else -> error("Unsupported version of Max operator: $version") | ||
} | ||
} | ||
} | ||
|
||
class MaxVer6(name: String, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : | ||
Max(name, INFO, attributes, inputs, outputs) { | ||
|
||
companion object { | ||
private val ATTRIBUTES_INFO = emptyList<AttributeInfo>() | ||
|
||
private val INPUTS_INFO = listOf( | ||
VariadicIOInfo(0, NUMBER_DATA_TYPES, "data_0", minimumArity = 1) | ||
) | ||
|
||
private val OUTPUTS_INFO = listOf( | ||
IOInfo(0, NUMBER_DATA_TYPES, "max", optional = false) | ||
) | ||
|
||
//Realized the latest version, but there is backward compatibility between operators | ||
internal val VERSION = VersionInfo(sinceVersion = 6) | ||
private val INFO = OperatorInfo("Max", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) | ||
} | ||
|
||
override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> { | ||
val cleanInputs = inputs.filterNotNull().map { it.data as NumberNDArrayTFJS } | ||
return listOf(cleanInputs.max().asTensor("Y")) | ||
} | ||
} |
79 changes: 79 additions & 0 deletions
79
inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/MaxTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package io.kinference.tfjs.operators.tensor | ||
|
||
import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner | ||
import io.kinference.utils.TestRunner | ||
import kotlin.test.Test | ||
|
||
class MaxTest { | ||
private fun getTargetPath(dirName: String) = "max/$dirName/" | ||
|
||
@Test | ||
fun test_max_example() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_example")) | ||
} | ||
|
||
@Test | ||
fun test_max_float16() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_float16")) | ||
} | ||
|
||
@Test | ||
fun test_max_float32() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_float32")) | ||
} | ||
|
||
@Test | ||
fun test_max_float64() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_float64")) | ||
} | ||
|
||
@Test | ||
fun test_max_int8() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int8")) | ||
} | ||
|
||
@Test | ||
fun test_max_int16() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int16")) | ||
} | ||
|
||
@Test | ||
fun test_max_int32() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int32")) | ||
} | ||
|
||
@Test | ||
fun test_max_int64() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_int64")) | ||
} | ||
|
||
@Test | ||
fun test_max_one_input() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_one_input")) | ||
} | ||
|
||
@Test | ||
fun test_max_two_inputs() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_two_inputs")) | ||
} | ||
|
||
@Test | ||
fun test_max_uint8() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint8")) | ||
} | ||
|
||
@Test | ||
fun test_max_uint16() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint16")) | ||
} | ||
|
||
@Test | ||
fun test_max_uint32() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint32")) | ||
} | ||
|
||
@Test | ||
fun test_max_uint64() = TestRunner.runTest { | ||
TFJSAccuracyRunner.runFromResources(getTargetPath("test_max_uint64")) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
utils/utils-testing/src/commonMain/kotlin/io.kinference/utils/Assertions.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package io.kinference.utils | ||
|
||
import kotlin.test.assertTrue | ||
|
||
object Assertions { | ||
fun <T: Comparable<T>> assertLessOrEquals(expected: T, actual: T, message: String) { | ||
assertTrue(actual <= expected, message) | ||
} | ||
} |