diff --git a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt index f6477867c..d3a29f9f0 100644 --- a/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt +++ b/inference/inference-core/src/commonMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt @@ -65,7 +65,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map - val resultMarker: Array + val resultMarker: Array<() -> Unit> if (past == null || past.linearSize == 0) { resultBlocks = kBlocks.plus(vBlocks) @@ -84,7 +84,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map(2 * batchSize * numHeads * presentDims[3] * blocksInRow) - val futureResMarker = arrayOfNulls(2 * batchSize * numHeads * presentDims[3] * blocksInRow) + val futureResMarker = arrayOfNulls<() -> Unit>(2 * batchSize * numHeads * presentDims[3] * blocksInRow) var resBlockIdx = 0 var pastBlocIdx = 0 @@ -109,7 +109,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map - resultMarker = futureResMarker as Array + resultMarker = futureResMarker as Array<() -> Unit> } return FloatNDArray(FloatTiledArray(resultBlocks, resultMarker), Strides(presentDims)) diff --git a/ndarray/ndarray-api/src/commonMain/kotlin/io/kinference/ndarray/arrays/ArrayDispatcherUtils.kt b/ndarray/ndarray-api/src/commonMain/kotlin/io/kinference/ndarray/arrays/ArrayDispatcherUtils.kt index b58874007..af6751243 100644 --- a/ndarray/ndarray-api/src/commonMain/kotlin/io/kinference/ndarray/arrays/ArrayDispatcherUtils.kt +++ b/ndarray/ndarray-api/src/commonMain/kotlin/io/kinference/ndarray/arrays/ArrayDispatcherUtils.kt @@ -1,16 +1,5 @@ package io.kinference.ndarray.arrays -typealias StateMarker = (ArrayUsageMarker) -> Unit - -const val NO_MODEL_CONTEXT = "NoContext" -const val NO_INFERENCE_CONTEXT = "NoInferenceContext" - -enum class ArrayUsageMarker { - Unused, - Used, - Output, -} - enum class ArrayTypes(val index: Int) { ByteArray(0), UByteArray(1), diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt index f5f392dfd..0fb962b38 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/BooleanNDArray.kt @@ -76,7 +76,7 @@ open class BooleanNDArray(var array: BooleanTiledArray, strides: Strides) : NDAr } override fun markOutput() { - array.marker.forEach { it.invoke(ArrayUsageMarker.Output) } + array.marker.forEach { it.invoke() } } override suspend fun toMutable(): MutableBooleanNDArray { diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt index ddb378fd8..ce7e80a74 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/PrimitiveNDArray.kt @@ -86,7 +86,7 @@ internal open class PrimitiveNDArray(array: PrimitiveTiledArray, strides: Stride } override fun markOutput() { - array.marker.forEach { it.invoke(ArrayUsageMarker.Output) } + array.marker.forEach { it.invoke() } } override suspend fun clone(): PrimitiveNDArray { diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/AllocatorContext.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/AllocatorContext.kt index 05041788d..4fff74816 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/AllocatorContext.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/AllocatorContext.kt @@ -19,8 +19,7 @@ data class AllocatorContext(val modelName: String, val cycleId: Long) : Coroutin fun closeAllocated() { usedContainers.forEach { - if (it.marker != ArrayUsageMarker.Output) { - it.marker = ArrayUsageMarker.Unused + if (!it.isOutput) { unusedContainers[it.arrayTypeIndex, it.arraySizeIndex].addLast(it) } } diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayContainer.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayContainer.kt index 3f066990e..0d4c6ca71 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayContainer.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayContainer.kt @@ -4,11 +4,13 @@ import io.kinference.ndarray.arrays.* internal sealed class ArrayContainer( val arrayTypeIndex: Int, - val arraySizeIndex: Int, - var marker: ArrayUsageMarker = ArrayUsageMarker.Used, + val arraySizeIndex: Int ) { - val markAsOutput: StateMarker = { - marker = it + var isOutput: Boolean = false + private set + + val markAsOutput = { + isOutput = true } var next: ArrayContainer? = null diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayStorage.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayStorage.kt index cc88fa820..2831f3a66 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayStorage.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/memory/ArrayStorage.kt @@ -1,7 +1,6 @@ package io.kinference.ndarray.arrays.memory import io.kinference.ndarray.arrays.ArrayTypes -import io.kinference.ndarray.arrays.ArrayUsageMarker internal class ArrayStorage(typeLength: Int, sizeLength: Int) { /** @@ -29,7 +28,6 @@ internal class ArrayStorage(typeLength: Int, sizeLength: Int) { val idx = if (sIndex != -1) { val array = storage[tIndex][sIndex].removeFirstOrNull() array?.let { - it.marker = ArrayUsageMarker.Used ArrayContainer.resetArray(it) return it } diff --git a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt index eda58c092..f42176cbe 100644 --- a/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt +++ b/ndarray/ndarray-core/src/commonMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt @@ -17,7 +17,7 @@ import kotlin.math.min @GenerateNameFromPrimitives @MakePublic -internal class PrimitiveTiledArray(val blocks: Array, val marker: Array = emptyMarker) { +internal class PrimitiveTiledArray(val blocks: Array, val marker: Array<()->Unit> = emptyMarker) { val size: Int val blockSize: Int = if (blocks.isEmpty()) 0 else blocks.first().size val blocksNum: Int = blocks.size @@ -28,7 +28,7 @@ internal class PrimitiveTiledArray(val blocks: Array, val marker companion object { val type: ArrayTypes = ArrayTypes.valueOf(PrimitiveArray::class.simpleName!!) - private val emptyMarker: Array = arrayOf() + private val emptyMarker: Array<()->Unit> = arrayOf() suspend operator fun invoke(strides: Strides): PrimitiveTiledArray { val blockSize = blockSizeByStrides(strides)