diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt index a097d794e..e6becd8aa 100755 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/KIOperatorFactory.kt @@ -92,9 +92,10 @@ object KIOperatorFactory : OperatorFactory> { "Gelu" -> Gelu(name, version, attributes, inputs, outputs) "Gemm" -> Gemm(name, version, attributes, inputs, outputs) "Greater" -> Greater(name, version, attributes, inputs, outputs) + "GRU" -> GRU(name, version, attributes, inputs, outputs) "Hardmax" -> Hardmax(name, version, attributes, inputs, outputs) "IsInf" -> IsInf(name, version, attributes, inputs, outputs) - "GRU" -> GRU(name, version, attributes, inputs, outputs) + "IsNaN" -> IsNaN(name, version, attributes, inputs, outputs) "Identity" -> Identity(name, version, attributes, inputs, outputs) "If" -> If(name, version, attributes, inputs, outputs) "LayerNormalization" -> LayerNormalization(name, version, attributes, inputs, outputs) @@ -140,6 +141,7 @@ object KIOperatorFactory : OperatorFactory> { "TreeEnsembleRegressor" -> TreeEnsembleRegressor(name, version, attributes, inputs, outputs) "Unsqueeze" -> Unsqueeze(name, version, attributes, inputs, outputs) "Where" -> Where(name, version, attributes, inputs, outputs) + "Xor" -> Xor(name, version, attributes, inputs, outputs) "ZipMap" -> ZipMap(name, version, attributes, inputs, outputs) else -> error("Unsupported operator: $opType") } as Operator, KIONNXData<*>> diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt index 0b6eca49e..a4b8bf5d7 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt @@ -13,7 +13,7 @@ import io.kinference.ndarray.arrays.tiled.FloatTiledArray import io.kinference.ndarray.extensions.allocateNDArray import io.kinference.ndarray.extensions.dotTransposedWithAlpha import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto import kotlinx.coroutines.coroutineScope @@ -265,12 +265,11 @@ class AttentionVer1(name: String, attributes: Map>, input override suspend fun > apply(contexts: Contexts, inputs: List): List { val input = inputs[0]!! val weights = inputs[1]!! - val preparedWeights = (contexts.graph!!.getOrNullValue(optName(weights.name)) - ?: AttentionContextRule.prepareWeights(weights, numHeads)) as KITensor + + val preparedWeights = weights.takeIf { isOpt(it.name) } ?: AttentionContextRule.prepareWeights(weights, numHeads) val bias = inputs[2]!! - val preparedBias = (contexts.graph!!.getOrNullValue(optName(bias.name)) - ?: AttentionContextRule.prepareBias(bias, numHeads)) as KITensor + val preparedBias = bias.takeIf { isOpt(it.name) } ?: AttentionContextRule.prepareBias(bias, numHeads) val maskIndices = inputs.elementAtOrNull(3)?.data as IntNDArray? val past = inputs.elementAtOrNull(4)?.data diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/QAttention.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/QAttention.kt index 3faf65a5a..98743395c 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/QAttention.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/QAttention.kt @@ -10,7 +10,7 @@ import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.* import io.kinference.ndarray.extensions.tryDequantize import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto @@ -106,14 +106,12 @@ class QAttentionVer1(name: String, attributes: Map>, inpu val weights = inputs[1]!! val weightsScale = inputs[4]!! val weightsZeroPoint = inputs.getOrNull(7) - - val preparedWeights = (contexts.graph!!.getOrNullValue(optName(weights.name)) - ?: QAttentionContextRule.prepareWeights(weights, weightsScale, weightsZeroPoint, numHeads)) as KITensor + val preparedWeights = weights.takeIf { isOpt(it.name) } + ?: QAttentionContextRule.prepareWeights(weights, weightsScale, weightsZeroPoint, numHeads) val bias = inputs[2]!! - - val preparedBias = (contexts.graph!!.getOrNullValue(optName(bias.name)) - ?: AttentionContextRule.prepareBias(bias, numHeads)) as KITensor + val preparedBias = bias.takeIf { isOpt(it.name) } + ?: AttentionContextRule.prepareBias(bias, numHeads) val maskIndices = inputs.getOrNull(5)?.data as IntNDArray? val past = inputs.getOrNull(8)?.data as NumberNDArrayCore? diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/gru/GRU.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/gru/GRU.kt index 01751d5ce..486e83f17 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/gru/GRU.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/gru/GRU.kt @@ -10,7 +10,7 @@ import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.IntNDArray import io.kinference.ndarray.arrays.NumberNDArrayCore import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto @@ -99,14 +99,16 @@ class GRUVer7( val input = inputs[0]!! val weights = inputs[1]!! - val preparedWeights = (contexts.graph!!.getOrNullValue(optName(weights.name)) ?: GRUContextRule.prepareWeights(weights)) + val preparedWeights = weights.takeIf { isOpt(it.name) } ?: GRUContextRule.prepareWeights(weights) val recurrentWeights = inputs[2]!! - val preparedRecurrentWeights = (contexts.graph!!.getOrNullValue(optName(recurrentWeights.name)) - ?: GRUContextRule.prepareWeights(recurrentWeights)) as KITensor + val preparedRecurrentWeights = recurrentWeights.takeIf { isOpt(it.name) } + ?: GRUContextRule.prepareWeights(recurrentWeights) val bias = inputs.getOrNull(3) - val preparedBias = bias?.let { contexts.graph!!.getOrNullValue(optName(it.name)) ?: GRUContextRule.prepareBias(it) } + val preparedBias = bias?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: GRUContextRule.prepareBias(tensor) + } val sequenceLens = inputs.getOrNull(4) val initialHiddenState = inputs.getOrNull(5) diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/lstm/LSTM.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/lstm/LSTM.kt index c522f0aea..5a21d283f 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/lstm/LSTM.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/recurrent/lstm/LSTM.kt @@ -10,7 +10,7 @@ import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.IntNDArray import io.kinference.ndarray.arrays.NumberNDArrayCore import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto @@ -101,24 +101,23 @@ class LSTMVer7( val inputAsLSTMInput = DefaultLSTMInput(input.data as NumberNDArrayCore) val weights = inputs[1]!! - val preparedWeights = (contexts.graph!!.getOrNullValue(optName(weights.name)) - ?: LSTMContextRule.prepareWeights(weights)) as KITensor + val preparedWeights = weights.takeIf { isOpt(it.name) } ?: LSTMContextRule.prepareWeights(weights) val weightsAsLSTMWeights = DefaultLSTMWeights(preparedWeights.data as NumberNDArrayCore) val recurrentWeights = inputs[2]!! - val preparedRecurrentWeights = (contexts.graph!!.getOrNullValue(optName(recurrentWeights.name)) - ?: LSTMContextRule.prepareWeights(recurrentWeights)) as KITensor + val preparedRecurrentWeights = recurrentWeights.takeIf { isOpt(it.name) } + ?: LSTMContextRule.prepareWeights(recurrentWeights) val recurrentWeightsAsLSTMWeights = DefaultLSTMWeights(preparedRecurrentWeights.data as NumberNDArrayCore) val bias = inputs.getOrNull(3) - val preparedBias = bias?.let { - contexts.graph!!.getOrNullValue(optName(it.name)) ?: LSTMContextRule.prepareBias(it) - } as KITensor? + val preparedBias = bias?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: LSTMContextRule.prepareBias(tensor) + } val peepholes = inputs.getOrNull(7) - val preparedPeepholes = peepholes?.let { - contexts.graph!!.getOrNullValue(optName(it.name)) ?: LSTMContextRule.preparePeepholes(it) - } as KITensor? + val preparedPeepholes = peepholes?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: LSTMContextRule.preparePeepholes(tensor) + } val sequenceLens = inputs.getOrNull(4) val initialState = inputs.getOrNull(5) diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/logical/Xor.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/logical/Xor.kt new file mode 100644 index 000000000..c92425e95 --- /dev/null +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/logical/Xor.kt @@ -0,0 +1,58 @@ +package io.kinference.core.operators.logical + +import io.kinference.attribute.Attribute +import io.kinference.core.data.tensor.KITensor +import io.kinference.core.data.tensor.asTensor +import io.kinference.data.ONNXData +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.BooleanNDArray +import io.kinference.operator.* +import io.kinference.protobuf.message.TensorProto + +sealed class Xor( + name: String, + info: OperatorInfo, + attributes: Map>, + inputs: List, + outputs: List +) : Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List): Xor { + return when (version ?: DEFAULT_VERSION.sinceVersion) { + in XorVer7.VERSION.asRange() -> XorVer7(name, attributes, inputs, outputs) + else -> error("Unsupported version of Xor operator: $version") + } + } + } +} + +class XorVer7( + name: String, + attributes: Map>, + inputs: List, + outputs: List +): Xor(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = setOf(TensorProto.DataType.BOOL) + + private val INPUTS_INFO = listOf( + IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false), + IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false) + ) + + private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "C", optional = false)) + + internal val VERSION = VersionInfo(sinceVersion = 7) + private val INFO = OperatorInfo("Xor", emptySet(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val left = inputs[0]!!.data as BooleanNDArray + val right = inputs[1]!!.data as BooleanNDArray + + val ans = left xor right + return listOf(ans.asTensor("C")) + } +} diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/math/MatMulInteger.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/math/MatMulInteger.kt index 6d3858c31..395cbdda8 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/math/MatMulInteger.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/math/MatMulInteger.kt @@ -8,7 +8,7 @@ import io.kinference.data.ONNXData import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.* import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.TensorProto sealed class MatMulInteger(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { @@ -52,10 +52,10 @@ class MatMulIntegerVer10(name: String, attributes: Map>, val firstZero = inputs.getOrNull(2) val secondZero = inputs.getOrNull(3) - val firstPrepared = (contexts.graph!!.getOrNullValue(optName(first.name)) - ?: MatMulIntegerContextRule.prepareTensor(first, firstZero)) as KITensor - val secondPrepared = (contexts.graph!!.getOrNullValue(optName(second.name)) - ?: MatMulIntegerContextRule.prepareTensor(second, secondZero)) as KITensor + val firstPrepared = first.takeIf { isOpt(it.name) } + ?: MatMulIntegerContextRule.prepareTensor(first, firstZero) + val secondPrepared = second.takeIf { isOpt(it.name) } + ?: MatMulIntegerContextRule.prepareTensor(second, secondZero) val output = (firstPrepared.data as NumberNDArrayCore) .matmul(secondPrepared.data as NumberNDArrayCore) diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/quantization/lstm/DynamicQuantizeLSTM.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/quantization/lstm/DynamicQuantizeLSTM.kt index 1f5f169f1..6dbe0acc2 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/quantization/lstm/DynamicQuantizeLSTM.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/quantization/lstm/DynamicQuantizeLSTM.kt @@ -10,7 +10,7 @@ import io.kinference.data.ONNXData import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.* import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto @@ -89,22 +89,21 @@ class DynamicQuantizeLSTMVer1(name: String, attributes: Map + tensor.takeIf { isOpt(it.name) } ?: LSTMContextRule.prepareBias(tensor) + } val peepholes = inputs.getOrNull(7) - val preparedPeepholes = peepholes?.let { - contexts.graph!!.getOrNullValue(optName(it.name)) ?: LSTMContextRule.preparePeepholes(it) - } as KITensor? + val preparedPeepholes = peepholes?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: LSTMContextRule.preparePeepholes(tensor) + } val sequenceLens = inputs.getOrNull(4) val initialState = inputs.getOrNull(5) diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/tensor/IsNaN.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/tensor/IsNaN.kt new file mode 100644 index 000000000..6082de1e8 --- /dev/null +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/tensor/IsNaN.kt @@ -0,0 +1,62 @@ +package io.kinference.core.operators.tensor + +import io.kinference.attribute.Attribute +import io.kinference.core.data.tensor.KITensor +import io.kinference.core.data.tensor.asTensor +import io.kinference.data.ONNXData +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.* +import io.kinference.ndarray.extensions.isNaN.isNaN +import io.kinference.operator.* +import io.kinference.primitives.types.DataType + +sealed class IsNaN(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : + Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 9) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List) = + when (version ?: DEFAULT_VERSION.sinceVersion) { + in IsNaNVer9.VERSION.asRange() -> IsNaNVer9(name, attributes, inputs, outputs) + else -> error("Unsupported version of IsNaN operator: $version") + } + } +} + + +class IsNaNVer9( + name: String, + attributes: Map>, + inputs: List, + outputs: List +) : IsNaN(name, INFO, attributes, inputs, outputs) { + companion object { + private val ATTRIBUTES_INFO = emptyList() + + private val INPUTS_INFO = listOf( + IOInfo(0, PRIMITIVE_DATA_TYPES, "X", differentiable = true, optional = false) + ) + + private val OUTPUTS_INFO = listOf( + IOInfo(0, PRIMITIVE_DATA_TYPES, "Y", differentiable = true, optional = false) + ) + + //Realized the latest version, but there is backward compatibility between operators + internal val VERSION = VersionInfo(sinceVersion = 9) + private val INFO = OperatorInfo("IsNaN", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val input = inputs[0]!!.data + + val output = when (input.type) { + DataType.FLOAT -> (input as FloatNDArray).isNaN() + DataType.DOUBLE -> (input as DoubleNDArray).isNaN() + else -> error("Unsupported type") + } + + return listOf(output.asTensor("Y")) + } +} + + diff --git a/inference/inference-core/src/commonTest/kotlin/io/kinference/operators/logical/XorTest.kt b/inference/inference-core/src/commonTest/kotlin/io/kinference/operators/logical/XorTest.kt new file mode 100644 index 000000000..f97bcb67f --- /dev/null +++ b/inference/inference-core/src/commonTest/kotlin/io/kinference/operators/logical/XorTest.kt @@ -0,0 +1,49 @@ +package io.kinference.operators.logical + +import io.kinference.KITestEngine.KIAccuracyRunner +import io.kinference.utils.TestRunner +import kotlin.test.Test + +class XorTest { + private fun getTargetPath(dirName: String) = "xor/$dirName/" + + @Test + fun test_xor_2d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor2d")) + } + + @Test + fun test_xor_3d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor3d")) + } + + @Test + fun test_xor_4d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor4d")) + } + + @Test + fun test_xor_broadcast_3v1d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast3v1d")) + } + + @Test + fun test_xor_broadcast_3v2d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast3v2d")) + } + + @Test + fun test_xor_broadcast_4v2d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v2d")) + } + + @Test + fun test_xor_broadcast_4v3d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v3d")) + } + + @Test + fun test_xor_broadcast_4v4d() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v4d")) + } +} diff --git a/inference/inference-core/src/commonTest/kotlin/io/kinference/operators/operations/IsNaNTest.kt b/inference/inference-core/src/commonTest/kotlin/io/kinference/operators/operations/IsNaNTest.kt new file mode 100644 index 000000000..ad4d2ae54 --- /dev/null +++ b/inference/inference-core/src/commonTest/kotlin/io/kinference/operators/operations/IsNaNTest.kt @@ -0,0 +1,14 @@ +package io.kinference.operators.operations + +import io.kinference.KITestEngine.KIAccuracyRunner +import io.kinference.utils.TestRunner +import kotlin.test.Test + +class IsNaNTest { + private fun getTargetPath(dirName: String) = "isnan/$dirName/" + + @Test + fun test_isnan() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_isnan")) + } +} diff --git a/inference/inference-ir/src/commonMain/kotlin/io/kinference/optimizer/GraphOptimizer.kt b/inference/inference-ir/src/commonMain/kotlin/io/kinference/optimizer/GraphOptimizer.kt index b61512028..3774697ad 100644 --- a/inference/inference-ir/src/commonMain/kotlin/io/kinference/optimizer/GraphOptimizer.kt +++ b/inference/inference-ir/src/commonMain/kotlin/io/kinference/optimizer/GraphOptimizer.kt @@ -93,6 +93,8 @@ class GraphOptimizer>(val graph: Graph) { companion object { private val logger = LoggerFactory.create("io.kinference.optimizer.GraphOptimizer") + fun isOpt(name: String?) = name?.startsWith(OptimizerRule.PREFIX) ?: false + fun optName(name: String?) = if (name!!.startsWith(OptimizerRule.PREFIX)) name else "${OptimizerRule.PREFIX}_${name}" } } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TFJSTensor.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TFJSTensor.kt index 1493949b4..e9b9f925c 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TFJSTensor.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TFJSTensor.kt @@ -2,11 +2,9 @@ package io.kinference.tfjs.data.tensors import io.kinference.data.ONNXTensor import io.kinference.ndarray.arrays.* -import io.kinference.ndarray.extensions.tensor import io.kinference.protobuf.FLOAT_TENSOR_TYPES import io.kinference.protobuf.message.TensorProto import io.kinference.protobuf.message.TensorProto.DataType -import io.kinference.protobuf.toIntArray import io.kinference.tfjs.TFJSBackend import io.kinference.types.ValueInfo import io.kinference.types.ValueTypeInfo @@ -24,44 +22,6 @@ class TFJSTensor(name: String?, override val data: NDArrayTFJS, val info: ValueT data.close() } - /*fun toNDArray(): NDArray { - val shapeIntArray = data.shape.toIntArray() - val strides = Strides(shapeIntArray) - val blockSize = blockSizeByStrides(strides) - val blocksCount = strides.linearSize / blockSize - - - return when(data.dtype) { - "float32" -> { - val array = data.dataFloat().unsafeCast() - val arrayBuffer = array.buffer - val blocks = Array(blocksCount) { blockNum -> - Float32Array(arrayBuffer, blockNum * blockSize * 4, blockSize).unsafeCast() - } - val tiledArray = FloatTiledArray(blocks) - FloatNDArray(tiledArray, Strides(shapeIntArray)) - } - - "int32" -> { - val array = data.dataFloat().unsafeCast() - val arrayBuffer = array.buffer - val blocks = Array(blocksCount) { blockNum -> - Int32Array(arrayBuffer, blockNum * blockSize * 4, blockSize).unsafeCast() - } - val tiledArray = IntTiledArray(blocks) - IntNDArray(tiledArray, strides) - } - - "bool" -> { - val array = data.dataBool() - val tiledArray = BooleanTiledArray(shapeIntArray) { array[it] } - BooleanNDArray(tiledArray, strides) - } - - else -> error("Unsupported type") - } - }*/ - companion object { //TODO: complex, uint32/64 tensors fun create(proto: TensorProto): TFJSTensor { @@ -71,30 +31,21 @@ class TFJSTensor(name: String?, override val data: NDArrayTFJS, val info: ValueT return TFJSTensor(array, type, proto.dims, proto.name) } - /*operator fun invoke(value: NDArray, name: String? = ""): TFJSTensor { - return when (val resolvedType = value.type.resolveProtoDataType()) { - DataType.FLOAT -> invoke((value as FloatNDArray).array.toArray(), resolvedType, value.shape, name) - DataType.INT32 -> invoke((value as IntNDArray).array.toArray(), resolvedType, value.shape, name) - DataType.UINT8 -> invoke((value as UByteNDArray).array.toArray(), resolvedType, value.shape, name) - DataType.INT64 -> invoke((value as LongNDArray).array.toArray(), resolvedType, value.shape, name) - DataType.BOOL -> invoke((value as BooleanNDArray).array.toArray(), resolvedType, value.shape, name) - else -> error("Unsupported type") - } - }*/ - - private fun UByteArray.toIntTypedArray() = Array(this.size) { this[it].toInt() } - private operator fun invoke(value: Any, type: DataType, dims: IntArray, name: String? = ""): TFJSTensor { val nameNotNull = name.orEmpty() val typedDims = dims.toTypedArray() return when (type) { - in FLOAT_TENSOR_TYPES -> NumberNDArrayTFJS(tensor(value as FloatArray, typedDims, "float32")) - DataType.DOUBLE -> NumberNDArrayTFJS(tensor((value as DoubleArray).toTypedArray(), typedDims, "float32")) - DataType.INT32 -> NumberNDArrayTFJS(tensor(value as IntArray, typedDims, "int32")) - DataType.UINT8 -> NumberNDArrayTFJS(tensor((value as UByteArray).toIntTypedArray(), typedDims, "int32")) - DataType.INT8 -> NumberNDArrayTFJS(tensor((value as ByteArray).toTypedArray(), typedDims, "int32")) - DataType.INT64 -> NumberNDArrayTFJS(tensor((value as LongArray).toIntArray(), typedDims, "int32")) - DataType.BOOL -> BooleanNDArrayTFJS(tensor((value as BooleanArray).toTypedArray(), typedDims, "bool")) + in FLOAT_TENSOR_TYPES -> NDArrayTFJS.float(value as FloatArray, typedDims) + DataType.DOUBLE -> NDArrayTFJS.float(value as DoubleArray, typedDims) + DataType.INT8 -> NDArrayTFJS.int(value as ByteArray, typedDims) + DataType.INT16 -> NDArrayTFJS.int(value as ShortArray, typedDims) + DataType.INT32 -> NDArrayTFJS.int(value as IntArray, typedDims) + DataType.INT64 -> NDArrayTFJS.int(value as LongArray, typedDims) + DataType.UINT8 -> NDArrayTFJS.int(value as UByteArray, typedDims) + DataType.UINT16 -> NDArrayTFJS.int(value as UShortArray, typedDims) + DataType.UINT32 -> NDArrayTFJS.int(value as UIntArray, typedDims) + DataType.UINT64 -> NDArrayTFJS.int(value as ULongArray, typedDims) + DataType.BOOL -> NDArrayTFJS.boolean(value as BooleanArray, typedDims) else -> error("Unsupported type: $type") }.asTensor(nameNotNull) } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TensorExtensions.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TensorExtensions.kt index d2b08a73b..0afb42404 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TensorExtensions.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/data/tensors/TensorExtensions.kt @@ -1,8 +1,7 @@ package io.kinference.tfjs.data.tensors -import io.kinference.ndarray.arrays.ArrayTFJS import io.kinference.ndarray.arrays.NDArrayTFJS -import io.kinference.ndarray.extensions.toNDArray +//import io.kinference.ndarray.extensions.toNDArray import io.kinference.protobuf.message.TensorProto import io.kinference.protobuf.resolveProtoDataType import io.kinference.types.* @@ -10,7 +9,6 @@ import io.kinference.types.* fun T.asTensor(name: String? = null) = TFJSTensor(this, ValueInfo(ValueTypeInfo.TensorTypeInfo(TensorShape(shape), type.resolveProtoDataType()), name ?: "")) -fun ArrayTFJS.asTensor(name: String? = null) = this.toNDArray().asTensor(name) fun String.tfTypeResolve(): TensorProto.DataType { return when (this) { diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt index 4eb1d0915..b60c2f2df 100755 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/TFJSOperatorFactory.kt @@ -78,9 +78,10 @@ object TFJSOperatorFactory : OperatorFactory> { "GatherND" -> GatherND(name, version, attributes, inputs, outputs) "Gemm" -> Gemm(name, version, attributes, inputs, outputs) "Greater" -> Greater(name, version, attributes, inputs, outputs) + "GRU" -> GRU(name, version, attributes, inputs, outputs) "Hardmax" -> Hardmax(name, version, attributes, inputs, outputs) "IsInf" -> IsInf(name, version, attributes, inputs, outputs) - "GRU" -> GRU(name, version, attributes, inputs, outputs) + "IsNaN" -> IsNaN(name, version, attributes, inputs, outputs) "Identity" -> Identity(name, version, attributes, inputs, outputs) "If" -> If(name, version, attributes, inputs, outputs) "LayerNormalization" -> LayerNormalization(name, version, attributes, inputs, outputs) @@ -121,6 +122,7 @@ object TFJSOperatorFactory : OperatorFactory> { "TreeEnsembleRegressor" -> TreeEnsembleRegressor(name, version, attributes, inputs, outputs) "Unsqueeze" -> Unsqueeze(name, version, attributes, inputs, outputs) "Where" -> Where(name, version, attributes, inputs, outputs) + "Xor" -> Xor(name, version, attributes, inputs, outputs) "ZipMap" -> ZipMap(name, version, attributes, inputs, outputs) else -> error("Unsupported operator: $opType") } as Operator, TFJSData<*>> diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/flow/Loop.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/flow/Loop.kt index aa784fa2d..83d893de7 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/flow/Loop.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/flow/Loop.kt @@ -89,7 +89,7 @@ class LoopVer1(name: String, attributes: Map>, inputs: Li require(body.inputs.size == inputs.size) { "Not enough inputs for Loop subgraph\nPresent: ${inputs.size}, Expected: ${body.inputs.size}" } val modified = inputs.drop(2).requireNoNulls().map { - (it.data.clone() as NDArrayTFJS).asTensor(it.name) + it.data.clone().asTensor(it.name) }.toMutableList() val modifiedCount = modified.size diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/gru/GRU.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/gru/GRU.kt index 2f2ac5f5c..2cf2bad02 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/gru/GRU.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/gru/GRU.kt @@ -6,7 +6,7 @@ import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.NumberNDArrayTFJS import io.kinference.ndarray.extensions.* import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto import io.kinference.tfjs.data.tensors.TFJSTensor @@ -94,16 +94,15 @@ class GRUVer7(name: String, attributes: Map>, inputs: Lis val input = inputs[0]!!.data val weights = inputs[1]!! - val preparedWeights = (contexts.graph!!.getOrNullValue(optName(weights.name)) - ?: GRUContextRule.prepareWeights(weights)) + val preparedWeights = weights.takeIf { isOpt(it.name) } ?: GRUContextRule.prepareWeights(weights) val recurrentWeights = inputs[2]!! - val preparedRecurrentWeights = (contexts.graph!!.getOrNullValue(optName(recurrentWeights.name)) - ?: GRUContextRule.prepareWeights(recurrentWeights)) + val preparedRecurrentWeights = recurrentWeights.takeIf { isOpt(it.name) } + ?: GRUContextRule.prepareWeights(recurrentWeights) val bias = inputs.getOrNull(3) - val preparedBias = bias?.let { - contexts.graph!!.getOrNullValue(optName(it.name)) ?: GRUContextRule.prepareBias(it) + val preparedBias = bias?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: GRUContextRule.prepareBias(tensor) } val sequenceLens = inputs.getOrNull(4)?.data diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/lstm/LSTM.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/lstm/LSTM.kt index a0eaa6a6f..71169ea4f 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/lstm/LSTM.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/layer/recurrent/lstm/LSTM.kt @@ -6,7 +6,7 @@ import io.kinference.graph.Contexts import io.kinference.ndarray.arrays.NumberNDArrayTFJS import io.kinference.ndarray.extensions.tidyNDArrays import io.kinference.operator.* -import io.kinference.optimizer.GraphOptimizer.Companion.optName +import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto import io.kinference.tfjs.data.tensors.TFJSTensor @@ -104,22 +104,21 @@ class LSTMVer7( val input = inputs[0]!!.data as NumberNDArrayTFJS val weights = inputs[1]!! - val preparedWeights = (contexts.graph!!.getOrNullValue(optName(weights.name)) - ?: LSTMContextRule.prepareWeights(weights)) as TFJSTensor + val preparedWeights = weights.takeIf { isOpt(it.name) } ?: LSTMContextRule.prepareWeights(weights) val recurrentWeights = inputs[2]!! - val preparedRecurrentWeights = (contexts.graph!!.getOrNullValue(optName(recurrentWeights.name)) - ?: LSTMContextRule.prepareWeights(recurrentWeights)) as TFJSTensor + val preparedRecurrentWeights = recurrentWeights.takeIf { isOpt(it.name) } + ?: LSTMContextRule.prepareWeights(recurrentWeights) val bias = inputs.getOrNull(3) - val preparedBias = bias?.let { - contexts.graph!!.getOrNullValue(optName(it.name)) ?: LSTMContextRule.prepareBias(it) - } as TFJSTensor? + val preparedBias = bias?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: LSTMContextRule.prepareBias(tensor) + } val peepholes = inputs.getOrNull(7) - val preparedPeepholes = peepholes?.let { - contexts.graph!!.getOrNullValue(optName(it.name)) ?: LSTMContextRule.preparePeepholes(it) - } as TFJSTensor? + val preparedPeepholes = peepholes?.let { tensor -> + tensor.takeIf { isOpt(it.name) } ?: LSTMContextRule.preparePeepholes(tensor) + } val sequenceLens = inputs.getOrNull(4) val initialState = inputs.getOrNull(5) diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/logical/Xor.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/logical/Xor.kt new file mode 100644 index 000000000..0f980744e --- /dev/null +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/logical/Xor.kt @@ -0,0 +1,58 @@ +package io.kinference.tfjs.operators.logical + +import io.kinference.attribute.Attribute +import io.kinference.data.ONNXData +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.BooleanNDArrayTFJS +import io.kinference.operator.* +import io.kinference.protobuf.message.TensorProto +import io.kinference.tfjs.data.tensors.TFJSTensor +import io.kinference.tfjs.data.tensors.asTensor + +sealed class Xor( + name: String, + info: OperatorInfo, + attributes: Map>, + inputs: List, + outputs: List +) : Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 7) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List): Xor { + return when (version ?: DEFAULT_VERSION.sinceVersion) { + in XorVer7.VERSION.asRange() -> XorVer7(name, attributes, inputs, outputs) + else -> error("Unsupported version of Xor operator: $version") + } + } + } +} + +class XorVer7( + name: String, + attributes: Map>, + inputs: List, + outputs: List +): Xor(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = setOf(TensorProto.DataType.BOOL) + + private val INPUTS_INFO = listOf( + IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false), + IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false) + ) + + private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "C", optional = false)) + + internal val VERSION = VersionInfo(sinceVersion = 7) + private val INFO = OperatorInfo("Xor", emptySet(), INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val left = inputs[0]!!.data as BooleanNDArrayTFJS + val right = inputs[1]!!.data as BooleanNDArrayTFJS + + val result = left xor right + return listOf(result.asTensor("C")) + } +} diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Gemm.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Gemm.kt index 80bb867d8..6d13562f9 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Gemm.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/math/Gemm.kt @@ -36,8 +36,8 @@ class GemmVer11( outputs: List ) : Gemm(name, INFO, attributes, inputs, outputs) { - private val alpha: Double by attribute { it: Number -> it.toDouble() } - private val beta: Double by attribute { it: Number -> it.toDouble() } + private val alpha: Float by attribute { it: Number -> it.toFloat() } + private val beta: Float by attribute { it: Number -> it.toFloat() } private val transA: Boolean by attribute { it: Number -> it.toInt() != 0 } private val transB: Boolean by attribute { it: Number -> it.toInt() != 0 } @@ -79,13 +79,13 @@ class GemmVer11( val c = inputs.getOrNull(2)?.data as? NumberNDArrayTFJS val result = tidyNDArray { - val alphaScalar = NumberNDArrayTFJS(tensor(arrayOf(alpha), emptyArray(), a.dtype)) + val alphaScalar = NDArrayTFJS.floatScalar(alpha) val matMulResult = alphaScalar * a.matmul(b, transposeLeft = transA, transposeRight = transB) - if (c == null) { + return@tidyNDArray if (c == null) { matMulResult } else { - val betaScalar = NumberNDArrayTFJS(tensor(arrayOf(beta), emptyArray(), c.dtype)) + val betaScalar = NDArrayTFJS.floatScalar(beta) matMulResult + betaScalar * c } } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/TreeEnsembleClassifier.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/TreeEnsembleClassifier.kt index d656cb7ce..4e982a863 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/TreeEnsembleClassifier.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/TreeEnsembleClassifier.kt @@ -127,7 +127,7 @@ class TreeEnsembleClassifierVer1( private suspend fun labeledTopClasses(array: NumberNDArrayTFJS): NDArrayTFJS { val shape = arrayOf(array.shape[0]) - val labelsIndices = array.argmax(axis = -1).tfjsArray.dataInt() + val labelsIndices = array.argmax(axis = -1).dataInt() return writeLabels(labels.labelsDataType, shape) { labels.labels[labelsIndices[it]]!! } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/trees/TFJSTreeEnsemble.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/trees/TFJSTreeEnsemble.kt index 4a993884e..d7900b3d7 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/trees/TFJSTreeEnsemble.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/ml/trees/TFJSTreeEnsemble.kt @@ -4,7 +4,6 @@ import io.kinference.ndarray.arrays.* import io.kinference.ndarray.extensions.dataFloat import io.kinference.primitives.types.DataType import io.kinference.trees.* -import io.kinference.utils.LoggerFactory class TFJSTreeEnsemble( aggregator: Aggregator, @@ -25,7 +24,7 @@ class TFJSTreeEnsemble( override suspend fun execute(input: NumberNDArrayTFJS): NumberNDArrayTFJS { require(input.type == DataType.DOUBLE || input.type == DataType.FLOAT) { "Integer inputs are not supported yet" } - val inputArray = input.tfjsArray.dataFloat() + val inputArray = input.dataFloat() val n = if (input.rank == 1) 1 else input.shape[0] val outputShape = if (numTargets == 1) arrayOf(n) else arrayOf(n, numTargets) val outputArray = FloatArray(n * numTargets) diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SplitToSequence.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SplitToSequence.kt index 115cd0740..1776eea27 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SplitToSequence.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/seq/SplitToSequence.kt @@ -80,7 +80,7 @@ class SplitToSequenceVer11( segments.map { it.reshape(newShape) } } } else { - val partsArray = parts.tfjsArray.dataInt() + val partsArray = parts.dataInt() if (parts.isScalar()) input.split(partsArray[0], axis) else input.split(partsArray, axis) }.toTypedArray() }.map { it.asTensor() } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/ConstantOfShape.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/ConstantOfShape.kt index 6cc8c9898..e83468be4 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/ConstantOfShape.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/ConstantOfShape.kt @@ -3,9 +3,10 @@ package io.kinference.tfjs.operators.tensor import io.kinference.attribute.Attribute import io.kinference.data.ONNXData import io.kinference.graph.Contexts -import io.kinference.ndarray.arrays.tensor +import io.kinference.ndarray.arrays.NDArrayTFJS import io.kinference.ndarray.extensions.* import io.kinference.operator.* +import io.kinference.primitives.types.DataType import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto import io.kinference.tfjs.data.tensors.TFJSTensor @@ -30,8 +31,12 @@ class ConstantOfShapeVer9(name: String, attributes: Map>, private val TYPE_CONSTRAINTS = PRIMITIVE_DATA_TYPES private val ATTRIBUTES_INFO = listOf( - AttributeInfo("value", setOf(AttributeProto.AttributeType.TENSOR), - default = tensor(floatArrayOf(0f), arrayOf(1), "float32").asTensor("value"), required = false) + AttributeInfo( + name = "value", + types = setOf(AttributeProto.AttributeType.TENSOR), + default = NDArrayTFJS.float(floatArrayOf(0f), arrayOf(1)).asTensor("value"), + required = false + ) ) private val INPUTS_INFO = listOf(IOInfo(0, setOf(TensorProto.DataType.INT64), "input", optional = false)) @@ -40,6 +45,16 @@ class ConstantOfShapeVer9(name: String, attributes: Map>, internal val VERSION = VersionInfo(sinceVersion = 9) private val INFO = OperatorInfo("ConstantOfShape", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + + private fun empty(type: DataType, shape: Array): NDArrayTFJS { + return when (type) { + DataType.FLOAT -> NDArrayTFJS.float(floatArrayOf(), shape) + DataType.INT -> NDArrayTFJS.int(intArrayOf(), shape) + DataType.BOOLEAN -> NDArrayTFJS.boolean(booleanArrayOf(), shape) + DataType.ALL -> NDArrayTFJS.string(emptyArray(), shape) + else -> error("") + } + } } private val value: TFJSTensor by attribute() @@ -48,7 +63,7 @@ class ConstantOfShapeVer9(name: String, attributes: Map>, val output = tidyNDArray { val shape = inputs[0]!!.data.dataInt().toTypedArray() if (shape.contains(0)) { - return@tidyNDArray tensor(emptyArray(), shape, value.data.dtype).toNDArray() + return@tidyNDArray empty(value.data.type, shape) } return@tidyNDArray value.data.broadcastTo(shape) diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Gather.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Gather.kt index afc13903c..8ffed8b63 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Gather.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Gather.kt @@ -3,6 +3,7 @@ package io.kinference.tfjs.operators.tensor import io.kinference.attribute.Attribute import io.kinference.data.ONNXData import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.NDArrayTFJS import io.kinference.ndarray.arrays.indexAxis import io.kinference.ndarray.extensions.* import io.kinference.operator.* @@ -59,7 +60,7 @@ class GatherVer1(name: String, attributes: Map>, inputs: val value = indicesData[idx] if (value < 0) indicesData[idx] = value + dim } - val preparedIndices = tensor(indicesData, indices.shapeArray, indices.dtype).toNDArray() + val preparedIndices = NDArrayTFJS.int(indicesData, indices.shapeArray) return@tidyNDArray data.gather(preparedIndices, actualAxis) } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/IsNaN.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/IsNaN.kt new file mode 100644 index 000000000..318a0edc7 --- /dev/null +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/IsNaN.kt @@ -0,0 +1,47 @@ +package io.kinference.tfjs.operators.tensor + +import io.kinference.attribute.Attribute +import io.kinference.data.ONNXData +import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.NumberNDArrayTFJS +import io.kinference.ndarray.extensions.isNaN +import io.kinference.operator.* +import io.kinference.tfjs.data.tensors.TFJSTensor +import io.kinference.tfjs.data.tensors.asTensor + +sealed class IsNaN(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : + Operator(name, info, attributes, inputs, outputs) { + companion object { + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 9) + + operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List) = when (version ?: DEFAULT_VERSION.sinceVersion) { + in IsNaNVer9.VERSION.asRange() -> IsNaNVer9(name, attributes, inputs, outputs) + else -> error("Unsupported version of IsNaN operator: $version") + } + } +} + +class IsNaNVer9(name: String, attributes: Map>, inputs: List, outputs: List) : + Abs(name, INFO, attributes, inputs, outputs) { + + companion object { + private val ATTRIBUTES_INFO = emptyList() + + private val INPUTS_INFO = listOf( + IOInfo(0, PRIMITIVE_DATA_TYPES, "X", differentiable = true, optional = false) + ) + + private val OUTPUTS_INFO = listOf( + IOInfo(0, PRIMITIVE_DATA_TYPES, "Y", differentiable = true, optional = false) + ) + + //Realized the latest version, but there is backward compatibility between operators + internal val VERSION = VersionInfo(sinceVersion = 9) + private val INFO = OperatorInfo("IsNaN", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val input = inputs[0]!!.data as NumberNDArrayTFJS + return listOf(input.isNaN().asTensor("Y")) + } +} diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Range.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Range.kt index 001289a76..e0b5c682a 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Range.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Range.kt @@ -3,12 +3,14 @@ package io.kinference.tfjs.operators.tensor import io.kinference.attribute.Attribute import io.kinference.data.ONNXData import io.kinference.graph.Contexts -import io.kinference.ndarray.arrays.range +import io.kinference.ndarray.arrays.NDArrayTFJS +import io.kinference.ndarray.arrays.NumberNDArrayTFJS import io.kinference.ndarray.extensions.* import io.kinference.operator.* -import io.kinference.protobuf.message.TensorProto.DataType +import io.kinference.protobuf.message.TensorProto import io.kinference.tfjs.data.tensors.TFJSTensor import io.kinference.tfjs.data.tensors.asTensor +import io.kinference.primitives.types.DataType sealed class Range(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { @@ -26,7 +28,13 @@ sealed class Range(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Range(name, INFO, attributes, inputs, outputs) { companion object { - private val TYPE_CONSTRAINTS = setOf(DataType.FLOAT, DataType.DOUBLE, DataType.INT16, DataType.INT32, DataType.INT64) + private val TYPE_CONSTRAINTS = setOf( + TensorProto.DataType.FLOAT, + TensorProto.DataType.DOUBLE, + TensorProto.DataType.INT16, + TensorProto.DataType.INT32, + TensorProto.DataType.INT64 + ) private val ATTRIBUTES_INFO = emptyList() @@ -43,21 +51,26 @@ class RangeVer11(name: String, attributes: Map>, inputs: } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val outputs = tidy { - val start = inputs[0]!!.data - val limit = inputs[1]!!.data - val delta = inputs[2]!!.data + val outputs = tidyNDArray { + val start = inputs[0]!!.data as NumberNDArrayTFJS + val limit = inputs[1]!!.data as NumberNDArrayTFJS + val delta = inputs[2]!!.data as NumberNDArrayTFJS + + require(start.type == limit.type && limit.type == delta.type) + { "Input tensors must have equal dtype, present: start: ${start.type}, limit: ${limit.type}, delta: ${delta.type}" } - require(start.dtype == limit.dtype && limit.dtype == delta.dtype) - { "Input tensors must have equal dtype, present: start: ${start.dtype}, limit: ${limit.dtype}, delta: ${delta.dtype}" } - val startNumber = if (start.dtype == "float32") start.dataFloat().first() else start.dataInt().first() - val limitNumber = if (limit.dtype == "float32") limit.dataFloat().first() else limit.dataInt().first() - val deltaNumber = if (delta.dtype == "float32") delta.dataFloat().first() else delta.dataInt().first() + val startNumber = start.singleValue() + val limitNumber = limit.singleValue() + val deltaNumber = delta.singleValue() - return@tidy arrayOf(range(startNumber, limitNumber, deltaNumber, start.dtype)) + return@tidyNDArray when (start.type) { + DataType.FLOAT -> NDArrayTFJS.floatRange(startNumber, limitNumber, deltaNumber) + DataType.INT -> NDArrayTFJS.intRange(startNumber, limitNumber, deltaNumber) + else -> error("Unsupported data type") + } } - return listOf(outputs[0].asTensor("output")) + return listOf(outputs.asTensor("output")) } } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/utils/ProfileInterfaces.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/utils/ProfileInterfaces.kt deleted file mode 100644 index 7d776f11d..000000000 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/utils/ProfileInterfaces.kt +++ /dev/null @@ -1,39 +0,0 @@ -package io.kinference.tfjs.utils - -import io.kinference.ndarray.arrays.ArrayTFJS -import kotlin.js.Promise - -interface ProfileInfo { - val newBytes: Int - val newTensors: Int - val peakBytes: Int - val kernels: Array - val result: Array - val kernelNames: Array -} - -interface KernelInfo { - val name: String - val bytesAdded: Int - val totalBytesSnapshot: Int - val tensorsAdded: Int - val totalTensorsSnapshot: Int - val inputShapes: Array> - val outputShapes: Array> - val kernelTimeMs: Int - val extraInfo: Promise -} - -interface TimingInfo { - val kernelMs: Int - val wallMs: Int - fun getExtraProfileInfo(): String? -} - -interface MemoryInfo { - val numTensors: Int - val numDataBuffers: Int - val numBytes: Int - val unreliable: Boolean? - val reasons: Array -} diff --git a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/logical/XorTest.kt b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/logical/XorTest.kt new file mode 100644 index 000000000..8bbbe59fd --- /dev/null +++ b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/logical/XorTest.kt @@ -0,0 +1,49 @@ +package io.kinference.tfjs.operators.logical + +import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner +import io.kinference.utils.TestRunner +import kotlin.test.Test + +class XorTest { + private fun getTargetPath(dirName: String) = "xor/$dirName/" + + @Test + fun test_xor_2d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor2d")) + } + + @Test + fun test_xor_3d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor3d")) + } + + @Test + fun test_xor_4d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor4d")) + } + + @Test + fun test_xor_broadcast_3v1d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast3v1d")) + } + + @Test + fun test_xor_broadcast_3v2d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast3v2d")) + } + + @Test + fun test_xor_broadcast_4v2d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v2d")) + } + + @Test + fun test_xor_broadcast_4v3d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v3d")) + } + + @Test + fun test_xor_broadcast_4v4d() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_xor_bcast4v4d")) + } +} diff --git a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/IsNaNTest.kt b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/IsNaNTest.kt new file mode 100644 index 000000000..f4805da53 --- /dev/null +++ b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/IsNaNTest.kt @@ -0,0 +1,14 @@ +package io.kinference.tfjs.operators.tensor + +import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner +import io.kinference.utils.TestRunner +import kotlin.test.Test + +class IsNaNTest { + private fun getTargetPath(dirName: String) = "isnan/$dirName/" + + @Test + fun test_isnan() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("test_isnan")) + } +} diff --git a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt index d3b4c8b60..680552305 100644 --- a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt +++ b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/utils/TFJSAssertions.kt @@ -3,6 +3,7 @@ package io.kinference.tfjs.utils import io.kinference.TestLoggerFactory import io.kinference.data.ONNXDataType import io.kinference.ndarray.extensions.* +import io.kinference.primitives.types.DataType import io.kinference.tfjs.TFJSData import io.kinference.tfjs.data.map.TFJSMap import io.kinference.tfjs.data.seq.TFJSSequence @@ -13,34 +14,35 @@ import kotlin.test.assertEquals object TFJSAssertions { val logger = TestLoggerFactory.create("Assertions") - @OptIn(ExperimentalUnsignedTypes::class) fun assertEquals(expected: TFJSTensor, actual: TFJSTensor, delta: Double) { - assertEquals(expected.data.dtype, actual.data.dtype, "Types of tensors ${expected.name} do not match") + assertEquals(expected.data.type, actual.data.type, "Types of tensors ${expected.name} do not match") ArrayAssertions.assertArrayEquals(expected.data.shapeArray, actual.data.shapeArray, "Shapes are incorrect") logger.info { "Errors for ${expected.name}:" } - when(expected.data.dtype) { - "float32" -> { + when(expected.data.type) { + DataType.FLOAT -> { val expectedArray = expected.data.dataFloat() val actualArray = actual.data.dataFloat() ArrayAssertions.assertEquals(expectedArray, actualArray, delta, expected.name.orEmpty()) } - "int32" -> { + DataType.INT -> { val expectedArray = expected.data.dataInt() val actualArray = actual.data.dataInt() ArrayAssertions.assertEquals(expectedArray, actualArray, delta, expected.name.orEmpty()) } - "bool" -> { + DataType.BOOLEAN -> { val expectedArray = expected.data.dataBool() val actualArray = actual.data.dataBool() ArrayAssertions.assertArrayEquals(expectedArray, actualArray, "Tensor ${expected.name} does not match") } + + else -> error("Unsupported data type of ${expected.name} tensor") } } diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt index a6bbbb3c5..3fa138c67 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt @@ -5,7 +5,7 @@ import io.kinference.ndarray.arrays.tiled.BooleanTiledArray import io.kinference.ndarray.arrays.tiled.LongTiledArray import io.kinference.ndarray.blockSizeByStrides import io.kinference.ndarray.broadcasting.Broadcasting -import io.kinference.ndarray.extensions.applyWithBroadcast +import io.kinference.ndarray.extensions.broadcasting.broadcastTwoTensorsBoolean import io.kinference.ndarray.extensions.isTransposeReshape import io.kinference.primitives.types.DataType import kotlin.jvm.JvmName @@ -139,66 +139,29 @@ open class BooleanNDArray(var array: BooleanTiledArray, strides: Strides) : NDAr } } - private fun orScalar(array: BooleanTiledArray, scalar: Boolean, destination: BooleanTiledArray) { - require(array.blocksNum == destination.blocksNum && array.blockSize == destination.blockSize) - - val arrayPointer = array.pointer() - val destPointer = destination.pointer() - - arrayPointer.mapTo(destPointer, destination.size) { it || scalar } - } - suspend fun or(other: BooleanNDArray, destination: MutableBooleanNDArray): BooleanNDArray { - when { - this.isScalar() && other.isScalar() -> destination.array.blocks[0][0] = this.array.blocks[0][0] or other.array.blocks[0][0] - this.isScalar() -> orScalar(other.array, this.array.blocks[0][0], destination.array) - other.isScalar() -> orScalar(this.array, other.array.blocks[0][0], destination.array) - else -> this.applyWithBroadcast(other, destination) { left, right, dest -> - left as BooleanNDArray; right as BooleanNDArray; dest as MutableBooleanNDArray - - val leftPointer = left.array.pointer() - val rightPointer = right.array.pointer() - val destPointer = dest.array.pointer() - - destPointer.acceptDouble(leftPointer, rightPointer, dest.array.size) { _, a, b -> a || b } - } + return broadcastTwoTensorsBoolean(this, other, destination) { + left: Boolean, right: Boolean -> left || right } - - return destination } suspend infix fun or(other: BooleanNDArray) = or(other, MutableBooleanNDArray(broadcastShape(listOf(this.shape, other.shape)))) - private fun andScalar(array: BooleanTiledArray, scalar: Boolean, destination: BooleanTiledArray) { - require(array.blocksNum == destination.blocksNum && array.blockSize == destination.blockSize) - - val arrayPointer = array.pointer() - val destPointer = destination.pointer() - - arrayPointer.mapTo(destPointer, destination.size) { it and scalar } - } - suspend fun and(other: BooleanNDArray, destination: MutableBooleanNDArray): BooleanNDArray { - when { - this.isScalar() && other.isScalar() -> destination.array.blocks[0][0] = this.array.blocks[0][0] and other.array.blocks[0][0] - this.isScalar() -> andScalar(other.array, this.array.blocks[0][0], destination.array) - other.isScalar() -> andScalar(this.array, other.array.blocks[0][0], destination.array) - else -> this.applyWithBroadcast(other, destination) { left, right, dest -> - left as BooleanNDArray; right as BooleanNDArray; dest as MutableBooleanNDArray - - val leftPointer = left.array.pointer() - val rightPointer = right.array.pointer() - val destPointer = dest.array.pointer() - - destPointer.acceptDouble(leftPointer, rightPointer, dest.array.size) { _, a, b -> a and b } - } + return broadcastTwoTensorsBoolean(this, other, destination) { + left: Boolean, right: Boolean -> left && right } - - return destination } suspend infix fun and(other: BooleanNDArray) = and(other, MutableBooleanNDArray(broadcastShape(listOf(this.shape, other.shape)))) + suspend fun xor(other: BooleanNDArray, destination: MutableBooleanNDArray): BooleanNDArray { + return broadcastTwoTensorsBoolean(this, other, destination) { + left: Boolean, right: Boolean -> left xor right + } + } + + suspend infix fun xor(other: BooleanNDArray) = xor(other, MutableBooleanNDArray(broadcastShape(listOf(this.shape, other.shape)))) override suspend fun concat(others: List, axis: Int): MutableBooleanNDArray { val actualAxis = indexAxis(axis) diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/extensions/isNaN/IsNaNUtils.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/extensions/isNaN/IsNaNUtils.kt new file mode 100644 index 000000000..2757f0bfd --- /dev/null +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/extensions/isNaN/IsNaNUtils.kt @@ -0,0 +1,5 @@ +package io.kinference.ndarray.extensions.isNaN + +import io.kinference.primitives.types.PrimitiveType + +fun PrimitiveType.isNaN(): Boolean = throw UnsupportedOperationException() diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/extensions/isNaN/PrimitiveIsNaN.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/extensions/isNaN/PrimitiveIsNaN.kt new file mode 100644 index 000000000..df137e0b0 --- /dev/null +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/extensions/isNaN/PrimitiveIsNaN.kt @@ -0,0 +1,29 @@ +@file:GeneratePrimitives( + DataType.DOUBLE, + DataType.FLOAT +) + +package io.kinference.ndarray.extensions.isNaN + +import io.kinference.ndarray.arrays.BooleanNDArray +import io.kinference.ndarray.arrays.PrimitiveNDArray +import io.kinference.primitives.annotations.GeneratePrimitives +import io.kinference.primitives.types.DataType + +fun PrimitiveNDArray.isNaN(): BooleanNDArray { + val output = BooleanNDArray(strides) + + val inputBlockIter = this.array.blocks.iterator() + val outputBlockIter = output.array.blocks.iterator() + + for (blockIdx in 0 until this.array.blocksNum) { + val inputBlock = inputBlockIter.next() + val outputBlock = outputBlockIter.next() + + for (idx in outputBlock.indices) { + outputBlock[idx] = inputBlock[idx].isNaN() + } + } + + return output +} diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/FunctionInterfaces.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/FunctionInterfaces.kt index dc993159b..32696e14a 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/FunctionInterfaces.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/FunctionInterfaces.kt @@ -4,7 +4,7 @@ import io.kinference.ndarray.arrays.ArrayTFJS import io.kinference.ndarray.arrays.NumberNDArrayTFJS import io.kinference.utils.Closeable -external interface MomentsOutputTFJS { +internal external interface MomentsOutputTFJS { val mean: ArrayTFJS val variance: ArrayTFJS } @@ -20,9 +20,9 @@ data class MomentsOutput( } } -fun MomentsOutputTFJS.toNDArray() = MomentsOutput(NumberNDArrayTFJS(mean), NumberNDArrayTFJS(variance)) +internal fun MomentsOutputTFJS.toNDArray() = MomentsOutput(NumberNDArrayTFJS(mean), NumberNDArrayTFJS(variance)) -data class QrDecompositionResultTFJS( +internal data class QrDecompositionResultTFJS( val q: ArrayTFJS, val r: ArrayTFJS ) @@ -37,4 +37,4 @@ data class QrDecompositionResult( } } -fun QrDecompositionResultTFJS.toNDArray() = QrDecompositionResult(NumberNDArrayTFJS(q), NumberNDArrayTFJS(r)) +internal fun QrDecompositionResultTFJS.toNDArray() = QrDecompositionResult(NumberNDArrayTFJS(q), NumberNDArrayTFJS(r)) diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/Utils.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/Utils.kt index de7056a69..7de08673a 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/Utils.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/Utils.kt @@ -30,6 +30,7 @@ fun String.resolveTFJSDataType(): DataType { "float32" -> DataType.FLOAT "int32" -> DataType.INT "bool" -> DataType.BOOLEAN + "string" -> DataType.ALL else -> error("Unsupported type: $this") } } @@ -38,7 +39,7 @@ inline fun T.applyIf(predicate: Boolean, func: (T) -> (T)): T { return if (predicate) func(this) else this } -fun makeNDArray(tfjsArray: ArrayTFJS, type: DataType): NDArrayTFJS { +internal fun makeNDArray(tfjsArray: ArrayTFJS, type: DataType): NDArrayTFJS { return when (type) { DataType.FLOAT, DataType.INT -> MutableNumberNDArrayTFJS(tfjsArray) DataType.BOOLEAN -> MutableBooleanNDArrayTFJS(tfjsArray) @@ -46,7 +47,7 @@ fun makeNDArray(tfjsArray: ArrayTFJS, type: DataType): NDArrayTFJS { } } -fun makeNDArray(tfjsArray: ArrayTFJS, type: String) = makeNDArray(tfjsArray, type.resolveTFJSDataType()) +internal fun makeNDArray(tfjsArray: ArrayTFJS, type: String) = makeNDArray(tfjsArray, type.resolveTFJSDataType()) internal fun activateCpuBackend() { versionCpu.length diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/BooleanNDArrayTFJS.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/BooleanNDArrayTFJS.kt index b247ce5ce..822d13cd1 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/BooleanNDArrayTFJS.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/BooleanNDArrayTFJS.kt @@ -3,7 +3,7 @@ package io.kinference.ndarray.arrays import io.kinference.ndarray.extensions.* import io.kinference.primitives.types.DataType -open class BooleanNDArrayTFJS(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray) { +open class BooleanNDArrayTFJS internal constructor(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray) { override val type: DataType = DataType.BOOLEAN override fun get(index: IntArray): Boolean { @@ -69,9 +69,13 @@ open class BooleanNDArrayTFJS(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray) { infix fun and(other: BooleanNDArrayTFJS): BooleanNDArrayTFJS { return BooleanNDArrayTFJS(tfjsArray.and(other.tfjsArray)) } + + infix fun xor(other: BooleanNDArrayTFJS): BooleanNDArrayTFJS { + return BooleanNDArrayTFJS(tfjsArray.xor(other.tfjsArray)) + } } -class MutableBooleanNDArrayTFJS(tfjsArray: ArrayTFJS) : BooleanNDArrayTFJS(tfjsArray), MutableNDArray { +class MutableBooleanNDArrayTFJS internal constructor(tfjsArray: ArrayTFJS) : BooleanNDArrayTFJS(tfjsArray), MutableNDArray { override fun clone(): MutableBooleanNDArrayTFJS { return MutableBooleanNDArrayTFJS(tfjsArray.clone()) } diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/MutableNumberNDArrayTFJS.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/MutableNumberNDArrayTFJS.kt index e4e9382f2..59f1ebe2a 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/MutableNumberNDArrayTFJS.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/MutableNumberNDArrayTFJS.kt @@ -2,7 +2,7 @@ package io.kinference.ndarray.arrays import io.kinference.ndarray.extensions.* -class MutableNumberNDArrayTFJS(tfjsArray: ArrayTFJS) : NumberNDArrayTFJS(tfjsArray), MutableNumberNDArray { +class MutableNumberNDArrayTFJS internal constructor(tfjsArray: ArrayTFJS) : NumberNDArrayTFJS(tfjsArray), MutableNumberNDArray { override fun set(index: IntArray, value: Any) { require(value is Number) tfjsArray.bufferSync().set(value, *index) diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NDArrayTFJS.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NDArrayTFJS.kt index 07c2000e3..e256cd06e 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NDArrayTFJS.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NDArrayTFJS.kt @@ -6,7 +6,7 @@ import io.kinference.ndarray.extensions.* import io.kinference.primitives.types.DataType import kotlinx.coroutines.await -abstract class NDArrayTFJS(tfjsArray: ArrayTFJS) : NDArray { +abstract class NDArrayTFJS internal constructor(internal var tfjsArray: ArrayTFJS) : NDArray { init { if (!isActivated) { activateDefaultBackend() @@ -14,9 +14,6 @@ abstract class NDArrayTFJS(tfjsArray: ArrayTFJS) : NDArray { } } - var tfjsArray = tfjsArray - protected set - override val strides get() = Strides(tfjsArray.shape.toIntArray()) @@ -158,9 +155,68 @@ abstract class NDArrayTFJS(tfjsArray: ArrayTFJS) : NDArray { return zero } + private fun LongArray.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun UByteArray.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun UShortArray.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun UIntArray.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun ULongArray.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun Array.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun Array.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun Array.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun Array.toFloatArray() = FloatArray(size) { this[it].toFloat() } + private fun Array.toFloatArray() = FloatArray(size) { this[it].toFloat() } + + private fun LongArray.toIntArray() = IntArray(size) { this[it].toInt() } + private fun UByteArray.toIntArray() = IntArray(size) { this[it].toInt() } + private fun UShortArray.toIntArray() = IntArray(size) { this[it].toInt() } + private fun UIntArray.toIntArray() = IntArray(size) { this[it].toInt() } + private fun ULongArray.toIntArray() = IntArray(size) { this[it].toInt() } + private fun Array.toIntArray() = IntArray(size) { this[it].toInt() } + private fun Array.toIntArray() = IntArray(size) { this[it].toInt() } + private fun Array.toIntArray() = IntArray(size) { this[it].toInt() } + private fun Array.toIntArray() = IntArray(size) { this[it].toInt() } + private fun Array.toIntArray() = IntArray(size) { this[it].toInt() } + fun float(values: FloatArray, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: DoubleArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toTypedArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: LongArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: IntArray, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: ShortArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toTypedArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: ByteArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toTypedArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "float32")) + fun float(values: UByteArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: UShortArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: UIntArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: ULongArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun float(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toFloatArray(), shape, "float32")) + fun int(values: IntArray, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "int32")) + fun int(values: ShortArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toTypedArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "int32")) + fun int(values: ByteArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toTypedArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values, shape, "int32")) + fun int(values: LongArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: UByteArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: UShortArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: UIntArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: ULongArray, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun int(values: Array, shape: Array) = NumberNDArrayTFJS(tensor(values.toIntArray(), shape, "int32")) + fun boolean(values: Array, shape: Array) = BooleanNDArrayTFJS(tensor(values, shape)) + fun boolean(values: BooleanArray, shape: Array) = BooleanNDArrayTFJS(tensor(values.toTypedArray(), shape)) fun string(values: Array, shape: Array) = StringNDArrayTFJS(tensor(values, shape)) fun float(shape: Array, init: (Int) -> Float) = NumberNDArrayTFJS(tensor(FloatArray(shape.times(), init), shape, "float32")) @@ -192,8 +248,8 @@ abstract class NDArrayTFJS(tfjsArray: ArrayTFJS) : NDArray { return StringNDArrayTFJS(tensor(array, shape)) } - fun floatScalar(value: Float) = NumberNDArrayTFJS(scalar(value, "float32")) - fun intScalar(value: Int) = NumberNDArrayTFJS(scalar(value, "int32")) + fun floatScalar(value: Number) = NumberNDArrayTFJS(scalar(value.toFloat(), "float32")) + fun intScalar(value: Number) = NumberNDArrayTFJS(scalar(value.toInt(), "int32")) fun booleanScalar(value: Boolean) = BooleanNDArrayTFJS(scalar(value)) fun stringScalar(value: String) = StringNDArrayTFJS(scalar(value)) @@ -205,21 +261,21 @@ abstract class NDArrayTFJS(tfjsArray: ArrayTFJS) : NDArray { fun intOnes(shape: Array) = NumberNDArrayTFJS(ones(shape, "int32")) fun booleanOnes(shape: Array) = BooleanNDArrayTFJS(ones(shape, "bool")) - fun floatRange(start: Float, stop: Float, step: Float) = NumberNDArrayTFJS(range(start, stop, step, "float32")) - fun intRange(start: Int, stop: Int, step: Int) = NumberNDArrayTFJS(range(start, stop, step, "int32")) + fun floatRange(start: Number, stop: Number, step: Number) = NumberNDArrayTFJS(range(start.toFloat(), stop.toFloat(), step.toFloat(), "float32")) + fun intRange(start: Number, stop: Number, step: Number) = NumberNDArrayTFJS(range(start.toInt(), stop.toInt(), step.toInt(), "int32")) - fun floatFill(shape: Array, value: Float) = NumberNDArrayTFJS(fill(shape, value, "float32")) - fun intFill(shape: Array, value: Int) = NumberNDArrayTFJS(fill(shape, value, "int32")) + fun floatFill(shape: Array, value: Number) = NumberNDArrayTFJS(fill(shape, value.toFloat(), "float32")) + fun intFill(shape: Array, value: Number) = NumberNDArrayTFJS(fill(shape, value.toFloat(), "int32")) fun stringFill(shape: Array, value: String) = StringNDArrayTFJS(fill(shape, value, "string")) fun onesLike(tensor: NumberNDArrayTFJS) = NumberNDArrayTFJS(onesLike(tensor.tfjsArray)) fun zerosLike(tensor: NumberNDArrayTFJS) = NumberNDArrayTFJS(zerosLike(tensor.tfjsArray)) - fun oneHotFloat(indices: NumberNDArrayTFJS, depth: Int, onValue: Float = 1f, offValue: Float = 0f) = - NumberNDArrayTFJS(oneHot(indices.tfjsArray, depth, onValue, offValue, "float32")) + fun oneHotFloat(indices: NumberNDArrayTFJS, depth: Int, onValue: Number = 1f, offValue: Number = 0f) = + NumberNDArrayTFJS(oneHot(indices.tfjsArray, depth, onValue.toFloat(), offValue.toFloat(), "float32")) - fun oneHotInt(indices: NumberNDArrayTFJS, depth: Int, onValue: Int = 1, offValue: Int = 0) = - NumberNDArrayTFJS(oneHot(indices.tfjsArray, depth, onValue, offValue, "int32")) + fun oneHotInt(indices: NumberNDArrayTFJS, depth: Int, onValue: Number = 1, offValue: Number = 0) = + NumberNDArrayTFJS(oneHot(indices.tfjsArray, depth, onValue.toInt(), offValue.toInt(), "int32")) fun oneHotBool(indices: NumberNDArrayTFJS, depth: Int, onValue: Boolean = true, offValue: Boolean = false) = NumberNDArrayTFJS(oneHot(indices.tfjsArray, depth, onValue.toInt(), offValue.toInt(), "bool")) diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NumberNDArrayTFJS.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NumberNDArrayTFJS.kt index f4643715d..d35369594 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NumberNDArrayTFJS.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/NumberNDArrayTFJS.kt @@ -2,7 +2,7 @@ package io.kinference.ndarray.arrays import io.kinference.ndarray.extensions.* -open class NumberNDArrayTFJS(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray), NumberNDArray { +open class NumberNDArrayTFJS internal constructor(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray), NumberNDArray { override fun get(index: IntArray): Number { return tfjsArray.bufferSync().get(*index) as Number } diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/StringNDArray.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/StringNDArray.kt index 47a865da3..84c886bbb 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/StringNDArray.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/StringNDArray.kt @@ -3,7 +3,7 @@ package io.kinference.ndarray.arrays import io.kinference.ndarray.extensions.* import io.kinference.primitives.types.DataType -open class StringNDArrayTFJS(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray) { +open class StringNDArrayTFJS internal constructor(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray) { override val type: DataType = DataType.ALL override fun get(index: IntArray): String { @@ -59,7 +59,7 @@ open class StringNDArrayTFJS(tfjsArray: ArrayTFJS) : NDArrayTFJS(tfjsArray) { } } -class MutableStringNDArrayTFJS(tfjsArray: ArrayTFJS) : StringNDArrayTFJS(tfjsArray), MutableNDArray { +class MutableStringNDArrayTFJS internal constructor(tfjsArray: ArrayTFJS) : StringNDArrayTFJS(tfjsArray), MutableNDArray { override fun clone(): MutableStringNDArrayTFJS { return MutableStringNDArrayTFJS(tfjsArray.clone()) } diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/Tensor.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/Tensor.kt index 740b6a921..bfa6f8f96 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/Tensor.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/arrays/Tensor.kt @@ -6,7 +6,7 @@ import org.khronos.webgl.* import kotlin.js.Promise @JsName("TensorBuffer") -external class MutableBuffer { +internal external class MutableBuffer { val size: Int val shape: Array val strides: Array @@ -18,11 +18,11 @@ external class MutableBuffer { } @JsName("Tensor") -open external class ArrayTFJS { +internal external class ArrayTFJS { val shape: Array val size: Int val dtype: String /* "float32" | "int32" | "bool" | "complex64" | "string" */ - open val rank: Int + val rank: Int internal fun data(): Promise internal fun dataSync(): dynamic @@ -34,33 +34,31 @@ open external class ArrayTFJS { } -external fun tensor(values: Float32Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Float32Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Int32Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Uint8Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Int32Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS +internal external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Uint8Array, shape: Array, dtype: String): ArrayTFJS +internal external fun range(start: Number, stop: Number, step: Number?, dtype: String?): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS -external fun tensor(values: Array, shape: Array, dtype: String): ArrayTFJS - -external fun range(start: Number, stop: Number, step: Number?, dtype: String?): ArrayTFJS - -external fun fill(shape: Array, value: Number, dtype: String): ArrayTFJS +internal external fun fill(shape: Array, value: Number, dtype: String): ArrayTFJS internal external fun fill(shape: Array, value: String, dtype: String): ArrayTFJS -external fun scalar(value: Number, dtype: String): ArrayTFJS +internal external fun scalar(value: Number, dtype: String): ArrayTFJS internal external fun scalar(value: Boolean, dtype: String): ArrayTFJS internal external fun scalar(value: String, dtype: String): ArrayTFJS -external fun zeros(shape: Array, dtype: String): ArrayTFJS -external fun zerosLike(x: ArrayTFJS): ArrayTFJS +internal external fun zeros(shape: Array, dtype: String): ArrayTFJS +internal external fun zerosLike(x: ArrayTFJS): ArrayTFJS -external fun ones(shape: Array, dtype: String): ArrayTFJS -external fun onesLike(x: ArrayTFJS): ArrayTFJS +internal external fun ones(shape: Array, dtype: String): ArrayTFJS +internal external fun onesLike(x: ArrayTFJS): ArrayTFJS -external fun oneHot(indices: ArrayTFJS, depth: Int, onValue: Number?, offValue: Number?, dtype: String?): ArrayTFJS +internal external fun oneHot(indices: ArrayTFJS, depth: Int, onValue: Number?, offValue: Number?, dtype: String?): ArrayTFJS diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Functions.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Functions.kt index d08c645bb..f7f2df4c7 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Functions.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Functions.kt @@ -123,6 +123,8 @@ internal external val logicalOr: (a: ArrayTFJS, b: ArrayTFJS) -> ArrayTFJS internal external val logicalAnd: (a: ArrayTFJS, b: ArrayTFJS) -> ArrayTFJS +internal external val logicalXor: (a: ArrayTFJS, b: ArrayTFJS) -> ArrayTFJS + internal external val pad: (x: ArrayTFJS, paddings: Array>, constantValue: dynamic) -> ArrayTFJS internal external val mirrorPad: (x: ArrayTFJS, paddings: Array>, mode: String) -> ArrayTFJS @@ -162,3 +164,5 @@ internal external val linalg: Linalg internal external val floor: (x: ArrayTFJS) -> ArrayTFJS internal external val isInf: (x: ArrayTFJS) -> ArrayTFJS + +internal external val isNaN: (x: ArrayTFJS) -> ArrayTFJS diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Memory.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Memory.kt index b3a796d7b..0aa4ada9b 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Memory.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/core/Memory.kt @@ -5,9 +5,9 @@ package io.kinference.ndarray.core import io.kinference.ndarray.arrays.ArrayTFJS -external fun tidy(nameOrFn: () -> Array, fn: (() -> Array)?): Array +internal external fun tidy(nameOrFn: () -> Array, fn: (() -> Array)?): Array -external fun tidy(nameOrFn: String, fn: (() -> Array)?): Array +internal external fun tidy(nameOrFn: String, fn: (() -> Array)?): Array @JsName("Engine") internal external class InternalTfjsEngine { diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/MemoryExtension.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/MemoryExtension.kt index 0a85bab39..cfdc24871 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/MemoryExtension.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/MemoryExtension.kt @@ -4,7 +4,7 @@ import io.kinference.ndarray.arrays.ArrayTFJS import io.kinference.ndarray.arrays.NDArrayTFJS import io.kinference.ndarray.core.* -suspend fun tidy(fn: suspend () -> Array): Array { +internal suspend fun tidy(fn: suspend () -> Array): Array { val engine = engine() lateinit var result: Array return scopedRun( @@ -35,14 +35,14 @@ suspend fun tidyNDArray(fn: suspend () -> T): T { return rawOutput.toNDArray() as T } -suspend fun scopedRun(start: () -> Unit, end: () -> Unit, fn: suspend () -> Array): Array { +internal suspend fun scopedRun(start: () -> Unit, end: () -> Unit, fn: suspend () -> Array): Array { start() try { val res = fn() - end(); - return res; + end() + return res } catch (e: Exception) { end() - throw e; + throw e } } diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/NDArrayExtension.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/NDArrayExtension.kt index a6faf7e0a..9abc98147 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/NDArrayExtension.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/NDArrayExtension.kt @@ -7,18 +7,18 @@ import io.kinference.ndarray.arrays.* import io.kinference.ndarray.core.* import io.kinference.primitives.types.DataType -val NDArrayTFJS.dtype: String +internal val NDArrayTFJS.dtype: String get() = tfjsArray.dtype val NDArrayTFJS.shapeArray: Array get() = tfjsArray.shape -fun ArrayTFJS.toNDArray() = makeNDArray(this, dtype) +internal fun ArrayTFJS.toNDArray() = makeNDArray(this, dtype) -fun Array.getArrays() = Array(this.size) { this[it].tfjsArray } -fun List.getArrays() = Array(this.size) { this[it].tfjsArray } +internal fun Array.getArrays() = Array(this.size) { this[it].tfjsArray } +internal fun List.getArrays() = Array(this.size) { this[it].tfjsArray } -fun Array.getNDArrays() = Array(this.size) { this[it].toNDArray() } +internal fun Array.getNDArrays() = Array(this.size) { this[it].toNDArray() } fun T.dataInt() = tfjsArray.dataInt() fun T.dataFloat() = tfjsArray.dataFloat() @@ -78,10 +78,6 @@ fun NumberNDArrayTFJS.add(tensors: Array) = NumberNDArrayTFJS fun NumberNDArrayTFJS.add(vararg tensors: NumberNDArrayTFJS) = add(tensors as Array) -fun NumberNDArrayTFJS.dot(other: NumberNDArrayTFJS) = NumberNDArrayTFJS(dot(tfjsArray, other.tfjsArray)) - -fun NumberNDArrayTFJS.softmax(axis: Int = -1) = NumberNDArrayTFJS(softmax(tfjsArray, axis)) - fun NumberNDArrayTFJS.min(axes: Array, keepDims: Boolean = false) = NumberNDArrayTFJS(min(tfjsArray, axes, keepDims)) fun NumberNDArrayTFJS.min(keepDims: Boolean = false) = NumberNDArrayTFJS(min(tfjsArray, null, keepDims)) @@ -219,3 +215,5 @@ suspend fun NumberNDArrayTFJS.isInf(detectNegative: Boolean = true, detectPositi else -> error("At least one of detectNegative or detectPositive must be true") } } + +fun NumberNDArrayTFJS.isNaN() = BooleanNDArrayTFJS(tfjsArray.isNaN()) diff --git a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/TensorExtenstion.kt b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/TensorExtenstion.kt index 510ecaafb..5c8337eff 100644 --- a/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/TensorExtenstion.kt +++ b/ndarray/ndarray-tfjs/src/jsMain/kotlin/io.kinference.ndarray/extensions/TensorExtenstion.kt @@ -40,250 +40,257 @@ import io.kinference.ndarray.core.unstack import io.kinference.ndarray.core.where import org.khronos.webgl.* -fun tensor(values: FloatArray, shape: Array, dtype: String): ArrayTFJS = tensor(values.unsafeCast(), shape, dtype) +internal fun tensor(values: FloatArray, shape: Array, dtype: String): ArrayTFJS = tensor(values.unsafeCast(), shape, dtype) -fun tensor(values: IntArray, shape: Array, dtype: String): ArrayTFJS = tensor(values.unsafeCast(), shape, dtype) +internal fun tensor(values: IntArray, shape: Array, dtype: String): ArrayTFJS = tensor(values.unsafeCast(), shape, dtype) -fun tensor(values: UByteArray, shape: Array, dtype: String): ArrayTFJS = tensor(values.unsafeCast(), shape, dtype) +internal fun tensor(values: UByteArray, shape: Array, dtype: String): ArrayTFJS = tensor(values.unsafeCast(), shape, dtype) -fun tensor(values: Array, shape: Array) = tensor(values, shape, "bool") +internal fun tensor(values: Array, shape: Array) = tensor(values, shape, "bool") -fun tensor(values: Array, shape: Array) = tensor(values, shape, "string") +internal fun tensor(values: Array, shape: Array) = tensor(values, shape, "string") -fun scalar(value: Boolean) = scalar(value, "bool") +internal fun scalar(value: Boolean) = scalar(value, "bool") -fun scalar(value: Float) = scalar(value, "float32") +internal fun scalar(value: Float) = scalar(value, "float32") -fun scalar(value: Int) = scalar(value, "int32") +internal fun scalar(value: Int) = scalar(value, "int32") -fun scalar(value: String) = scalar(value, "string") +internal fun scalar(value: String) = scalar(value, "string") -fun fill(shape: Array, value: String) = fill(shape, value, "string") +internal fun fill(shape: Array, value: String) = fill(shape, value, "string") -fun ArrayTFJS.dataInt() = dataSync().unsafeCast().unsafeCast() +internal fun ArrayTFJS.dataInt() = dataSync().unsafeCast().unsafeCast() -fun ArrayTFJS.dataFloat() = dataSync().unsafeCast().unsafeCast() +internal fun ArrayTFJS.dataFloat() = dataSync().unsafeCast().unsafeCast() -fun ArrayTFJS.dataBool() = dataSync().unsafeCast>() +internal fun ArrayTFJS.dataBool() = dataSync().unsafeCast>() -fun ArrayTFJS.dataString() = dataSync().unsafeCast>() +internal fun ArrayTFJS.dataString() = dataSync().unsafeCast>() -operator fun ArrayTFJS.plus(other: ArrayTFJS) = io.kinference.ndarray.core.add(this, other) +internal operator fun ArrayTFJS.plus(other: ArrayTFJS) = io.kinference.ndarray.core.add(this, other) -operator fun ArrayTFJS.minus(other: ArrayTFJS) = sub(this, other) +internal operator fun ArrayTFJS.minus(other: ArrayTFJS) = sub(this, other) -operator fun ArrayTFJS.div(other: ArrayTFJS) = div(this, other) +internal operator fun ArrayTFJS.div(other: ArrayTFJS) = div(this, other) -operator fun ArrayTFJS.times(other: ArrayTFJS) = mul(this, other) +internal operator fun ArrayTFJS.times(other: ArrayTFJS) = mul(this, other) -fun ArrayTFJS.broadcastTo(shape: Array) = broadcastTo(this, shape) +internal fun ArrayTFJS.broadcastTo(shape: Array) = broadcastTo(this, shape) -fun ArrayTFJS.cast(dtype: String) = cast(this, dtype) +internal fun ArrayTFJS.cast(dtype: String) = cast(this, dtype) -fun ArrayTFJS.reshape(shape: Array) = reshape(this, shape) -fun ArrayTFJS.reshape(shape: IntArray) = reshape(this, shape.toTypedArray()) +internal fun ArrayTFJS.reshape(shape: Array) = reshape(this, shape) +internal fun ArrayTFJS.reshape(shape: IntArray) = reshape(this, shape.toTypedArray()) -fun ArrayTFJS.gather(indices: ArrayTFJS, axis: Int = 0, batchDims: Int = 0) = gather(this, indices, axis, batchDims) +internal fun ArrayTFJS.gather(indices: ArrayTFJS, axis: Int = 0, batchDims: Int = 0) = gather(this, indices, axis, batchDims) -fun ArrayTFJS.moments(axis: Int, keepDims: Boolean = false) = moments(this, arrayOf(axis), keepDims) +internal fun ArrayTFJS.moments(axis: Int, keepDims: Boolean = false) = moments(this, arrayOf(axis), keepDims) -fun ArrayTFJS.moments(axes: Array, keepDims: Boolean = false) = moments(this, axes, keepDims) +internal fun ArrayTFJS.moments(axes: Array, keepDims: Boolean = false) = moments(this, axes, keepDims) -fun ArrayTFJS.sum(axis: Int, keepDims: Boolean = false) = sum(this, arrayOf(axis), keepDims) +internal fun ArrayTFJS.sum(axis: Int, keepDims: Boolean = false) = sum(this, arrayOf(axis), keepDims) -fun ArrayTFJS.sum(axes: Array, keepDims: Boolean = false) = sum(this, axes, keepDims) +internal fun ArrayTFJS.sum(axes: Array, keepDims: Boolean = false) = sum(this, axes, keepDims) -fun ArrayTFJS.sum(keepDims: Boolean = false) = sum(this, null, keepDims) +internal fun ArrayTFJS.sum(keepDims: Boolean = false) = sum(this, null, keepDims) -fun Array.sum() = addN(this) +internal fun Array.sum() = addN(this) -fun ArrayTFJS.add(tensors: Array) = addN(arrayOf(this, *tensors)) +internal fun ArrayTFJS.add(tensors: Array) = addN(arrayOf(this, *tensors)) -fun ArrayTFJS.add(vararg tensors: ArrayTFJS) = addN(arrayOf(this, *tensors)) +internal fun ArrayTFJS.add(vararg tensors: ArrayTFJS) = addN(arrayOf(this, *tensors)) -fun ArrayTFJS.transpose() = transpose(this, null) +internal fun ArrayTFJS.transpose() = transpose(this, null) -fun ArrayTFJS.transpose(permutation: Array? = null) = transpose(this, permutation) +internal fun ArrayTFJS.transpose(permutation: Array? = null) = transpose(this, permutation) -fun ArrayTFJS.unstack(axis: Int = 0) = unstack(this, axis) +internal fun ArrayTFJS.unstack(axis: Int = 0) = unstack(this, axis) -fun Array.stack(axis: Int = 0) = stack(this, axis) +internal fun Array.stack(axis: Int = 0) = stack(this, axis) -fun Collection.stack(axis: Int = 0) = this.toTypedArray().stack(axis) +internal fun Collection.stack(axis: Int = 0) = this.toTypedArray().stack(axis) -fun ArrayTFJS.stack(vararg tensors: ArrayTFJS, axis: Int = 0) = stack(arrayOf(this, *tensors), axis) +internal fun ArrayTFJS.stack(vararg tensors: ArrayTFJS, axis: Int = 0) = stack(arrayOf(this, *tensors), axis) -fun ArrayTFJS.dot(other: ArrayTFJS) = dot(this, other) +internal fun ArrayTFJS.dot(other: ArrayTFJS) = dot(this, other) -fun Array.concat(axis: Int = 0) = concat(this, axis) +internal fun Array.concat(axis: Int = 0) = concat(this, axis) -fun ArrayTFJS.concat(vararg tensors: ArrayTFJS, axis: Int = 0) = concat(arrayOf(this, *tensors), axis) +internal fun ArrayTFJS.concat(vararg tensors: ArrayTFJS, axis: Int = 0) = concat(arrayOf(this, *tensors), axis) -fun ArrayTFJS.split(split: Array, axis: Int) = split(this, split, axis) +internal fun ArrayTFJS.split(split: Array, axis: Int) = split(this, split, axis) -fun ArrayTFJS.split(splitSize: Int, axis: Int) = split(this, splitSize, axis) +internal fun ArrayTFJS.split(splitSize: Int, axis: Int) = split(this, splitSize, axis) -fun ArrayTFJS.matMul(other: ArrayTFJS, transposeLeft: Boolean = false, transposeRight: Boolean = false) = matMul(this, other, transposeLeft, transposeRight) +internal fun ArrayTFJS.matMul(other: ArrayTFJS, transposeLeft: Boolean = false, transposeRight: Boolean = false) = matMul(this, other, transposeLeft, transposeRight) -fun ArrayTFJS.softmax(axis: Int = -1) = softmax(this, axis) +internal fun ArrayTFJS.softmax(axis: Int = -1) = softmax(this, axis) -fun ArrayTFJS.logSoftmax(axis: Int = -1) = io.kinference.ndarray.core.logSoftmax(this , axis) +internal fun ArrayTFJS.logSoftmax(axis: Int = -1) = logSoftmax(this , axis) -fun ArrayTFJS.log() = log(this) +internal fun ArrayTFJS.log() = log(this) -fun ArrayTFJS.erf() = erf(this) +internal fun ArrayTFJS.erf() = erf(this) -fun ArrayTFJS.flatten() = reshape(this, arrayOf(this.size)) +internal fun ArrayTFJS.flatten() = reshape(this, arrayOf(this.size)) -fun ArrayTFJS.isScalar() = shape.isEmpty() +internal fun ArrayTFJS.isScalar() = shape.isEmpty() -fun ArrayTFJS.computeBlockSize(fromDim: Int = 0, toDim: Int = this.shape.size): Int { +internal fun ArrayTFJS.computeBlockSize(fromDim: Int = 0, toDim: Int = this.shape.size): Int { return this.shape.sliceArray(fromDim until toDim).fold(1, Int::times) } -fun ArrayTFJS.indexAxis(axis: Int) = if (axis < 0) rank + axis else axis +internal fun ArrayTFJS.indexAxis(axis: Int) = if (axis < 0) rank + axis else axis -fun ArrayTFJS.min(axis: Int = 0, keepDims: Boolean = false) = min(this, arrayOf(axis), keepDims) +internal fun ArrayTFJS.min(axis: Int = 0, keepDims: Boolean = false) = min(this, arrayOf(axis), keepDims) -fun ArrayTFJS.min(axes: Array, keepDims: Boolean = false) = min(this, axes, keepDims) +internal fun ArrayTFJS.min(axes: Array, keepDims: Boolean = false) = min(this, axes, keepDims) -fun ArrayTFJS.min(keepDims: Boolean = false) = min(this, null, keepDims) +internal fun ArrayTFJS.min(keepDims: Boolean = false) = min(this, null, keepDims) -fun ArrayTFJS.min() = min(this, null, null) +internal fun ArrayTFJS.min() = min(this, null, null) -fun ArrayTFJS.max(axis: Int, keepDims: Boolean = false) = max(this, arrayOf(axis), keepDims) +internal fun ArrayTFJS.max(axis: Int, keepDims: Boolean = false) = max(this, arrayOf(axis), keepDims) -fun ArrayTFJS.max(axes: Array, keepDims: Boolean = false) = max(this, axes, keepDims) +internal fun ArrayTFJS.max(axes: Array, keepDims: Boolean = false) = max(this, axes, keepDims) -fun ArrayTFJS.max(keepDims: Boolean = false) = max(this, null, keepDims) +internal fun ArrayTFJS.max(keepDims: Boolean = false) = max(this, null, keepDims) -fun ArrayTFJS.max() = max(this, null, null) +internal fun ArrayTFJS.max() = max(this, null, null) -fun ArrayTFJS.round() = round(this) +internal fun ArrayTFJS.round() = round(this) -fun ArrayTFJS.clip(minValue: Number, maxValue: Number) = clipByValue(this, minValue, maxValue) +internal fun ArrayTFJS.clip(minValue: Number, maxValue: Number) = clipByValue(this, minValue, maxValue) -operator fun ArrayTFJS.unaryMinus() = neg(this) +internal operator fun ArrayTFJS.unaryMinus() = neg(this) -fun min(a: ArrayTFJS, b: ArrayTFJS) = minimum(a, b) +internal fun min(a: ArrayTFJS, b: ArrayTFJS) = minimum(a, b) -fun max(a: ArrayTFJS, b: ArrayTFJS) = maximum(a, b) +internal fun max(a: ArrayTFJS, b: ArrayTFJS) = maximum(a, b) -fun ArrayTFJS.sqrt() = sqrt(this) +internal fun ArrayTFJS.sqrt() = sqrt(this) -fun sqrt(value: ArrayTFJS) = value.sqrt() +internal fun sqrt(value: ArrayTFJS) = value.sqrt() -fun ArrayTFJS.tanh() = tanh(this) +internal fun ArrayTFJS.tanh() = tanh(this) -fun tanh(x: ArrayTFJS) = x.tanh() +internal fun tanh(x: ArrayTFJS) = x.tanh() -fun ArrayTFJS.slice(begin: Array, end: Array) = slice(this, begin, end) +internal fun ArrayTFJS.slice(begin: Array, end: Array) = slice(this, begin, end) -fun ArrayTFJS.slice(begin: Array) = slice(this, begin, null) +internal fun ArrayTFJS.slice(begin: Array) = slice(this, begin, null) -fun ArrayTFJS.reverse(axes: Array) = reverse(this, axes) +internal fun ArrayTFJS.reverse(axes: Array) = reverse(this, axes) -fun ArrayTFJS.reverse(axis: Int) = reverse(this, arrayOf(axis)) +internal fun ArrayTFJS.reverse(axis: Int) = reverse(this, arrayOf(axis)) -fun ArrayTFJS.reverse() = reverse(this, null) +internal fun ArrayTFJS.reverse() = reverse(this, null) -fun ArrayTFJS.slice(start: Array, end: Array, step: Array) = stridedSlice(this, start, end, step, 0, 0, 0, 0, 0) +internal fun ArrayTFJS.slice(start: Array, end: Array, step: Array) = stridedSlice(this, start, end, step, 0, 0, 0, 0, 0) -fun ArrayTFJS.squeeze(axes: Array? = null) = squeeze(this, axes) +internal fun ArrayTFJS.squeeze(axes: Array? = null) = squeeze(this, axes) -fun ArrayTFJS.argmax(axis: Int = 0) = argMax(this, axis) +internal fun ArrayTFJS.argmax(axis: Int = 0) = argMax(this, axis) -fun ArrayTFJS.argmin(axis: Int = 0) = argMin(this, axis) +internal fun ArrayTFJS.argmin(axis: Int = 0) = argMin(this, axis) -fun ArrayTFJS.tile(repeats: Array) = tile(this, repeats) +internal fun ArrayTFJS.tile(repeats: Array) = tile(this, repeats) -fun ArrayTFJS.less(other: ArrayTFJS) = less(this, other) +internal fun ArrayTFJS.less(other: ArrayTFJS) = less(this, other) -fun ArrayTFJS.greater(other: ArrayTFJS) = greater(this, other) +internal fun ArrayTFJS.greater(other: ArrayTFJS) = greater(this, other) -fun ArrayTFJS.greaterEqual(other: ArrayTFJS) = greaterEqual(this, other) +internal fun ArrayTFJS.greaterEqual(other: ArrayTFJS) = greaterEqual(this, other) -fun ArrayTFJS.equal(other: ArrayTFJS) = equal(this, other) +internal fun ArrayTFJS.equal(other: ArrayTFJS) = equal(this, other) -fun ArrayTFJS.notEqual(other: ArrayTFJS) = notEqual(this, other) +internal fun ArrayTFJS.notEqual(other: ArrayTFJS) = notEqual(this, other) -fun ArrayTFJS.where(condition: ArrayTFJS, other: ArrayTFJS) = where(condition, this, other) +internal fun ArrayTFJS.where(condition: ArrayTFJS, other: ArrayTFJS) = where(condition, this, other) -fun ArrayTFJS.clone() = clone(this) +internal fun ArrayTFJS.clone() = clone(this) -fun ArrayTFJS.not(): ArrayTFJS { +internal fun ArrayTFJS.not(): ArrayTFJS { require(this.dtype == "bool") { "Only bool type is accepted" } return logicalNot(this) } -fun ArrayTFJS.or(other: ArrayTFJS): ArrayTFJS { +internal fun ArrayTFJS.or(other: ArrayTFJS): ArrayTFJS { require(this.dtype == "bool" && other.dtype == "bool") { "Only boolean arrays are accepted" } return logicalOr(this, other) } -fun ArrayTFJS.and(other: ArrayTFJS): ArrayTFJS { +internal fun ArrayTFJS.and(other: ArrayTFJS): ArrayTFJS { require(this.dtype == "bool" && other.dtype == "bool") { "Only boolean arrays are accepted" } return logicalAnd(this, other) } -fun ArrayTFJS.pad(paddings: Array>, constantValue: Any) = pad(this, paddings, constantValue) +internal fun ArrayTFJS.xor(other: ArrayTFJS): ArrayTFJS { + require(this.dtype == "bool" && other.dtype == "bool") { "Only boolean arrays are accepted" } + return logicalXor(this, other) +} + +internal fun ArrayTFJS.pad(paddings: Array>, constantValue: Any) = pad(this, paddings, constantValue) internal fun ArrayTFJS.mirrorPad(paddings: Array>, mode: String) = mirrorPad(this, paddings, mode) -fun ArrayTFJS.reflectPad(paddings: Array>) = mirrorPad(this, paddings, "reflect") +internal fun ArrayTFJS.reflectPad(paddings: Array>) = mirrorPad(this, paddings, "reflect") -fun ArrayTFJS.symmetricPad(paddings: Array>) = mirrorPad(this, paddings, "symmetric") +internal fun ArrayTFJS.symmetricPad(paddings: Array>) = mirrorPad(this, paddings, "symmetric") -fun ArrayTFJS.gatherNd(indices: ArrayTFJS) = gatherND(this, indices) +internal fun ArrayTFJS.gatherNd(indices: ArrayTFJS) = gatherND(this, indices) -fun ArrayTFJS.leakyRelu(alpha: Number) = leakyRelu(this, alpha) +internal fun ArrayTFJS.leakyRelu(alpha: Number) = leakyRelu(this, alpha) -fun ArrayTFJS.relu() = relu(this) +internal fun ArrayTFJS.relu() = relu(this) -fun ArrayTFJS.cumsum(axis: Int = 0, exclusive: Boolean = false, reverse: Boolean = false) = +internal fun ArrayTFJS.cumsum(axis: Int = 0, exclusive: Boolean = false, reverse: Boolean = false) = cumsum(this, axis, exclusive, reverse) -fun ArrayTFJS.topk(k: Int, sorted: Boolean = false) = topk(this, k, sorted) +internal fun ArrayTFJS.topk(k: Int, sorted: Boolean = false) = topk(this, k, sorted) -fun ArrayTFJS.abs() = abs(this) +internal fun ArrayTFJS.abs() = abs(this) -fun ArrayTFJS.acos() = acos(this) +internal fun ArrayTFJS.acos() = acos(this) -fun ArrayTFJS.acosh() = acosh(this) +internal fun ArrayTFJS.acosh() = acosh(this) -fun ArrayTFJS.asin() = asin(this) +internal fun ArrayTFJS.asin() = asin(this) -fun ArrayTFJS.asinh() = asinh(this) +internal fun ArrayTFJS.asinh() = asinh(this) -fun ArrayTFJS.atan() = atan(this) +internal fun ArrayTFJS.atan() = atan(this) -fun ArrayTFJS.atanh() = atanh(this) +internal fun ArrayTFJS.atanh() = atanh(this) -fun ArrayTFJS.tensorScatterUpdate(indices: ArrayTFJS, updates: ArrayTFJS) = tensorScatterUpdate(this, indices, updates) +internal fun ArrayTFJS.tensorScatterUpdate(indices: ArrayTFJS, updates: ArrayTFJS) = tensorScatterUpdate(this, indices, updates) -fun ArrayTFJS.ceil() = ceil(this) +internal fun ArrayTFJS.ceil() = ceil(this) -fun ArrayTFJS.exp() = exp(this) +internal fun ArrayTFJS.exp() = exp(this) -fun ArrayTFJS.expm1() = expm1(this) +internal fun ArrayTFJS.expm1() = expm1(this) -fun ArrayTFJS.elu() = elu(this) +internal fun ArrayTFJS.elu() = elu(this) -fun ArrayTFJS.prelu(alpha: ArrayTFJS) = prelu(this, alpha) +internal fun ArrayTFJS.prelu(alpha: ArrayTFJS) = prelu(this, alpha) -fun ArrayTFJS.cos() = cos(this) +internal fun ArrayTFJS.cos() = cos(this) -fun ArrayTFJS.cosh() = cosh(this) +internal fun ArrayTFJS.cosh() = cosh(this) -fun ArrayTFJS.qrDecomposition(fullMatrices: Boolean = false): QrDecompositionResultTFJS { +internal fun ArrayTFJS.qrDecomposition(fullMatrices: Boolean = false): QrDecompositionResultTFJS { val result = linalg.qr(this, fullMatrices) return QrDecompositionResultTFJS(result[0], result[1]) } -fun ArrayTFJS.prod(axis: Int, keepDims: Boolean = false) = prod(this, arrayOf(axis), keepDims) +internal fun ArrayTFJS.prod(axis: Int, keepDims: Boolean = false) = prod(this, arrayOf(axis), keepDims) + +internal fun ArrayTFJS.prod(axes: Array, keepDims: Boolean = false) = prod(this, axes, keepDims) -fun ArrayTFJS.prod(axes: Array, keepDims: Boolean = false) = prod(this, axes, keepDims) +internal fun ArrayTFJS.floor() = floor(this) -fun ArrayTFJS.floor() = floor(this) +internal fun ArrayTFJS.isInf() = isInf(this) -fun ArrayTFJS.isInf() = isInf(this) +internal fun ArrayTFJS.isNaN() = isNaN(this) diff --git a/ndarray/ndarray-tfjs/src/jsTest/kotlin/FloatCreateTest.kt b/ndarray/ndarray-tfjs/src/jsTest/kotlin/FloatCreateTest.kt new file mode 100644 index 000000000..2b48f9077 --- /dev/null +++ b/ndarray/ndarray-tfjs/src/jsTest/kotlin/FloatCreateTest.kt @@ -0,0 +1,206 @@ +import io.kinference.ndarray.arrays.NDArrayTFJS +import io.kinference.ndarray.extensions.dataFloat +import kotlin.test.Test +import kotlin.test.assertContentEquals + +class FloatCreateTest { + @Test + fun createFromFloatArray() { + val tensor = NDArrayTFJS.float(floatArrayOf(1f, 2f, 3f, 4f), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from float array") + + tensor.close() + } + + @Test + fun createFromFloatArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1f, 2f, 3f, 4f), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from float array") + + tensor.close() + } + + @Test + fun createFromDoubleArray() { + val tensor = NDArrayTFJS.float(doubleArrayOf(1.0, 2.0, 3.0, 4.0), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from double array") + + tensor.close() + } + + @Test + fun createFromDoubleArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1.0, 2.0, 3.0, 4.0), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from double array") + + tensor.close() + } + + @Test + fun createFromIntArray() { + val tensor = NDArrayTFJS.float(intArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from int array") + + tensor.close() + } + + @Test + fun createFromIntArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from int array") + + tensor.close() + } + + @Test + fun createFromByteArray() { + val tensor = NDArrayTFJS.float(byteArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from byte array") + + tensor.close() + } + + @Test + fun createFromByteArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from byte array") + + tensor.close() + } + + @Test + fun createFromShortArray() { + val tensor = NDArrayTFJS.float(shortArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from short array") + + tensor.close() + } + + @Test + fun createFromShortArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from short array") + + tensor.close() + } + + @Test + fun createFromLongArray() { + val tensor = NDArrayTFJS.float(longArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from long array") + + tensor.close() + } + + @Test + fun createFromLongArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from long array") + + tensor.close() + } + + @Test + fun createFromUByteArray() { + val tensor = NDArrayTFJS.float(ubyteArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from ubyte array") + + tensor.close() + } + + @Test + fun createFromUByteArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from ubyte array") + + tensor.close() + } + + @Test + fun createFromUShortArray() { + val tensor = NDArrayTFJS.float(ushortArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from ushort array") + + tensor.close() + } + + @Test + fun createFromUShortArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from ushort array") + + tensor.close() + } + + @Test + fun createFromUIntArray() { + val tensor = NDArrayTFJS.float(uintArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from uint array") + + tensor.close() + } + + @Test + fun createFromUIntArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from uint array") + + tensor.close() + } + + @Test + fun createFromULongArray() { + val tensor = NDArrayTFJS.float(ulongArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from ulong array") + + tensor.close() + } + + @Test + fun createFromULongArrayTyped() { + val tensor = NDArrayTFJS.float(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataFloat() + assertContentEquals(floatArrayOf(1f, 2f, 3f, 4f), data, "Problem with creating float tensor from ulong array") + + tensor.close() + } +} diff --git a/ndarray/ndarray-tfjs/src/jsTest/kotlin/IntCreateTest.kt b/ndarray/ndarray-tfjs/src/jsTest/kotlin/IntCreateTest.kt new file mode 100644 index 000000000..97ee6be2c --- /dev/null +++ b/ndarray/ndarray-tfjs/src/jsTest/kotlin/IntCreateTest.kt @@ -0,0 +1,166 @@ +import io.kinference.ndarray.arrays.NDArrayTFJS +import io.kinference.ndarray.extensions.dataInt +import kotlin.test.Test +import kotlin.test.assertContentEquals + +class IntCreateTest { + @Test + fun createFromIntArray() { + val tensor = NDArrayTFJS.int(intArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from int array") + + tensor.close() + } + + @Test + fun createFromIntArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from int array") + + tensor.close() + } + + @Test + fun createFromByteArray() { + val tensor = NDArrayTFJS.int(byteArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from byte array") + + tensor.close() + } + + @Test + fun createFromByteArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from byte array") + + tensor.close() + } + + @Test + fun createFromShortArray() { + val tensor = NDArrayTFJS.int(shortArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from short array") + + tensor.close() + } + + @Test + fun createFromShortArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from short array") + + tensor.close() + } + + @Test + fun createFromLongArray() { + val tensor = NDArrayTFJS.int(longArrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from long array") + + tensor.close() + } + + @Test + fun createFromLongArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1, 2, 3, 4), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from long array") + + tensor.close() + } + + @Test + fun createFromUIntArray() { + val tensor = NDArrayTFJS.int(uintArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from uint array") + + tensor.close() + } + + @Test + fun createFromUIntArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from uint array") + + tensor.close() + } + + @Test + fun createFromUByteArray() { + val tensor = NDArrayTFJS.int(ubyteArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from ubyte array") + + tensor.close() + } + + @Test + fun createFromUByteArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from ubyte array") + + tensor.close() + } + + @Test + fun createFromUShortArray() { + val tensor = NDArrayTFJS.int(ushortArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from ushort array") + + tensor.close() + } + + @Test + fun createFromUShortArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from ushort array") + + tensor.close() + } + + @Test + fun createFromULongArray() { + val tensor = NDArrayTFJS.int(ulongArrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from ulong array") + + tensor.close() + } + + @Test + fun createFromULongArrayTyped() { + val tensor = NDArrayTFJS.int(arrayOf(1u, 2u, 3u, 4u), arrayOf(2, 2)) + + val data = tensor.dataInt() + assertContentEquals(intArrayOf(1, 2, 3, 4), data, "Problem with creating int tensor from ulong array") + + tensor.close() + } +} diff --git a/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/descriptor.txt b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/descriptor.txt new file mode 100644 index 000000000..7967d1896 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/descriptor.txt @@ -0,0 +1,2 @@ +test_data_set_0/input_0.pb +test_data_set_0/output_0.pb \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/model.onnx b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/model.onnx new file mode 100644 index 000000000..f629dbffd --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/model.onnx @@ -0,0 +1,12 @@ + backend-test:= + +xy"IsNaN +test_isnanZ +x + + +b +y + +  +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/test_data_set_0/input_0.pb new file mode 100644 index 000000000..14d9d76db Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/test_data_set_0/output_0.pb new file mode 100644 index 000000000..1867429a1 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/isnan/test_isnan/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/model.onnx new file mode 100644 index 000000000..e17f63dfb --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/model.onnx @@ -0,0 +1,17 @@ + backend-test:_ + +x +yxor"Xor +test_xor2dZ +x +   + +Z +y +   + +b +xor +   + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..d9541774b Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..2f37772a3 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..2ede81629 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor2d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/model.onnx new file mode 100644 index 000000000..ce915d7b4 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/model.onnx @@ -0,0 +1,20 @@ + backend-test:k + +x +yxor"Xor +test_xor3dZ +x +  + + +Z +y +  + + +b +xor +  + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..402cb458f Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..825b4f452 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..5aaed1272 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor3d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/model.onnx new file mode 100644 index 000000000..c4ab09c45 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/model.onnx @@ -0,0 +1,23 @@ + backend-test:w + +x +yxor"Xor +test_xor4dZ +x +  + + + +Z +y +  + + + +b +xor +  + + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..3c2fb94d8 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..925bab696 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..4b38cb454 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor4d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/model.onnx new file mode 100644 index 000000000..57b7a689c --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/model.onnx @@ -0,0 +1,18 @@ + backend-test:k + +x +yxor"Xortest_xor_bcast3v1dZ +x +  + + +Z +y + +  +b +xor +  + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..5cc32fe10 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..1500686f9 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..191b7b0ef Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v1d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/model.onnx new file mode 100644 index 000000000..f7fc3df32 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/model.onnx @@ -0,0 +1,18 @@ + backend-test:o + +x +yxor"Xortest_xor_bcast3v2dZ +x +  + + +Z +y +   + +b +xor +  + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..d784193ad Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..8bab0d309 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..dadeaaada Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast3v2d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/model.onnx new file mode 100644 index 000000000..f9c7ccf72 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/model.onnx @@ -0,0 +1,20 @@ + backend-test:w + +x +yxor"Xortest_xor_bcast4v2dZ +x +  + + + +Z +y +   + +b +xor +  + + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..22944e324 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..d7cba2c1a Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..926665380 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v2d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/model.onnx new file mode 100644 index 000000000..77dbcb6b5 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/model.onnx @@ -0,0 +1,21 @@ + backend-test:{ + +x +yxor"Xortest_xor_bcast4v3dZ +x +  + + + +Z +y +  + + +b +xor +  + + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..70f33b925 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..7e83c141e Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..854b39a48 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v3d/test_data_set_0/output_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/descriptor.txt b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/model.onnx b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/model.onnx new file mode 100644 index 000000000..4e30fd1e5 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/model.onnx @@ -0,0 +1,22 @@ + backend-test: + +x +yxor"Xortest_xor_bcast4v4dZ +x +  + + + +Z +y +  + + + +b +xor +  + + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/input_0.pb new file mode 100644 index 000000000..024a3481c Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/input_0.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/input_1.pb new file mode 100644 index 000000000..8e18502fa Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/output_0.pb new file mode 100644 index 000000000..bb78baf43 Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/xor/test_xor_bcast4v4d/test_data_set_0/output_0.pb differ