-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'refactor-tfjs-api' into max-operator
- Loading branch information
Showing
93 changed files
with
1,357 additions
and
431 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/logical/Xor.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package io.kinference.core.operators.logical | ||
|
||
import io.kinference.attribute.Attribute | ||
import io.kinference.core.data.tensor.KITensor | ||
import io.kinference.core.data.tensor.asTensor | ||
import io.kinference.data.ONNXData | ||
import io.kinference.graph.Contexts | ||
import io.kinference.ndarray.arrays.BooleanNDArray | ||
import io.kinference.operator.* | ||
import io.kinference.protobuf.message.TensorProto | ||
|
||
sealed class Xor( | ||
name: String, | ||
info: OperatorInfo, | ||
attributes: Map<String, Attribute<Any>>, | ||
inputs: List<String>, | ||
outputs: List<String> | ||
) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) { | ||
companion object { | ||
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7) | ||
|
||
operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Xor { | ||
return when (version ?: DEFAULT_VERSION.sinceVersion) { | ||
in XorVer7.VERSION.asRange() -> XorVer7(name, attributes, inputs, outputs) | ||
else -> error("Unsupported version of Xor operator: $version") | ||
} | ||
} | ||
} | ||
} | ||
|
||
class XorVer7( | ||
name: String, | ||
attributes: Map<String, Attribute<Any>>, | ||
inputs: List<String>, | ||
outputs: List<String> | ||
): Xor(name, INFO, attributes, inputs, outputs) { | ||
companion object { | ||
private val TYPE_CONSTRAINTS = setOf(TensorProto.DataType.BOOL) | ||
|
||
private val INPUTS_INFO = listOf( | ||
IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false), | ||
IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false) | ||
) | ||
|
||
private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "C", optional = false)) | ||
|
||
internal val VERSION = VersionInfo(sinceVersion = 7) | ||
private val INFO = OperatorInfo("Xor", emptySet(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) | ||
} | ||
|
||
override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> { | ||
val left = inputs[0]!!.data as BooleanNDArray | ||
val right = inputs[1]!!.data as BooleanNDArray | ||
|
||
val ans = left xor right | ||
return listOf(ans.asTensor("C")) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/tensor/IsNaN.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package io.kinference.core.operators.tensor | ||
|
||
import io.kinference.attribute.Attribute | ||
import io.kinference.core.data.tensor.KITensor | ||
import io.kinference.core.data.tensor.asTensor | ||
import io.kinference.data.ONNXData | ||
import io.kinference.graph.Contexts | ||
import io.kinference.ndarray.arrays.* | ||
import io.kinference.ndarray.extensions.isNaN.isNaN | ||
import io.kinference.operator.* | ||
import io.kinference.primitives.types.DataType | ||
|
||
sealed class IsNaN(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : | ||
Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) { | ||
companion object { | ||
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 9) | ||
|
||
operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) = | ||
when (version ?: DEFAULT_VERSION.sinceVersion) { | ||
in IsNaNVer9.VERSION.asRange() -> IsNaNVer9(name, attributes, inputs, outputs) | ||
else -> error("Unsupported version of IsNaN operator: $version") | ||
} | ||
} | ||
} | ||
|
||
|
||
class IsNaNVer9( | ||
name: String, | ||
attributes: Map<String, Attribute<Any>>, | ||
inputs: List<String>, | ||
outputs: List<String> | ||
) : IsNaN(name, INFO, attributes, inputs, outputs) { | ||
companion object { | ||
private val ATTRIBUTES_INFO = emptyList<AttributeInfo>() | ||
|
||
private val INPUTS_INFO = listOf( | ||
IOInfo(0, PRIMITIVE_DATA_TYPES, "X", differentiable = true, optional = false) | ||
) | ||
|
||
private val OUTPUTS_INFO = listOf( | ||
IOInfo(0, PRIMITIVE_DATA_TYPES, "Y", differentiable = true, optional = false) | ||
) | ||
|
||
//Realized the latest version, but there is backward compatibility between operators | ||
internal val VERSION = VersionInfo(sinceVersion = 9) | ||
private val INFO = OperatorInfo("IsNaN", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) | ||
} | ||
|
||
override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> { | ||
val input = inputs[0]!!.data | ||
|
||
val output = when (input.type) { | ||
DataType.FLOAT -> (input as FloatNDArray).isNaN() | ||
DataType.DOUBLE -> (input as DoubleNDArray).isNaN() | ||
else -> error("Unsupported type") | ||
} | ||
|
||
return listOf(output.asTensor("Y")) | ||
} | ||
} | ||
|
||
|
49 changes: 49 additions & 0 deletions
49
inference/inference-core/src/commonTest/kotlin/io/kinference/operators/logical/XorTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package io.kinference.operators.logical | ||
|
||
import io.kinference.KITestEngine.KIAccuracyRunner | ||
import io.kinference.utils.TestRunner | ||
import kotlin.test.Test | ||
|
||
class XorTest { | ||
private fun getTargetPath(dirName: String) = "xor/$dirName/" | ||
|
||
@Test | ||
fun test_xor_2d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor2d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_3d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor3d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_4d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor4d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_broadcast_3v1d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast3v1d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_broadcast_3v2d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast3v2d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_broadcast_4v2d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v2d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_broadcast_4v3d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v3d")) | ||
} | ||
|
||
@Test | ||
fun test_xor_broadcast_4v4d() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v4d")) | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
...ence/inference-core/src/commonTest/kotlin/io/kinference/operators/operations/IsNaNTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package io.kinference.operators.operations | ||
|
||
import io.kinference.KITestEngine.KIAccuracyRunner | ||
import io.kinference.utils.TestRunner | ||
import kotlin.test.Test | ||
|
||
class IsNaNTest { | ||
private fun getTargetPath(dirName: String) = "isnan/$dirName/" | ||
|
||
@Test | ||
fun test_isnan() = TestRunner.runTest { | ||
KIAccuracyRunner.runFromResources(getTargetPath("test_isnan")) | ||
} | ||
} |
Oops, something went wrong.