diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/model/KIModel.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/model/KIModel.kt index 6611fc1c..3f78377d 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/model/KIModel.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/model/KIModel.kt @@ -5,7 +5,6 @@ import io.kinference.core.graph.KIGraph import io.kinference.graph.Contexts import io.kinference.model.Model import io.kinference.ndarray.arrays.memory.* -import io.kinference.ndarray.arrays.memory.contexts.finalizeAllocatorContext import io.kinference.operator.OperatorSetRegistry import io.kinference.profiler.* import io.kinference.protobuf.message.ModelProto @@ -47,7 +46,6 @@ class KIModel( return@withContext graph.execute(input, contexts).map { it.clone(it.name) }.toList() } - predictionContext.finalizeAllocatorContext() predictionContextDispatcher.returnStorage(predictionContext) output } finally { diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt index 1add2b1b..73732877 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/attention/Attention.kt @@ -19,6 +19,7 @@ import io.kinference.optimizer.GraphOptimizer.Companion.isOpt import io.kinference.primitives.types.DataType import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto +import io.kinference.utils.PredictionContext import io.kinference.utils.launchWithLimitOrDefault import kotlinx.coroutines.coroutineScope import kotlin.coroutines.coroutineContext @@ -287,7 +288,7 @@ class AttentionVer1(name: String, attributes: Map>, input private val maskFilterValue: Float by attribute("mask_filter_value") { it: Number -> it.toFloat() } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val context = coroutineContext[ManualAllocatorContext.Key] + val context = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val input = inputs[0]!! val weights = inputs[1]!! diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/EmbedLayerNormalization.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/EmbedLayerNormalization.kt index 33a01c6d..09866772 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/EmbedLayerNormalization.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/EmbedLayerNormalization.kt @@ -11,6 +11,7 @@ import io.kinference.operator.* import io.kinference.primitives.types.DataType import io.kinference.protobuf.message.AttributeProto.AttributeType import io.kinference.protobuf.message.TensorProto +import io.kinference.utils.PredictionContext import kotlin.coroutines.coroutineContext import kotlin.math.sqrt @@ -175,7 +176,7 @@ class EmbedLayerNormalizationVer1( } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val manualContext = coroutineContext[ManualAllocatorContext.Key] + val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val inputIds = inputs[0]!!.data as IntNDArray val segmentIds = inputs[1]?.data as IntNDArray? diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/SkipLayerNormalization.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/SkipLayerNormalization.kt index 08b8e7f1..aa246044 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/SkipLayerNormalization.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/layer/normalization/SkipLayerNormalization.kt @@ -13,6 +13,7 @@ import io.kinference.operator.* import io.kinference.primitives.types.DataType import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto +import io.kinference.utils.PredictionContext import kotlin.coroutines.coroutineContext import kotlin.math.sqrt @@ -107,7 +108,7 @@ class SkipLayerNormalizationVer1(name: String, attributes: Map> apply(contexts: Contexts, inputs: List): List { - val manualContext = coroutineContext[ManualAllocatorContext.Key] + val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val input = inputs[0]!!.data as FloatNDArray val output = (manualContext?.getNDArray(DataType.FLOAT, input.strides, fillZeros = false) ?: MutableFloatNDArray(input.strides)) as MutableFloatNDArray @@ -119,7 +120,7 @@ class SkipLayerNormalizationVer1(name: String, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { @@ -55,7 +56,7 @@ class AddVer7(name: String, attributes: Map>, inputs: Lis } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val manualContext = coroutineContext[ManualAllocatorContext.Key] + val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val left = inputs[0]!!.data as NumberNDArrayCore val right = inputs[1]!!.data as NumberNDArrayCore diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/BiasGelu.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/BiasGelu.kt index c6b21a77..65b5089e 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/BiasGelu.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/BiasGelu.kt @@ -11,6 +11,7 @@ import io.kinference.ndarray.arrays.memory.contexts.ManualAllocatorContext import io.kinference.ndarray.extensions.allocateNDArray import io.kinference.ndarray.extensions.gelu.biasGelu import io.kinference.operator.* +import io.kinference.utils.PredictionContext import kotlin.coroutines.coroutineContext sealed class BiasGelu(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { @@ -43,7 +44,7 @@ class BiasGeluVer1(name: String, attributes: Map> = empty } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val manualContext = coroutineContext[ManualAllocatorContext.Key] + val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val input = inputs[0]!!.data as NumberNDArrayCore val bias = inputs[1]!!.data as NumberNDArrayCore diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/MatMul.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/MatMul.kt index 1d560845..aabce734 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/MatMul.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/MatMul.kt @@ -11,6 +11,7 @@ import io.kinference.ndarray.broadcasting.Broadcasting import io.kinference.ndarray.extensions.allocateNDArray import io.kinference.operator.* import io.kinference.protobuf.message.TensorProto +import io.kinference.utils.PredictionContext import kotlin.coroutines.coroutineContext sealed class MatMul(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { @@ -50,7 +51,7 @@ class MatMulVer1(name: String, attributes: Map>, inputs: } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val manualContext = coroutineContext[ManualAllocatorContext.Key] + val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val first = inputs[0]!!.data as NumberNDArrayCore val second = inputs[1]!!.data as NumberNDArrayCore diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Cast.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Cast.kt index d0bc9a56..acc9dfb9 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Cast.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Cast.kt @@ -14,6 +14,7 @@ import io.kinference.primitives.types.DataType import io.kinference.protobuf.FLOAT_TENSOR_TYPES import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto +import io.kinference.utils.PredictionContext import kotlin.coroutines.coroutineContext sealed class Cast(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { @@ -801,7 +802,7 @@ class CastVer6(name: String, attributes: Map>, inputs: Li private val toType: Int by attribute("to") { it: Number -> it.toInt() } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val manualContext = coroutineContext[ManualAllocatorContext.Key] + val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext val tensor = inputs.first()!! val to = TensorProto.DataType.fromValue(toType)!! diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/PredictionContextDispatcher.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/PredictionContextDispatcher.kt index 10a2c4bc..801e5c66 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/PredictionContextDispatcher.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/PredictionContextDispatcher.kt @@ -3,14 +3,8 @@ package io.kinference.ndarray.arrays.memory import io.kinference.ndarray.arrays.memory.contexts.* import io.kinference.ndarray.arrays.memory.storage.* import io.kinference.utils.* -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.* import java.util.concurrent.ConcurrentLinkedQueue -import kotlin.coroutines.CoroutineContext - -interface ArrayStorage { - fun resetState() -} class PredictionContextDispatcher(private val predictionConfig: PredictionConfig) : Closeable { private val limiter: MemoryManager = MemoryManager( @@ -18,11 +12,11 @@ class PredictionContextDispatcher(private val predictionConfig: PredictionConfig cacheClearingInterval = predictionConfig.memoryClearingInterval, onCacheClear = ::clearCache) - private val contextQueue: ConcurrentLinkedQueue = ConcurrentLinkedQueue() + private val contextQueue: ConcurrentLinkedQueue = ConcurrentLinkedQueue() val allocationMode get() = predictionConfig.allocationMode - fun getPredictionContext(): CoroutineContext { + fun getPredictionContext(): PredictionContext { val allocatorContext = when (predictionConfig.allocationMode) { AllocationMode.NoAllocation -> getNoAllocatorContext() AllocationMode.Manual -> getManualAllocatorContext() @@ -31,21 +25,23 @@ class PredictionContextDispatcher(private val predictionConfig: PredictionConfig return allocatorContext } - @OptIn(ExperimentalCoroutinesApi::class) - private fun getNoAllocatorContext(): CoroutineContext { - return contextQueue.poll() ?: (NoAllocatorContext() + ParallelismLimiterContext(Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit))) + private fun getNoAllocatorContext(): PredictionContext { + return contextQueue.poll() ?: (NoAllocatorContext(getDispatcher())) } - @OptIn(ExperimentalCoroutinesApi::class) - private fun getAutoAllocatorContext(): CoroutineContext { + private fun getAutoAllocatorContext(): PredictionContext { limiter.updateLastAccessTime() - return contextQueue.poll() ?: (AutoAllocatorContext(AutoArrayHandlingStorage(limiter)) + ParallelismLimiterContext(Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit))) + return contextQueue.poll() ?: (AutoAllocatorContext(getDispatcher(), AutoArrayHandlingStorage(limiter))) } - @OptIn(ExperimentalCoroutinesApi::class) - private fun getManualAllocatorContext(): CoroutineContext { + private fun getManualAllocatorContext(): PredictionContext { limiter.updateLastAccessTime() - return contextQueue.poll() ?: (ManualAllocatorContext(ManualArrayHandlingStorage(limiter)) + ParallelismLimiterContext(Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit))) + return contextQueue.poll() ?: (ManualAllocatorContext(getDispatcher(), ManualArrayHandlingStorage(limiter))) + } + + @OptIn(ExperimentalCoroutinesApi::class) + private fun getDispatcher(): CoroutineDispatcher { + return Dispatchers.Default.limitedParallelism(predictionConfig.parallelismLimit) } fun clearCache() { @@ -58,7 +54,10 @@ class PredictionContextDispatcher(private val predictionConfig: PredictionConfig clearCache() } - fun returnStorage(context: CoroutineContext) { + fun returnStorage(context: PredictionContext) { + if (context is AllocatorContext<*>) { + context.finalizeContext() + } contextQueue.offer(context) } } diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/AutoAllocatorContext.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/AutoAllocatorContext.kt index a4d36b55..e69367f5 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/AutoAllocatorContext.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/AutoAllocatorContext.kt @@ -3,12 +3,11 @@ package io.kinference.ndarray.arrays.memory.contexts import io.kinference.ndarray.arrays.memory.storage.AutoArrayHandlingStorage import io.kinference.primitives.types.DataType import io.kinference.primitives.types.PrimitiveArray +import io.kinference.utils.* +import kotlinx.coroutines.CoroutineDispatcher import kotlin.coroutines.* internal class AutoAllocatorContext internal constructor( + dispatcher: CoroutineDispatcher, storage: AutoArrayHandlingStorage, -) : BaseAllocatorContextWithStorage(storage) { - - companion object Key : CoroutineContext.Key - override val key: CoroutineContext.Key<*> get() = Key -} +) : AllocatorContext(dispatcher, storage) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/BaseAllocatorContextWithStorage.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/BaseAllocatorContextWithStorage.kt deleted file mode 100644 index e617c78d..00000000 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/BaseAllocatorContextWithStorage.kt +++ /dev/null @@ -1,24 +0,0 @@ -package io.kinference.ndarray.arrays.memory.contexts - -import io.kinference.ndarray.arrays.memory.ArrayStorage -import kotlin.coroutines.CoroutineContext - -interface BaseAllocatorContext: CoroutineContext.Element - -abstract class BaseAllocatorContextWithStorage(internal val storage: T) : BaseAllocatorContext { - fun finalizeContext() { - storage.resetState() - } -} - -fun CoroutineContext.finalizeAllocatorContext() { - this.fold(Unit) { _, context -> - if (context is BaseAllocatorContextWithStorage<*>) - context.finalizeContext() - } -} - -class NoAllocatorContext : BaseAllocatorContext { - companion object Key : CoroutineContext.Key - override val key: CoroutineContext.Key<*> get() = Key -} diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/ManualAllocatorContext.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/ManualAllocatorContext.kt index a713f31f..9a6663c7 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/ManualAllocatorContext.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/contexts/ManualAllocatorContext.kt @@ -4,14 +4,13 @@ import io.kinference.ndarray.arrays.* import io.kinference.ndarray.arrays.memory.storage.ManualArrayHandlingStorage import io.kinference.ndarray.arrays.memory.storage.ManualStorage import io.kinference.primitives.types.DataType -import kotlin.coroutines.CoroutineContext +import io.kinference.utils.AllocatorContext +import kotlinx.coroutines.CoroutineDispatcher class ManualAllocatorContext internal constructor( + dispatcher: CoroutineDispatcher, storage: ManualArrayHandlingStorage, -) : BaseAllocatorContextWithStorage(storage) { - - companion object Key : CoroutineContext.Key - override val key: CoroutineContext.Key<*> get() = Key +) : AllocatorContext(dispatcher, storage) { fun getNDArray(dataType: DataType, strides: Strides, fillZeros: Boolean = false): MutableNDArrayCore { return storage.getNDArray(dataType, strides, fillZeros) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/AutoArrayHandlingStorage.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/AutoArrayHandlingStorage.kt index 62364570..b0ffdbbb 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/AutoArrayHandlingStorage.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/AutoArrayHandlingStorage.kt @@ -1,6 +1,7 @@ package io.kinference.ndarray.arrays.memory.storage import io.kinference.ndarray.arrays.memory.* +import io.kinference.utils.ArrayStorage internal interface TypedAutoHandlingStorage { fun moveBlocksIntoUnused() diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/ManualArrayHandlingStorage.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/ManualArrayHandlingStorage.kt index 559334f8..227d2513 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/ManualArrayHandlingStorage.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/ManualArrayHandlingStorage.kt @@ -3,6 +3,7 @@ package io.kinference.ndarray.arrays.memory.storage import io.kinference.ndarray.arrays.* import io.kinference.ndarray.arrays.memory.* import io.kinference.primitives.types.DataType +import io.kinference.utils.ArrayStorage internal interface TypedManualHandlingStorage { fun getNDArray(strides: Strides, fillZeros: Boolean = false, limiter: MemoryManager): MutableNDArrayCore diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveAutoHandlingArrayStorage.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveAutoHandlingArrayStorage.kt index c0b7d986..4cd5bb66 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveAutoHandlingArrayStorage.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveAutoHandlingArrayStorage.kt @@ -24,7 +24,7 @@ internal class PrimitiveAutoHandlingArrayStorage : TypedAutoHandlingStorage { val unusedQueue = unused.getOrPut(blockSize) { ArrayDeque(blocksNum) } val usedQueue = used.getOrPut(blockSize) { ArrayDeque(blocksNum) } - val blocks = if (limiter.checkMemoryLimitAndAdd(type.getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) { + val blocks = if (limiter.checkMemoryLimitAndAdd(getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) { Array(blocksNum) { unusedQueue.removeFirstOrNull()?.apply { fill(PrimitiveConstants.ZERO) diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveGetBlockFunctionsExtension.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveGetBlockFunctionsExtension.kt index 5da084dc..9280823d 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveGetBlockFunctionsExtension.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveGetBlockFunctionsExtension.kt @@ -19,6 +19,6 @@ internal fun AutoAllocatorContext.getPrimitiveBlock(blocksNum: Int, blockSize: I } @GenerateNameFromPrimitives -internal fun DataType.getPrimitiveArraySizeInBytes(arraySize: Int): Long { +internal fun getPrimitiveArraySizeInBytes(arraySize: Int): Long { return PrimitiveConstants.SIZE_BYTES * arraySize } diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveManualHandlingArrayStorage.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveManualHandlingArrayStorage.kt index 29060279..1c264be0 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveManualHandlingArrayStorage.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/memory/storage/PrimitiveManualHandlingArrayStorage.kt @@ -25,7 +25,7 @@ internal class PrimitiveManualHandlingArrayStorage : TypedManualHandlingStorage override fun getNDArray(strides: Strides, fillZeros: Boolean, limiter: MemoryManager): MutableNDArrayCore { val blockSize = blockSizeByStrides(strides) val blocksNum = strides.linearSize / blockSize - val blocks = if (limiter.checkMemoryLimitAndAdd(type.getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) { + val blocks = if (limiter.checkMemoryLimitAndAdd(getPrimitiveArraySizeInBytes(arraySize = blockSize * blocksNum))) { val queue = storage.getOrPut(blockSize) { ArrayDeque(blocksNum) } Array(blocksNum) { queue.removeFirstOrNull()?.apply { diff --git a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt index 600211e3..339f2fb8 100644 --- a/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt +++ b/ndarray/ndarray-core/src/jvmMain/kotlin/io/kinference/ndarray/arrays/tiled/PrimitiveTiledArray.kt @@ -11,6 +11,7 @@ import io.kinference.ndarray.arrays.pointers.accept import io.kinference.ndarray.blockSizeByStrides import io.kinference.primitives.annotations.* import io.kinference.primitives.types.* +import io.kinference.utils.PredictionContext import io.kinference.utils.inlines.InlineInt import kotlin.coroutines.coroutineContext import kotlin.math.min @@ -59,7 +60,8 @@ internal class PrimitiveTiledArray(val blocks: Array) { require(size % blockSize == 0) { "Size must divide blockSize" } val blocksNum = if (blockSize == 0) 0 else size / blockSize - val blocks = coroutineContext[AutoAllocatorContext.Key]?.getPrimitiveBlock(blocksNum, blockSize) ?: Array(blocksNum) { PrimitiveArray(blockSize) } + val blocks = (coroutineContext[PredictionContext.Key] as? AutoAllocatorContext)?.getPrimitiveBlock(blocksNum, blockSize) + ?: Array(blocksNum) { PrimitiveArray(blockSize) } return PrimitiveTiledArray(blocks) } diff --git a/utils/utils-common/src/commonMain/kotlin/io/kinference/utils/ResourcesDispatcher.kt b/utils/utils-common/src/commonMain/kotlin/io/kinference/utils/ResourcesDispatcher.kt index 66b5cea9..b17df1f7 100644 --- a/utils/utils-common/src/commonMain/kotlin/io/kinference/utils/ResourcesDispatcher.kt +++ b/utils/utils-common/src/commonMain/kotlin/io/kinference/utils/ResourcesDispatcher.kt @@ -2,6 +2,7 @@ package io.kinference.utils import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel +import kotlin.coroutines.AbstractCoroutineContextElement import kotlin.coroutines.CoroutineContext object ResourcesDispatcher { @@ -16,11 +17,30 @@ object ResourcesDispatcher { } } -class ParallelismLimiterContext(val dispatcher: CoroutineDispatcher) : CoroutineContext.Element { - companion object Key : CoroutineContext.Key - override val key: CoroutineContext.Key<*> get() = Key +interface PredictionKey : CoroutineContext.Key + +sealed class PredictionContext( + val dispatcher: CoroutineDispatcher +) : AbstractCoroutineContextElement(PredictionContext) { + companion object Key : PredictionKey +} + +interface ArrayStorage { + fun resetState() +} + +abstract class AllocatorContext( + dispatcher: CoroutineDispatcher, + val storage: T +) : PredictionContext(dispatcher) { + + fun finalizeContext() { + storage.resetState() + } } +class NoAllocatorContext(dispatcher: CoroutineDispatcher) : PredictionContext(dispatcher) + fun CoroutineScope.launchWithLimitOrDefault(block: suspend CoroutineScope.() -> Unit) { - this.launch(coroutineContext[ParallelismLimiterContext.Key]?.dispatcher ?: Dispatchers.Default, block = block) + this.launch(coroutineContext[PredictionContext]?.dispatcher ?: Dispatchers.Default, block = block) }