Skip to content

Commit

Permalink
KI-33 [tfjs] Add Trilu operator
Browse files Browse the repository at this point in the history
  • Loading branch information
AnastasiaTuchina committed Jul 12, 2023
1 parent 29a21b3 commit ee5c36d
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ object TFJSOperatorFactory : OperatorFactory<TFJSData<*>> {
"Transpose" -> Transpose(name, version, attributes, inputs, outputs)
"TreeEnsembleClassifier" -> TreeEnsembleClassifier(name, version, attributes, inputs, outputs)
"TreeEnsembleRegressor" -> TreeEnsembleRegressor(name, version, attributes, inputs, outputs)
"Trilu" -> Trilu(name, version, attributes, inputs, outputs)
"Unsqueeze" -> Unsqueeze(name, version, attributes, inputs, outputs)
"Where" -> Where(name, version, attributes, inputs, outputs)
"ZipMap" -> ZipMap(name, version, attributes, inputs, outputs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.attribute.Attribute
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.extensions.trilu
import io.kinference.operator.*
import io.kinference.protobuf.message.AttributeProto
import io.kinference.protobuf.message.TensorProto
import io.kinference.tfjs.data.tensors.TFJSTensor
import io.kinference.tfjs.data.tensors.asTensor

sealed class Trilu(
name: String,
info: OperatorInfo,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Operator<TFJSTensor, TFJSTensor>(name, info, attributes, inputs, outputs) {
companion object {
private val DEFAULT_VERSION = VersionInfo(sinceVersion = 14)

operator fun invoke(name: String, version: Int?, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>): Trilu {
return when(version ?: DEFAULT_VERSION.sinceVersion) {
in TriluVer14.VERSION.asRange() -> TriluVer14(name, attributes, inputs, outputs)
else -> error("Unsupported version of Trilu operator: $version")
}
}
}
}


class TriluVer14(
name: String,
attributes: Map<String, Attribute<Any>>,
inputs: List<String>,
outputs: List<String>
) : Trilu(name, INFO, attributes, inputs, outputs) {
companion object {
private val TYPE_CONSTRAINTS = ALL_DATA_TYPES

private val INPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "input", optional = false, differentiable = true),
IOInfo(1, setOf(TensorProto.DataType.INT64), "k", optional = true, differentiable = false)
)

private val OUTPUTS_INFO = listOf(
IOInfo(0, TYPE_CONSTRAINTS, "output", optional = false, differentiable = true)
)

private val ATTRIBUTES_INFO = listOf(
AttributeInfo("upper", setOf(AttributeProto.AttributeType.INT), required = false, default = 1L)
)

internal val VERSION = VersionInfo(sinceVersion = 14)
private val INFO = OperatorInfo("Trilu", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN)
}

private val upper: Boolean by attribute { it: Number -> it != 0L }

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<TFJSTensor?>): List<TFJSTensor?> {
val input = inputs[0]!!.data
val kTensor = inputs.getOrNull(1)?.data as? NumberNDArrayTFJS

val k = kTensor?.singleValue()?.toInt() ?: 0
val output = input.trilu(k, upper)

return listOf(output.asTensor("output"))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package io.kinference.tfjs.operators.tensor

import io.kinference.tfjs.runners.TFJSTestEngine.TFJSAccuracyRunner
import io.kinference.utils.TestRunner
import kotlin.test.Test

class TriluTest {
private fun getTargetPath(dirName: String) = "trilu/$dirName/"

@Test
fun test_tril() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril"))
}

@Test
fun test_tril_neg() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_neg"))
}

@Test
fun test_tril_one_row_neg() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_one_row_neg"))
}

@Test
fun test_tril_out_neg() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_out_neg"))
}

@Test
fun test_tril_out_pos() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_out_pos"))
}

@Test
fun test_tril_pos() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_pos"))
}

@Test
fun test_tril_square() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_square"))
}

@Test
fun test_tril_square_neg() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_square_neg"))
}

@Test
fun test_tril_zero() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_tril_zero"))
}

@Test
fun test_triu() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu"))
}

@Test
fun test_triu_neg() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_neg"))
}

@Test
fun test_triu_one_row() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_one_row"))
}

@Test
fun test_triu_out_neg_out() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_out_neg_out"))
}

@Test
fun test_triu_out_pos() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_out_pos"))
}

@Test
fun test_triu_pos() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_pos"))
}

@Test
fun test_triu_square() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_square"))
}

@Test
fun test_triu_square_neg() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_square_neg"))
}

@Test
fun test_triu_zero() = TestRunner.runTest {
TFJSAccuracyRunner.runFromResources(getTargetPath("test_triu_zero"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ abstract class NDArrayTFJS(tfjsArray: ArrayTFJS) : NDArray {
return zero
}

internal fun createZeros(shape: Array<Int>, dtype: String): NDArrayTFJS {
return when (dtype) {
"int32" -> intZeros(shape)
"float32" -> floatZeros(shape)
"bool" -> booleanZeros(shape)
"string" -> stringFill(shape, "")
else -> error("Unsupported data type: $dtype")
}
}

fun float(values: FloatArray, shape: Array<Int>) = NumberNDArrayTFJS(tensor(values, shape, "float32"))
fun int(values: IntArray, shape: Array<Int>) = NumberNDArrayTFJS(tensor(values, shape, "int32"))
fun boolean(values: Array<Boolean>, shape: Array<Int>) = BooleanNDArrayTFJS(tensor(values, shape))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ import io.kinference.ndarray.arrays.ArrayTFJS

internal external interface Linalg {
val qr: (x: ArrayTFJS, fullMatrices: Boolean) -> Array<ArrayTFJS>

val bandPart: (a: ArrayTFJS, numLower: Int, numUpper: Int) -> ArrayTFJS
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,40 @@ suspend fun NumberNDArrayTFJS.isInf(detectNegative: Boolean = true, detectPositi
}

fun NumberNDArrayTFJS.isNaN() = BooleanNDArrayTFJS(tfjsArray.isNaN())


suspend fun <T : NDArrayTFJS> T.trilu(k: Int, upper: Boolean): T {
if (this.linearSize == 0) return this.clone() as T

return tidyNDArray {
if (upper) this.triluUpper(k) else this.triluLower(k)
} as T
}

private fun NDArrayTFJS.triluUpper(k: Int): NDArrayTFJS {
val (height, width) = shape.takeLast(2)

if (k == 0) return tfjsArray.bandPart(numUpper = -1).toNDArray()
if (k > 0 && k - 1 > width) return NDArrayTFJS.createZeros(shapeArray, dtype)
if (k < 0 && -k > height) return this.clone()

return if (k > 0) {
tfjsArray - tfjsArray.bandPart(numLower = -1, numUpper = k - 1)
} else {
tfjsArray.bandPart(numLower = -k, numUpper = -1)
}.toNDArray()
}

private fun NDArrayTFJS.triluLower(k: Int): NDArrayTFJS {
val (height, width) = shape.takeLast(2)

if (k == 0) return tfjsArray.bandPart(numLower = -1).toNDArray()
if (k < 0 && -k - 1 > height) return NDArrayTFJS.createZeros(shapeArray, dtype)
if (k > 0 && k > width) return this.clone()

return if (k > 0) {
tfjsArray.bandPart(numLower = -1, numUpper = k)
} else {
tfjsArray - tfjsArray.bandPart(numLower = -k - 1, numUpper = -1)
}.toNDArray()
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,5 @@ fun ArrayTFJS.floor() = floor(this)
fun ArrayTFJS.isInf() = isInf(this)

fun ArrayTFJS.isNaN() = isNaN(this)

fun ArrayTFJS.bandPart(numLower: Int = 0, numUpper: Int = 0) = linalg.bandPart(this, numLower, numUpper)

0 comments on commit ee5c36d

Please sign in to comment.