Skip to content

Commit

Permalink
Merge pull request #182 from JetBrains-Research/coroutines-support
Browse files Browse the repository at this point in the history
Coroutines support
  • Loading branch information
dmitriyb authored May 15, 2024
2 parents 953cb6d + d2c9d97 commit e566cfc
Show file tree
Hide file tree
Showing 22 changed files with 176 additions and 303 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ import io.kinference.core.graph.KIGraph
import io.kinference.core.markOutput
import io.kinference.graph.Contexts
import io.kinference.model.Model
import io.kinference.ndarray.arrays.memory.ArrayDispatcher
import io.kinference.operator.OperatorSetRegistry
import io.kinference.profiler.*
import io.kinference.protobuf.message.ModelProto
import io.kinference.utils.ModelContext
import io.kinference.ndarray.arrays.memory.AllocatorContext
import kotlinx.coroutines.withContext
import kotlinx.atomicfu.atomic

Expand All @@ -20,23 +19,21 @@ class KIModel(val id: String, val name: String, val opSet: OperatorSetRegistry,
override fun analyzeProfilingResults(): ProfileAnalysisEntry = profiles.analyze("Model $name")
override fun resetProfiles() = profiles.clear()

override suspend fun predict(input: List<KIONNXData<*>>, profile: Boolean): Map<String, KIONNXData<*>> = withContext(ModelContext(id, getInferenceCycleId().toString())) {
val contexts = Contexts<KIONNXData<*>>(
null,
if (profile) addProfilingContext("Model $name") else null
)
val modelName = coroutineContext[ModelContext.Key]!!.modelName
val inferenceCycle = coroutineContext[ModelContext.Key]!!.cycleId
ArrayDispatcher.addInferenceContext(modelName, inferenceCycle)
val execResult = graph.execute(input, contexts)
execResult.forEach { it.markOutput() }
ArrayDispatcher.closeInferenceContext(modelName, inferenceCycle)
execResult.associateBy { it.name!! }
}
override suspend fun predict(input: List<KIONNXData<*>>, profile: Boolean): Map<String, KIONNXData<*>> =
withContext(AllocatorContext(id, getInferenceCycleId())) {
val contexts = Contexts<KIONNXData<*>>(
null,
if (profile) addProfilingContext("Model $name") else null
)
val coroutineContext = coroutineContext[AllocatorContext.Key]!!
val execResult = graph.execute(input, contexts)
execResult.forEach { it.markOutput() }
coroutineContext.closeAllocated()
execResult.associateBy { it.name!! }
}

override suspend fun close() {
graph.close()
ArrayDispatcher.removeModelContext(id)
}

private fun getInferenceCycleId(): Long = inferenceCycleCounter.incrementAndGet()
Expand All @@ -51,7 +48,6 @@ class KIModel(val id: String, val name: String, val opSet: OperatorSetRegistry,
val id = "$name:${generateModelId()}"
val opSet = OperatorSetRegistry(proto.opSetImport)
val graph = KIGraph(proto.graph!!, opSet)
ArrayDispatcher.addModelContext(id)
return KIModel(id, name, opSet, graph)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
package io.kinference.ndarray.arrays

typealias StateMarker = (ArrayUsageMarker) -> Unit

const val NO_MODEL_CONTEXT = "NoContext"
const val NO_INFERENCE_CONTEXT = "NoInferenceContext"

enum class ArrayUsageMarker {
Unused,
Used,
Output,
}
typealias StateMarker = () -> Unit

enum class ArrayTypes(val index: Int) {
ByteArray(0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ open class BooleanNDArray(var array: BooleanTiledArray, strides: Strides) : NDAr
}

override fun markOutput() {
array.marker.forEach { it.invoke(ArrayUsageMarker.Output) }
array.marker.forEach { it.invoke() }
}

override suspend fun toMutable(): MutableBooleanNDArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ internal open class PrimitiveNDArray(array: PrimitiveTiledArray, strides: Stride
}

override fun markOutput() {
array.marker.forEach { it.invoke(ArrayUsageMarker.Output) }
array.marker.forEach { it.invoke() }
}

override suspend fun clone(): PrimitiveNDArray {
Expand Down Expand Up @@ -540,7 +540,7 @@ internal open class PrimitiveNDArray(array: PrimitiveTiledArray, strides: Stride
}

override suspend fun reduceSum(axes: IntArray, keepDims: Boolean): PrimitiveNDArray =
reduceOperationPrimitive(axes, keepDims) { output: PrimitiveType, input: PrimitiveType -> (output + input).toPrimitive() }
reduceOperationPrimitive(axes, keepDims) { output: InlinePrimitive, input: InlinePrimitive -> (output + input) }

override suspend fun topK(axis: Int, k: Int, largest: Boolean, sorted: Boolean): Pair<PrimitiveNDArray, LongNDArray> {
val actualAxis = indexAxis(axis)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.kinference.ndarray.arrays.memory

import io.kinference.ndarray.arrays.*
import kotlin.coroutines.CoroutineContext

data class AllocatorContext(val modelName: String, val cycleId: Long) : CoroutineContext.Element {
private val usedContainers: ArrayDeque<ArrayContainer> = ArrayDeque()
private val unusedContainers: ArrayStorage = ArrayDispatcher.getStorage()

companion object Key : CoroutineContext.Key<AllocatorContext>
override val key: CoroutineContext.Key<*> get() = Key

internal fun getArrayContainers(type: ArrayTypes, size: Int, count: Int): Array<ArrayContainer> {
val arrayContainers = Array(count) { unusedContainers.getArrayContainer(type, size) }
usedContainers.addAll(arrayContainers)
return arrayContainers
}


fun closeAllocated() {
usedContainers.forEach {
if (!it.isOutput) {
unusedContainers[it.arrayTypeIndex, it.arraySizeIndex].addLast(it)
}
}
ArrayDispatcher.returnStorage(unusedContainers)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,20 @@ import io.kinference.ndarray.arrays.*

internal sealed class ArrayContainer(
val arrayTypeIndex: Int,
val arraySizeIndex: Int,
var marker: ArrayUsageMarker = ArrayUsageMarker.Used,
val arraySizeIndex: Int
) {
val markAsOutput: StateMarker = {
marker = it
var isOutput: Boolean = false
private set

val markAsOutput = {
isOutput = true
}

var next: ArrayContainer? = null

private class EmptyArrayContainer : ArrayContainer(EMPTY_INDEX, EMPTY_INDEX)

companion object {
private const val EMPTY_INDEX = -1

fun emptyContainer(): ArrayContainer = EmptyArrayContainer()

operator fun invoke(type: ArrayTypes, size: Int, sizeIndex: Int = EMPTY_INDEX): ArrayContainer {
return when (type) {
ArrayTypes.ByteArray -> ByteArrayContainer(type.index, sizeIndex, ByteArray(size)) // 8-bit signed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,230 +1,19 @@
package io.kinference.ndarray.arrays.memory

import io.kinference.ndarray.arrays.*
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import io.kinference.ndarray.arrays.ArrayTypes
import io.kinference.utils.ConcurrentQueue

object ArrayDispatcher {
private val modelDispatchers = mutableMapOf<String, ModelArrayDispatcher>()
private val mutex = Mutex()
internal object ArrayDispatcher {
private const val INIT_SIZE_VALUE: Int = 2
private val typeSize: Int = ArrayTypes.entries.size

suspend fun addModelContext(modelContext: String) {
mutex.withLock {
modelDispatchers[modelContext] = ModelArrayDispatcher()
}
}

suspend fun removeModelContext(modelContext: String) {
val modelDispatcher = mutex.withLock {
modelDispatchers.remove(modelContext)
}
modelDispatcher?.close()
}

suspend fun addInferenceContext(modelContext: String, inferenceContext: String) {
modelDispatchers[modelContext]!!.addInferenceContext(inferenceContext)
}

suspend fun closeInferenceContext(modelContext: String, inferenceContext: String) {
modelDispatchers[modelContext]!!.closeInferenceContext(inferenceContext)
}

internal suspend fun getArrayContainers(
type: ArrayTypes,
size: Int,
count: Int,
modelContext: String = NO_MODEL_CONTEXT,
inferenceContext: String = NO_INFERENCE_CONTEXT
): Array<ArrayContainer> {
if (modelContext == NO_MODEL_CONTEXT || inferenceContext == NO_MODEL_CONTEXT)
return Array(count) { ArrayContainer(type, size) }

return modelDispatchers[modelContext]!!.getArrayContainers(inferenceContext, type, size, count)
}
}

private class ModelArrayDispatcher {
companion object {
private const val INIT_SIZE_VALUE: Int = 2
private val typeSize: Int = ArrayTypes.entries.size
}

private val usedArrays: HashMap<String, ConcurrentArrayContainerQueue> = hashMapOf()
private val unusedArrays: ArrayStorage = ArrayStorage(typeSize, INIT_SIZE_VALUE)
private val mutex = Mutex()

class ConcurrentArrayContainerQueue {
// Initialize the head with the emptyContainer sentinel node
private var head: ArrayContainer? = ArrayContainer.emptyContainer()
private var tail: ArrayContainer? = head
private val isClosed = atomic(false)
private val lock = atomic(false)

fun addLast(container: ArrayContainer) {
while (true) {
if (lock.compareAndSet(expect = false, update = true)) {
if (isClosed.value) {
lock.value = false
throw IllegalStateException("Cannot add to a closed queue.")
}

container.next = null
tail?.next = container
tail = container
lock.value = false
return
}
}
}

fun removeFirstOrNull(): ArrayContainer? {
while (true) {
if (lock.compareAndSet(expect = false, update = true)) {
if (isClosed.value) {
lock.value = false
throw IllegalStateException("Cannot remove from a closed queue.")
}

val first = head?.next
if (first == null) {
lock.value = false
return null
}

head?.next = first.next
if (first.next == null) {
tail = head
}
lock.value = false
return first
}
}
}

fun close() {
while (true) {
if (lock.compareAndSet(expect = false, update = true)) {
isClosed.value = true
var current = head
while (current != null) {
val next = current.next
current.next = null
current = next
}
lock.value = false
return
}
}
}
}

private class ArrayStorage(typeLength: Int, sizeLength: Int) {
/**
* Structure is as follows:
* 1. Array by predefined types (all types are known compiled time)
* 2. Array by size. Starting with 'INIT_SIZE_VALUE' element and grow it doubling (typically there are no more than 16 different sizes)
* 3. Queue of array containers (used as FIFO)
*/
private var storage: Array<Array<ConcurrentArrayContainerQueue>> =
Array(typeLength) { Array(sizeLength) { ConcurrentArrayContainerQueue() } }

private var sizeIndices: IntArray = IntArray(typeLength)
private var sizes: Array<IntArray> = Array(typeLength) { IntArray(sizeLength) }
private val mutex = Mutex()

operator fun get(typeIndex: Int, sizeIndex: Int): ConcurrentArrayContainerQueue {
return storage[typeIndex][sizeIndex]
}

suspend fun getArrayContainer(type: ArrayTypes, size: Int): ArrayContainer {
val tIndex = type.index
val sIndex = sizes[tIndex].indexOf(size)

// Checking that we have this array size in our storage for this type
val idx = if (sIndex != -1) {
val array = storage[tIndex][sIndex].removeFirstOrNull()
array?.let {
it.marker = ArrayUsageMarker.Used
ArrayContainer.resetArray(it)
return it
}
sIndex
} else {
mutex.withLock {
if (sizeIndices[tIndex] >= storage[tIndex].size)
grow(tIndex)

val idx = sizeIndices[tIndex]++
sizes[tIndex][idx] = size
idx
}
}

return ArrayContainer(type, size, idx)
}

fun grow(typeIndex: Int) {
val newSize = sizes[typeIndex].size * 2
val newStorage: Array<ConcurrentArrayContainerQueue> = Array(newSize) { ConcurrentArrayContainerQueue() }

for (i in storage[typeIndex].indices) {
newStorage[i] = storage[typeIndex][i]
}

storage[typeIndex] = newStorage
sizes[typeIndex] = sizes[typeIndex].copyOf(newSize)
}

fun close() {
for (i in storage.indices) {
for (j in storage[i].indices) {
storage[i][j].close()
}
}
}
}

suspend fun addInferenceContext(inferenceContext: String) {
mutex.withLock {
usedArrays[inferenceContext] = ConcurrentArrayContainerQueue()
}
}

suspend fun getArrayContainers(inferenceContext: String, type: ArrayTypes, size: Int, count: Int): Array<ArrayContainer> {
return Array(count) { getArrayContainer(inferenceContext, type, size) }
}

suspend fun closeInferenceContext(inferenceContext: String) {
val usedArrays = mutex.withLock {
usedArrays.remove(inferenceContext)!!
}
var isProcessed = false

while (!isProcessed) {
val container = usedArrays.removeFirstOrNull()
if (container != null) {
if (container.marker != ArrayUsageMarker.Output) {
container.marker = ArrayUsageMarker.Unused
unusedArrays[container.arrayTypeIndex, container.arraySizeIndex].addLast(container)
}
} else {
isProcessed = true
}
}

usedArrays.close()
}
private val unusedArrays: ConcurrentQueue<ArrayStorage> = ConcurrentQueue()

fun close() {
unusedArrays.close()
usedArrays.forEach { it.value.close() }
usedArrays.clear()
fun getStorage(): ArrayStorage {
return unusedArrays.removeFirstOrNull() ?: ArrayStorage(typeSize, INIT_SIZE_VALUE)
}

private suspend fun getArrayContainer(inferenceContext: String, type: ArrayTypes, size: Int): ArrayContainer {
val newArray = unusedArrays.getArrayContainer(type, size)
usedArrays[inferenceContext]!!.addLast(newArray)
return newArray
fun returnStorage(storage: ArrayStorage) {
unusedArrays.addLast(storage)
}
}
Loading

0 comments on commit e566cfc

Please sign in to comment.