Skip to content

Commit

Permalink
JBAI-4393 [core] Rework context keys
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank committed Sep 2, 2024
1 parent c942273 commit 3caa0bc
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class AttentionVer1(name: String, attributes: Map<String, Attribute<Any>>, input
private val maskFilterValue: Float by attribute("mask_filter_value") { it: Number -> it.toFloat() }

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val context = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val context = coroutineContext[ManualAllocatorContext]

val input = inputs[0]!!
val weights = inputs[1]!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class EmbedLayerNormalizationVer1(
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val manualContext = coroutineContext[ManualAllocatorContext]

val inputIds = inputs[0]!!.data as IntNDArray
val segmentIds = inputs[1]?.data as IntNDArray?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class SkipLayerNormalizationVer1(name: String, attributes: Map<String, Attribute


override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val manualContext = coroutineContext[ManualAllocatorContext]

val input = inputs[0]!!.data as FloatNDArray
val output = (manualContext?.getNDArray(DataType.FLOAT, input.strides, fillZeros = false) ?: MutableFloatNDArray(input.strides)) as MutableFloatNDArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class AddVer7(name: String, attributes: Map<String, Attribute<Any>>, inputs: Lis
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val manualContext = coroutineContext[ManualAllocatorContext]

val left = inputs[0]!!.data as NumberNDArrayCore
val right = inputs[1]!!.data as NumberNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BiasGeluVer1(name: String, attributes: Map<String, Attribute<Any>> = empty
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val manualContext = coroutineContext[ManualAllocatorContext]

val input = inputs[0]!!.data as NumberNDArrayCore
val bias = inputs[1]!!.data as NumberNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MatMulVer1(name: String, attributes: Map<String, Attribute<Any>>, inputs:
}

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val manualContext = coroutineContext[ManualAllocatorContext]

val first = inputs[0]!!.data as NumberNDArrayCore
val second = inputs[1]!!.data as NumberNDArrayCore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ class CastVer6(name: String, attributes: Map<String, Attribute<Any>>, inputs: Li
private val toType: Int by attribute("to") { it: Number -> it.toInt() }

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val manualContext = coroutineContext[PredictionContext.Key] as? ManualAllocatorContext
val manualContext = coroutineContext[ManualAllocatorContext]

val tensor = inputs.first()!!
val to = TensorProto.DataType.fromValue(toType)!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import io.kinference.utils.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlin.coroutines.*

@OptIn(ExperimentalStdlibApi::class)
internal class AutoAllocatorContext internal constructor(
dispatcher: CoroutineDispatcher,
storage: AutoArrayHandlingStorage,
) : AllocatorContext<AutoArrayHandlingStorage>(dispatcher, storage)
) : AllocatorContext<AutoArrayHandlingStorage>(dispatcher, storage) {
companion object Key : AbstractCoroutineContextKey<AllocatorContext<*>, AutoAllocatorContext>(
AllocatorContext.Key, { it as? AutoAllocatorContext }
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ import io.kinference.ndarray.arrays.memory.storage.ManualStorage
import io.kinference.primitives.types.DataType
import io.kinference.utils.AllocatorContext
import kotlinx.coroutines.CoroutineDispatcher
import kotlin.coroutines.AbstractCoroutineContextKey

@OptIn(ExperimentalStdlibApi::class)
class ManualAllocatorContext internal constructor(
dispatcher: CoroutineDispatcher,
storage: ManualArrayHandlingStorage,
) : AllocatorContext<ManualStorage>(dispatcher, storage) {
companion object Key : AbstractCoroutineContextKey<AllocatorContext<*>, ManualAllocatorContext>(
AllocatorContext.Key, { it as? ManualAllocatorContext }
)


fun getNDArray(dataType: DataType, strides: Strides, fillZeros: Boolean = false): MutableNDArrayCore {
return storage.getNDArray(dataType, strides, fillZeros)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ internal class PrimitiveTiledArray(val blocks: Array<PrimitiveArray>) {
require(size % blockSize == 0) { "Size must divide blockSize" }

val blocksNum = if (blockSize == 0) 0 else size / blockSize
val blocks = (coroutineContext[PredictionContext.Key] as? AutoAllocatorContext)?.getPrimitiveBlock(blocksNum, blockSize)
val blocks = coroutineContext[AutoAllocatorContext]?.getPrimitiveBlock(blocksNum, blockSize)
?: Array(blocksNum) { PrimitiveArray(blockSize) }

return PrimitiveTiledArray(blocks)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@file:OptIn(ExperimentalStdlibApi::class)
package io.kinference.utils

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.*

object ResourcesDispatcher {
private val tokenChannel = Channel<Unit>(capacity = PlatformUtils.cores)
Expand All @@ -17,12 +17,17 @@ object ResourcesDispatcher {
}
}

interface PredictionKey<T : PredictionContext> : CoroutineContext.Key<T>

sealed class PredictionContext(
val dispatcher: CoroutineDispatcher
) : AbstractCoroutineContextElement(PredictionContext) {
companion object Key : PredictionKey<PredictionContext>
companion object Key : CoroutineContext.Key<PredictionContext>

override val key
get() = Key

override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = getPolymorphicElement(key)

override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = minusPolymorphicKey(key)
}

interface ArrayStorage {
Expand All @@ -33,13 +38,22 @@ abstract class AllocatorContext<T : ArrayStorage>(
dispatcher: CoroutineDispatcher,
val storage: T
) : PredictionContext(dispatcher) {
companion object Key : AbstractCoroutineContextKey<PredictionContext, AllocatorContext<*>>(
PredictionContext.Key,
{ it as? AllocatorContext<*> }
)

fun finalizeContext() {
storage.resetState()
}
}

class NoAllocatorContext(dispatcher: CoroutineDispatcher) : PredictionContext(dispatcher)
class NoAllocatorContext(dispatcher: CoroutineDispatcher) : PredictionContext(dispatcher) {
companion object Key : AbstractCoroutineContextKey<PredictionContext, NoAllocatorContext>(
PredictionContext.Key,
{ it as? NoAllocatorContext }
)
}

fun CoroutineScope.launchWithLimitOrDefault(block: suspend CoroutineScope.() -> Unit) {
this.launch(coroutineContext[PredictionContext]?.dispatcher ?: Dispatchers.Default, block = block)
Expand Down

0 comments on commit 3caa0bc

Please sign in to comment.