Skip to content

Commit

Permalink
JBAI-4017 [ndarray, core, utils] Typealias to marker arrays as StateM…
Browse files Browse the repository at this point in the history
…arker, changed PlatformQueue to ConcurrentQueue, commented code cleaned up
  • Loading branch information
dmitriyb committed May 15, 2024
1 parent f742926 commit d2c9d97
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
val vMarker = v.array.marker

val resultBlocks: Array<FloatArray>
val resultMarker: Array<() -> Unit>
val resultMarker: Array<StateMarker>

if (past == null || past.linearSize == 0) {
resultBlocks = kBlocks.plus(vBlocks)
Expand All @@ -84,7 +84,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,

val rowsSize = batchSize * numHeads
val futureRes = arrayOfNulls<FloatArray>(2 * batchSize * numHeads * presentDims[3] * blocksInRow)
val futureResMarker = arrayOfNulls<() -> Unit>(2 * batchSize * numHeads * presentDims[3] * blocksInRow)
val futureResMarker = arrayOfNulls<StateMarker>(2 * batchSize * numHeads * presentDims[3] * blocksInRow)

var resBlockIdx = 0
var pastBlocIdx = 0
Expand All @@ -109,7 +109,7 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
}
}
resultBlocks = futureRes as Array<FloatArray>
resultMarker = futureResMarker as Array<() -> Unit>
resultMarker = futureResMarker as Array<StateMarker>
}

return FloatNDArray(FloatTiledArray(resultBlocks, resultMarker), Strides(presentDims))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.kinference.ndarray.arrays

typealias StateMarker = () -> Unit

enum class ArrayTypes(val index: Int) {
ByteArray(0),
UByteArray(1),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package io.kinference.ndarray.arrays.memory

import io.kinference.ndarray.arrays.ArrayTypes
import io.kinference.utils.PlatformQueue
import io.kinference.utils.ConcurrentQueue

internal object ArrayDispatcher {
private const val INIT_SIZE_VALUE: Int = 2
private val typeSize: Int = ArrayTypes.entries.size

private val unusedArrays: PlatformQueue<ArrayStorage> = PlatformQueue()
private val unusedArrays: ConcurrentQueue<ArrayStorage> = ConcurrentQueue()

fun getStorage(): ArrayStorage {
return unusedArrays.removeFirstOrNull() ?: ArrayStorage(typeSize, INIT_SIZE_VALUE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import kotlin.math.min

@GenerateNameFromPrimitives
@MakePublic
internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>, val marker: Array<()->Unit> = emptyMarker) {
internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>, val marker: Array<StateMarker> = emptyMarker) {
val size: Int
val blockSize: Int = if (blocks.isEmpty()) 0 else blocks.first().size
val blocksNum: Int = blocks.size
Expand All @@ -28,7 +28,7 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>, val marker

companion object {
val type: ArrayTypes = ArrayTypes.valueOf(PrimitiveArray::class.simpleName!!)
private val emptyMarker: Array<()->Unit> = arrayOf()
private val emptyMarker: Array<StateMarker> = arrayOf()

suspend operator fun invoke(strides: Strides): PrimitiveTiledArray {
val blockSize = blockSizeByStrides(strides)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,10 @@ internal suspend fun computeGeluPrimitive(input: PrimitiveNDArray, bias: Primiti

val blockSize = input.array.blockSize

// // This approach when arrays acquired before parallelizeByBlocks() is faster
// val coroutineContext = coroutineContext[ModelContext.Key]!!
// val modelName = coroutineContext.modelName
// val inferenceCycle = coroutineContext.cycleId
//
// val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048)
// val containerTemporaryBlockArrays = ArrayDispatcher.getArrayContainers(PrimitiveTiledArray.type, blockSize, coroutineCount)
// val containerTemporaryBlockAbsArrays = ArrayDispatcher.getArrayContainers(PrimitiveTiledArray.type, blockSize, coroutineCount)
// val temporaryBlockArrays = Array(containerTemporaryBlockArrays.size) { i -> (containerTemporaryBlockArrays[i] as PrimitiveArrayContainer).array }
// val temporaryBlockAbsArrays = Array(containerTemporaryBlockAbsArrays.size) { i -> (containerTemporaryBlockAbsArrays[i] as PrimitiveArrayContainer).array }

// Constant 2048 was precomputed on M1 Max processor
// With this constant two launches work faster than single thread without launches
// TODO: (cupertank) Remove constants
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex ->
// val temporaryBlock = temporaryBlockArrays[coroutineIndex]
// val temporaryBlockAbs = temporaryBlockAbsArrays[coroutineIndex]
val temporaryBlock = PrimitiveArray(blockSize)
val temporaryBlockAbs = PrimitiveArray(blockSize)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,10 @@ internal suspend fun fastGeluPrimitive(input: PrimitiveNDArray, bias: PrimitiveN

val blockSize = input.array.blockSize

// val coroutineCount = countCoroutinesByData(blockSize, inputBlocks.size, 2048)
// val containerArray = ArrayDispatcher.getArrayContainers(PrimitiveTiledArray.type, blockSize, coroutineCount)
// val temporaryBlockExpArrays = Array(containerArray.size) { i -> (containerArray[i] as PrimitiveArrayContainer).array }

// Constant 2048 was precomputed on M1 Max processor
// With this constant two launches work faster than single thread without launches
// TODO: (cupertank) Remove constants
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, coroutineIndex ->
// val temporaryBlockExp = temporaryBlockExpArrays[coroutineIndex]
parallelizeByBlocks(blockSize, inputBlocks.size, 2048) { blockStart, blockEnd, _ ->
val temporaryBlockExp = PrimitiveArray(blockSize)
for (blockIdx in blockStart until blockEnd) {
val outputBlock = outputBlocks[blockIdx]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.kinference.utils

expect class PlatformQueue<T>() {
expect class ConcurrentQueue<T>() {
fun removeFirstOrNull(): T?
fun addLast(element: T)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.kinference.utils

actual class PlatformQueue<T> actual constructor() {
actual class ConcurrentQueue<T> actual constructor() {
private val queue: ArrayDeque<T> = ArrayDeque()

actual fun removeFirstOrNull(): T? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.kinference.utils

import java.util.concurrent.ConcurrentLinkedQueue

actual class PlatformQueue<T> actual constructor() {
actual class ConcurrentQueue<T> actual constructor() {
private val queue: ConcurrentLinkedQueue<T> = ConcurrentLinkedQueue()

actual fun removeFirstOrNull(): T? {
Expand Down

0 comments on commit d2c9d97

Please sign in to comment.