Skip to content

Commit

Permalink
JBAI-4017 [ndarray, core] Refactored array markers from ArrayUsageMar…
Browse files Browse the repository at this point in the history
…ker to boolean reducing enormous amount of time for refreshing arrays status
  • Loading branch information
dmitriyb committed May 13, 2024
1 parent 51b4231 commit f742926
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
val vMarker = v.array.marker

val resultBlocks: Array<FloatArray>
val resultMarker: Array<StateMarker>
val resultMarker: Array<() -> Unit>

if (past == null || past.linearSize == 0) {
resultBlocks = kBlocks.plus(vBlocks)
Expand All @@ -84,7 +84,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,

val rowsSize = batchSize * numHeads
val futureRes = arrayOfNulls<FloatArray>(2 * batchSize * numHeads * presentDims[3] * blocksInRow)
val futureResMarker = arrayOfNulls<StateMarker>(2 * batchSize * numHeads * presentDims[3] * blocksInRow)
val futureResMarker = arrayOfNulls<() -> Unit>(2 * batchSize * numHeads * presentDims[3] * blocksInRow)

var resBlockIdx = 0
var pastBlocIdx = 0
Expand All @@ -109,7 +109,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
}
}
resultBlocks = futureRes as Array<FloatArray>
resultMarker = futureResMarker as Array<StateMarker>
resultMarker = futureResMarker as Array<() -> Unit>
}

return FloatNDArray(FloatTiledArray(resultBlocks, resultMarker), Strides(presentDims))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
package io.kinference.ndarray.arrays

typealias StateMarker = (ArrayUsageMarker) -> Unit

const val NO_MODEL_CONTEXT = "NoContext"
const val NO_INFERENCE_CONTEXT = "NoInferenceContext"

enum class ArrayUsageMarker {
Unused,
Used,
Output,
}

enum class ArrayTypes(val index: Int) {
ByteArray(0),
UByteArray(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ open class BooleanNDArray(var array: BooleanTiledArray, strides: Strides) : NDAr
}

override fun markOutput() {
array.marker.forEach { it.invoke(ArrayUsageMarker.Output) }
array.marker.forEach { it.invoke() }
}

override suspend fun toMutable(): MutableBooleanNDArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ internal open class PrimitiveNDArray(array: PrimitiveTiledArray, strides: Stride
}

override fun markOutput() {
array.marker.forEach { it.invoke(ArrayUsageMarker.Output) }
array.marker.forEach { it.invoke() }
}

override suspend fun clone(): PrimitiveNDArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ data class AllocatorContext(val modelName: String, val cycleId: Long) : Coroutin

fun closeAllocated() {
usedContainers.forEach {
if (it.marker != ArrayUsageMarker.Output) {
it.marker = ArrayUsageMarker.Unused
if (!it.isOutput) {
unusedContainers[it.arrayTypeIndex, it.arraySizeIndex].addLast(it)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import io.kinference.ndarray.arrays.*

internal sealed class ArrayContainer(
val arrayTypeIndex: Int,
val arraySizeIndex: Int,
var marker: ArrayUsageMarker = ArrayUsageMarker.Used,
val arraySizeIndex: Int
) {
val markAsOutput: StateMarker = {
marker = it
var isOutput: Boolean = false
private set

val markAsOutput = {
isOutput = true
}

var next: ArrayContainer? = null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.kinference.ndarray.arrays.memory

import io.kinference.ndarray.arrays.ArrayTypes
import io.kinference.ndarray.arrays.ArrayUsageMarker

internal class ArrayStorage(typeLength: Int, sizeLength: Int) {
/**
Expand Down Expand Up @@ -29,7 +28,6 @@ internal class ArrayStorage(typeLength: Int, sizeLength: Int) {
val idx = if (sIndex != -1) {
val array = storage[tIndex][sIndex].removeFirstOrNull()
array?.let {
it.marker = ArrayUsageMarker.Used
ArrayContainer.resetArray(it)
return it
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import kotlin.math.min

@GenerateNameFromPrimitives
@MakePublic
internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>, val marker: Array<StateMarker> = emptyMarker) {
internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>, val marker: Array<()->Unit> = emptyMarker) {
val size: Int
val blockSize: Int = if (blocks.isEmpty()) 0 else blocks.first().size
val blocksNum: Int = blocks.size
Expand All @@ -28,7 +28,7 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>, val marker

companion object {
val type: ArrayTypes = ArrayTypes.valueOf(PrimitiveArray::class.simpleName!!)
private val emptyMarker: Array<StateMarker> = arrayOf()
private val emptyMarker: Array<()->Unit> = arrayOf()

suspend operator fun invoke(strides: Strides): PrimitiveTiledArray {
val blockSize = blockSizeByStrides(strides)
Expand Down

0 comments on commit f742926

Please sign in to comment.