Skip to content

Commit

Permalink
JBAI-4393 [core, ndarray] Refactored coroutine contexts to be polymor…
Browse files Browse the repository at this point in the history
…phic, merge ParallelismLimiterContext and its thread limiter behavior into PredictionContext.
  • Loading branch information
dmitriyb committed Sep 2, 2024
1 parent a19fc9c commit c942273
Show file tree
Hide file tree
Showing 19 changed files with 73 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import io.kinference.core.graph.KIGraph
import io.kinference.graph.Contexts
import io.kinference.model.Model
import io.kinference.ndarray.arrays.memory.*
import io.kinference.ndarray.arrays.memory.contexts.finalizeAllocatorContext
import io.kinference.operator.OperatorSetRegistry
import io.kinference.profiler.*
import io.kinference.protobuf.message.ModelProto
Expand Down Expand Up @@ -47,7 +46,6 @@ class KIModel(
return@withContext graph.execute(input, contexts).map { it.clone(it.name) }.toList()
}

predictionContext.finalizeAllocatorContext()
predictionContextDispatcher.returnStorage(predictionContext)
output
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.kinference.optimizer.GraphOptimizer.Companion.isOpt
import io.kinference.primitives.types.DataType
import io.kinference.protobuf.message.AttributeProto
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.PredictionContext
import io.kinference.utils.launchWithLimitOrDefault
import kotlinx.coroutines.coroutineScope
import kotlin.coroutines.coroutineContext
Expand Down Expand Up @@ -287,7 +288,7 @@ class AttentionVer1(name: String, attributes: Map<String, Attribute<Any>>, input
private val maskFilterValue: Float by attribute("mask_filter_value") { it: Number -> it.toFloat() }

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val context = coroutineContext[ManualAllocatorContext.Key]
val context = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val input = inputs[0]!!
val weights = inputs[1]!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.kinference.operator.*
import io.kinference.primitives.types.DataType
import io.kinference.protobuf.message.AttributeProto.AttributeType
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.PredictionContext
import kotlin.coroutines.coroutineContext
import kotlin.math.sqrt

Expand Down Expand Up @@ -175,7 +176,7 @@ class EmbedLayerNormalizationVer1(
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[ManualAllocatorContext.Key]
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val inputIds = inputs[0]!!.data as IntNDArray
val segmentIds = inputs[1]?.data as IntNDArray?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import io.kinference.operator.*
import io.kinference.primitives.types.DataType
import io.kinference.protobuf.message.AttributeProto
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.PredictionContext
import kotlin.coroutines.coroutineContext
import kotlin.math.sqrt

Expand Down Expand Up @@ -107,7 +108,7 @@ class SkipLayerNormalizationVer1(name: String, attributes: Map<String, Attribute


override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[ManualAllocatorContext.Key]
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val input = inputs[0]!!.data as FloatNDArray
val output = (manualContext?.getNDArray(DataType.FLOAT, input.strides, fillZeros = false) ?: MutableFloatNDArray(input.strides)) as MutableFloatNDArray
Expand All @@ -119,7 +120,7 @@ class SkipLayerNormalizationVer1(name: String, attributes: Map<String, Attribute
epsilon = epsilon,
dst = output
)
// Do we need to pass context here??

return listOf(output.asTensor(context = manualContext))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.kinference.ndarray.arrays.memory.contexts.ManualAllocatorContext
import io.kinference.ndarray.extensions.allocateNDArray
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.PredictionContext
import kotlin.coroutines.coroutineContext

sealed class Add(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
Expand Down Expand Up @@ -55,7 +56,7 @@ class AddVer7(name: String, attributes: Map<String, Attribute<Any>>, inputs: Lis
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[ManualAllocatorContext.Key]
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val left = inputs[0]!!.data as NumberNDArrayCore
val right = inputs[1]!!.data as NumberNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.kinference.ndarray.arrays.memory.contexts.ManualAllocatorContext
import io.kinference.ndarray.extensions.allocateNDArray
import io.kinference.ndarray.extensions.gelu.biasGelu
import io.kinference.operator.*
import io.kinference.utils.PredictionContext
import kotlin.coroutines.coroutineContext

sealed class BiasGelu(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
Expand Down Expand Up @@ -43,7 +44,7 @@ class BiasGeluVer1(name: String, attributes: Map<String, Attribute<Any>> = empty
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[ManualAllocatorContext.Key]
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val input = inputs[0]!!.data as NumberNDArrayCore
val bias = inputs[1]!!.data as NumberNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.kinference.ndarray.broadcasting.Broadcasting
import io.kinference.ndarray.extensions.allocateNDArray
import io.kinference.operator.*
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.PredictionContext
import kotlin.coroutines.coroutineContext

sealed class MatMul(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
Expand Down Expand Up @@ -50,7 +51,7 @@ class MatMulVer1(name: String, attributes: Map<String, Attribute<Any>>, inputs:
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[ManualAllocatorContext.Key]
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val first = inputs[0]!!.data as NumberNDArrayCore
val second = inputs[1]!!.data as NumberNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import io.kinference.primitives.types.DataType
import io.kinference.protobuf.FLOAT_TENSOR_TYPES
import io.kinference.protobuf.message.AttributeProto
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.PredictionContext
import kotlin.coroutines.coroutineContext

sealed class Cast(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
Expand Down Expand Up @@ -801,7 +802,7 @@ class CastVer6(name: String, attributes: Map<String, Attribute<Any>>, inputs: Li
private val toType: Int by attribute("to") { it: Number -> it.toInt() }

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[ManualAllocatorContext.Key]
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext

val tensor = inputs.first()!!
val to = TensorProto.DataType.fromValue(toType)!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,20 @@ package io.kinference.ndarray.arrays.memory
import io.kinference.ndarray.arrays.memory.contexts.*
import io.kinference.ndarray.arrays.memory.storage.*
import io.kinference.utils.*
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.*
import java.util.concurrent.ConcurrentLinkedQueue
import kotlin.coroutines.CoroutineContext

interface ArrayStorage {
fun resetState()
}

class PredictionContextDispatcher(private val predictionConfig: PredictionConfig) : Closeable {
private val limiter: MemoryManager = MemoryManager(
memoryLimit = predictionConfig.memoryThreshold,
cacheClearingInterval = predictionConfig.memoryClearingInterval,
onCacheClear = ::clearCache)

private val contextQueue: ConcurrentLinkedQueue<CoroutineContext> = ConcurrentLinkedQueue()
private val contextQueue: ConcurrentLinkedQueue<PredictionContext> = ConcurrentLinkedQueue()
val allocationMode
get() = predictionConfig.allocationMode

fun getPredictionContext(): CoroutineContext {
fun getPredictionContext(): PredictionContext {
val allocatorContext = when (predictionConfig.allocationMode) {
AllocationMode.NoAllocation -> getNoAllocatorContext()
AllocationMode.Manual -> getManualAllocatorContext()
Expand All @@ -31,21 +25,23 @@ class PredictionContextDispatcher(private val predictionConfig: PredictionConfig
return allocatorContext
}

@OptIn(ExperimentalCoroutinesApi::class)
private fun getNoAllocatorContext(): CoroutineContext {
return contextQueue.poll() ?: (NoAllocatorContext() + ParallelismLimiterContext(Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit)))
private fun getNoAllocatorContext(): PredictionContext {
return contextQueue.poll() ?: (NoAllocatorContext(getDispatcher()))
}

@OptIn(ExperimentalCoroutinesApi::class)
private fun getAutoAllocatorContext(): CoroutineContext {
private fun getAutoAllocatorContext(): PredictionContext {
limiter.updateLastAccessTime()
return contextQueue.poll() ?: (AutoAllocatorContext(AutoArrayHandlingStorage(limiter)) + ParallelismLimiterContext(Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit)))
return contextQueue.poll() ?: (AutoAllocatorContext(getDispatcher(), AutoArrayHandlingStorage(limiter)))
}

@OptIn(ExperimentalCoroutinesApi::class)
private fun getManualAllocatorContext(): CoroutineContext {
private fun getManualAllocatorContext(): PredictionContext {
limiter.updateLastAccessTime()
return contextQueue.poll() ?: (ManualAllocatorContext(ManualArrayHandlingStorage(limiter)) + ParallelismLimiterContext(Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit)))
return contextQueue.poll() ?: (ManualAllocatorContext(getDispatcher(), ManualArrayHandlingStorage(limiter)))
}

@OptIn(ExperimentalCoroutinesApi::class)
private fun getDispatcher(): CoroutineDispatcher {
return Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit)
}

fun clearCache() {
Expand All @@ -58,7 +54,10 @@ class PredictionContextDispatcher(private val predictionConfig: PredictionConfig
clearCache()
}

fun returnStorage(context: CoroutineContext) {
fun returnStorage(context: PredictionContext) {
if (context is AllocatorContext<*>) {
context.finalizeContext()
}
contextQueue.offer(context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ package io.kinference.ndarray.arrays.memory.contexts
import io.kinference.ndarray.arrays.memory.storage.AutoArrayHandlingStorage
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveArray
import io.kinference.utils.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlin.coroutines.*

internal class AutoAllocatorContext internal constructor(
dispatcher: CoroutineDispatcher,
storage: AutoArrayHandlingStorage,
) : BaseAllocatorContextWithStorage<AutoArrayHandlingStorage>(storage) {

companion object Key : CoroutineContext.Key<AutoAllocatorContext>
override val key: CoroutineContext.Key<*> get() = Key
}
) : AllocatorContext<AutoArrayHandlingStorage>(dispatcher, storage)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.memory.storage.ManualArrayHandlingStorage
import io.kinference.ndarray.arrays.memory.storage.ManualStorage
import io.kinference.primitives.types.DataType
import kotlin.coroutines.CoroutineContext
import io.kinference.utils.AllocatorContext
import kotlinx.coroutines.CoroutineDispatcher

class ManualAllocatorContext internal constructor(
dispatcher: CoroutineDispatcher,
storage: ManualArrayHandlingStorage,
) : BaseAllocatorContextWithStorage<ManualStorage>(storage) {

companion object Key : CoroutineContext.Key<ManualAllocatorContext>
override val key: CoroutineContext.Key<*> get() = Key
) : AllocatorContext<ManualStorage>(dispatcher, storage) {

fun getNDArray(dataType: DataType, strides: Strides, fillZeros: Boolean = false): MutableNDArrayCore {
return storage.getNDArray(dataType, strides, fillZeros)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.kinference.ndarray.arrays.memory.storage

import io.kinference.ndarray.arrays.memory.*
import io.kinference.utils.ArrayStorage

internal interface TypedAutoHandlingStorage {
fun moveBlocksIntoUnused()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.kinference.ndarray.arrays.memory.storage
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.memory.*
import io.kinference.primitives.types.DataType
import io.kinference.utils.ArrayStorage

internal interface TypedManualHandlingStorage {
fun getNDArray(strides: Strides, fillZeros: Boolean = false, limiter: MemoryManager): MutableNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ internal class PrimitiveAutoHandlingArrayStorage : TypedAutoHandlingStorage {
val unusedQueue = unused.getOrPut(blockSize) { ArrayDeque(blocksNum) }
val usedQueue = used.getOrPut(blockSize) { ArrayDeque(blocksNum) }

val blocks = if (limiter.checkMemoryLimitAndAdd(type.getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) {
val blocks = if (limiter.checkMemoryLimitAndAdd(getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) {
Array(blocksNum) {
unusedQueue.removeFirstOrNull()?.apply {
fill(PrimitiveConstants.ZERO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ internal fun AutoAllocatorContext.getPrimitiveBlock(blocksNum: Int, blockSize: I
}

@GenerateNameFromPrimitives
internal fun DataType.getPrimitiveArraySizeInBytes(arraySize: Int): Long {
internal fun getPrimitiveArraySizeInBytes(arraySize: Int): Long {
return PrimitiveConstants.SIZE_BYTES * arraySize
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal class PrimitiveManualHandlingArrayStorage : TypedManualHandlingStorage
override fun getNDArray(strides: Strides, fillZeros: Boolean, limiter: MemoryManager): MutableNDArrayCore {
val blockSize = blockSizeByStrides(strides)
val blocksNum = strides.linearSize / blockSize
val blocks = if (limiter.checkMemoryLimitAndAdd(type.getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) {
val blocks = if (limiter.checkMemoryLimitAndAdd(getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) {
val queue = storage.getOrPut(blockSize) { ArrayDeque(blocksNum) }
Array(blocksNum) {
queue.removeFirstOrNull()?.apply {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.kinference.ndarray.arrays.pointers.accept
import io.kinference.ndarray.blockSizeByStrides
import io.kinference.primitives.annotations.*
import io.kinference.primitives.types.*
import io.kinference.utils.PredictionContext
import io.kinference.utils.inlines.InlineInt
import kotlin.coroutines.coroutineContext
import kotlin.math.min
Expand Down Expand Up @@ -59,7 +60,8 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>) {
require(size % blockSize == 0) { "Size must divide blockSize" }

val blocksNum = if (blockSize == 0) 0 else size / blockSize
val blocks = coroutineContext[AutoAllocatorContext.Key]?.getPrimitiveBlock(blocksNum, blockSize) ?: Array(blocksNum) { PrimitiveArray(blockSize) }
val blocks = (coroutineContext[PredictionContext.Key] as? AutoAllocatorContext)?.getPrimitiveBlock(blocksNum, blockSize)
?: Array(blocksNum) { PrimitiveArray(blockSize) }

return PrimitiveTiledArray(blocks)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.kinference.utils

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext

object ResourcesDispatcher {
Expand All @@ -16,11 +17,30 @@ object ResourcesDispatcher {
}
}

class ParallelismLimiterContext(val dispatcher: CoroutineDispatcher) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<ParallelismLimiterContext>
override val key: CoroutineContext.Key<*> get() = Key
interface PredictionKey<T : PredictionContext> : CoroutineContext.Key<T>

sealed class PredictionContext(
val dispatcher: CoroutineDispatcher
) : AbstractCoroutineContextElement(PredictionContext) {
companion object Key : PredictionKey<PredictionContext>
}

interface ArrayStorage {
fun resetState()
}

abstract class AllocatorContext<T : ArrayStorage>(
dispatcher: CoroutineDispatcher,
val storage: T
) : PredictionContext(dispatcher) {

fun finalizeContext() {
storage.resetState()
}
}

class NoAllocatorContext(dispatcher: CoroutineDispatcher) : PredictionContext(dispatcher)

fun CoroutineScope.launchWithLimitOrDefault(block: suspend CoroutineScope.() -> Unit) {
this.launch(coroutineContext[ParallelismLimiterContext.Key]?.dispatcher ?: Dispatchers.Default, block = block)
this.launch(coroutineContext[PredictionContext]?.dispatcher ?: Dispatchers.Default, block = block)
}

0 comments on commit c942273

Please sign in to comment.