Skip to content

Commit

Permalink
JBAI-4393 [core, ndarray] Refactored memory management and array hand…
Browse files Browse the repository at this point in the history
…ling: added new type for limiter which works with manually managed ndarrays, added manual ndarray handling in Attention and TensorExtensions, moved to use standard DataType enum instead of ArrayTypes.
  • Loading branch information
dmitriyb committed Aug 19, 2024
1 parent 9caf75c commit f1a9296
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.kinference.core.data.tensor
import io.kinference.core.*
import io.kinference.data.ONNXTensor
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.memory.ManualAllocatorContext
import io.kinference.ndarray.arrays.tiled.*
import io.kinference.protobuf.FLOAT_TENSOR_TYPES
import io.kinference.protobuf.message.TensorProto
Expand All @@ -12,10 +13,11 @@ import io.kinference.types.ValueTypeInfo

//TODO: support segments
//TODO: support external data
class KITensor(name: String?, override val data: NDArrayCore, val info: ValueTypeInfo.TensorTypeInfo) : ONNXTensor<NDArrayCore, CoreBackend>(name, data) {
class KITensor(name: String?, override val data: NDArrayCore, val info: ValueTypeInfo.TensorTypeInfo, private var context: ManualAllocatorContext? = null) : ONNXTensor<NDArrayCore, CoreBackend>(name, data) {
constructor(data: NDArrayCore, info: ValueInfo) : this(info.name, data, info.typeInfo as ValueTypeInfo.TensorTypeInfo)

override suspend fun close() {
context?.returnNDArray(data)
data.close()
}

Expand All @@ -41,7 +43,7 @@ class KITensor(name: String?, override val data: NDArrayCore, val info: ValueTyp
override val backend = CoreBackend

override fun rename(name: String): KITensor {
return KITensor(name, data, info)
return KITensor(name, data, info, context)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package io.kinference.core.data.tensor

import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.memory.ManualAllocatorContext
import io.kinference.ndarray.extensions.concat
import io.kinference.ndarray.extensions.splitWithAxis
import io.kinference.primitives.types.DataType
import io.kinference.protobuf.resolveProtoDataType
import io.kinference.types.TensorShape
import io.kinference.types.ValueTypeInfo

fun NDArrayCore.asTensor(name: String? = null) = KITensor(name, this, ValueTypeInfo.TensorTypeInfo(TensorShape(this.shape), type.resolveProtoDataType()))
fun NDArrayCore.asTensor(name: String? = null, context: ManualAllocatorContext? = null) = KITensor(name, this, ValueTypeInfo.TensorTypeInfo(TensorShape(this.shape), type.resolveProtoDataType()), context)

internal fun <T : NDArray> T.asTensor(name: String? = null) = (this as NDArrayCore).asTensor(name)
internal fun <T : NDArray> T.asTensor(name: String? = null, context: ManualAllocatorContext? = null) = (this as NDArrayCore).asTensor(name, context)

internal fun <T : NDArray> Collection<T>.asONNXTensors(names: List<String>): List<KITensor> {
return this.zip(names).map { (data, name) -> data.asTensor(name) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class KIModel(

@OptIn(ExperimentalCoroutinesApi::class)
private val dispatcher: CoroutineDispatcher = Dispatchers.Default.limitedParallelism(parallelismLimit)
private val modelArrayStorage: ModelArrayStorage = ModelArrayStorage(MemoryLimiters.Default)
private val modelArrayStorage: ModelArrayStorage = ModelArrayStorage(memoryLimiter)

override fun addProfilingContext(name: String): ProfilingContext = ProfilingContext(name).apply { profiles.add(this) }
override fun analyzeProfilingResults(): ProfileAnalysisEntry = profiles.analyze("Model $name")
Expand All @@ -44,20 +44,31 @@ class KIModel(
coreReserved = true
}

if (memoryLimiter == MemoryLimiters.NoAllocator) {
withContext(limiterContext) {
return@withContext graph.execute(input, contexts)
when (memoryLimiter) {
MemoryLimiters.NoAllocator -> {
withContext(limiterContext) {
return@withContext graph.execute(input, contexts)
}
}
} else {
val allocatorContext = modelArrayStorage.createAllocatorContext()
val mixedContext = allocatorContext + limiterContext
MemoryLimiters.DefaultManualAllocator -> {
val allocatorContext = modelArrayStorage.createManualAllocatorContext()
val mixedContext = allocatorContext + limiterContext

withContext(mixedContext) {
val coroutineContext = coroutineContext[AllocatorContext.Key]!!
val execResult = graph.execute(input, contexts)
val copies = execResult.map { it.clone(it.name) }.toList()
coroutineContext.closeAllocated()
return@withContext copies
withContext(mixedContext) {
return@withContext graph.execute(input, contexts)
}
}
else -> {
val allocatorContext = modelArrayStorage.createAutoAllocatorContext()
val mixedContext = allocatorContext + limiterContext

withContext(mixedContext) {
val coroutineContext = coroutineContext[AutoAllocatorContext.Key]!!
val execResult = graph.execute(input, contexts)
val copies = execResult.map { it.clone(it.name) }.toList()
coroutineContext.returnUsedArrays()
return@withContext copies
}
}
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,34 @@ import io.kinference.core.optimizer.rules.context.AttentionContextRule
import io.kinference.data.ONNXData
import io.kinference.graph.Contexts
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.memory.ManualAllocatorContext
import io.kinference.ndarray.arrays.pointers.accept
import io.kinference.ndarray.arrays.pointers.map
import io.kinference.ndarray.arrays.tiled.FloatTiledArray
import io.kinference.ndarray.extensions.allocateNDArray
import io.kinference.ndarray.extensions.dotTransposedWithAlpha
import io.kinference.ndarray.extensions.softmax.softmax
import io.kinference.operator.*
import io.kinference.optimizer.GraphOptimizer.Companion.isOpt
import io.kinference.primitives.types.DataType
import io.kinference.protobuf.message.AttributeProto
import io.kinference.protobuf.message.TensorProto
import io.kinference.utils.launchWithLimitOrDefault
import kotlinx.coroutines.coroutineScope
import kotlin.coroutines.coroutineContext
import kotlin.math.min
import kotlin.math.sqrt

sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String, Attribute<Any>>, inputs: List<String>, outputs: List<String>) : Operator<KITensor, KITensor>(name, info, attributes, inputs, outputs) {
companion object {
private suspend fun attentionScore(
scores: NDArrayCore, batchSize: Int, seqLen: Int,
numHeads: Int, hiddenSize: Int, present: NDArrayCore
numHeads: Int, hiddenSize: Int, present: NDArrayCore, context: ManualAllocatorContext? = null
): Pair<NDArrayCore, NDArrayCore> {
val headSize = hiddenSize / numHeads

val output = allocateNDArray(scores.type, Strides(intArrayOf(batchSize, numHeads, seqLen, headSize)))
val outputStrides = Strides(intArrayOf(batchSize, numHeads, seqLen, headSize))
val output = context?.getNDArray(scores.type, outputStrides, fillZeros = true) ?: allocateNDArray(scores.type, outputStrides)

coroutineScope {
for (batchNum in 0 until batchSize) {
Expand All @@ -46,6 +51,8 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
}
}

context?.returnNDArray(scores)

val outputTransposed = output.transpose(intArrayOf(0, 2, 1, 3)).reshape(intArrayOf(batchSize, seqLen, hiddenSize))
return outputTransposed to present
}
Expand Down Expand Up @@ -108,26 +115,27 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,

internal suspend fun getScores(
unidir: Boolean, q: NDArrayCore, k: NDArrayCore, v: NDArrayCore, mask: IntNDArray?,
past: NDArrayCore?, batchSize: Int, seqLen: Int, numHeads: Int, hiddenSize: Int, maskFilterValue: Float = -10_000f
past: NDArrayCore?, batchSize: Int, seqLen: Int, numHeads: Int, hiddenSize: Int, maskFilterValue: Float = -10_000f, context: ManualAllocatorContext? = null
): Pair<NDArrayCore, NDArrayCore> {
val headSize = hiddenSize / numHeads

val pastSeqLen = past?.shape?.get(3) ?: 0
val present = makePresent(past, k, v, batchSize, seqLen, numHeads, hiddenSize)

val scores = normalizedScores(unidir, q, mask, batchSize, seqLen, pastSeqLen, headSize, numHeads, present, maskFilterValue)
return attentionScore(scores, batchSize, seqLen, numHeads, hiddenSize, present)
val scores = normalizedScores(unidir, q, mask, batchSize, seqLen, pastSeqLen, headSize, numHeads, present, maskFilterValue, context)
return attentionScore(scores, batchSize, seqLen, numHeads, hiddenSize, present, context)
}

private suspend fun normalizedScores(
unidir: Boolean, queries: NDArrayCore, maskIndices: IntNDArray?, batchSize: Int,
seqLen: Int, pastSeqLen: Int, headSize: Int, numHeads: Int, present: NDArrayCore, maskFilterValue: Float = -10_000f
seqLen: Int, pastSeqLen: Int, headSize: Int, numHeads: Int, present: NDArrayCore, maskFilterValue: Float = -10_000f, context: ManualAllocatorContext? = null
): NumberNDArrayCore {
val allSeqLen = present.shape[3]

val scores = allocateNDArray(queries.type, Strides(intArrayOf(batchSize, numHeads, seqLen, allSeqLen))) as MutableNumberNDArrayCore
val scoresStrides = Strides(intArrayOf(batchSize, numHeads, seqLen, allSeqLen))
val scores = (context?.getNDArray(queries.type, scoresStrides, fillZeros = true) ?: allocateNDArray(queries.type, scoresStrides)) as MutableNumberNDArrayCore

val maskData = maskIndices?.maskFromIndices(unidir, batchSize, seqLen, pastSeqLen, maskFilterValue)
val maskData = maskIndices?.maskFromIndices(unidir, batchSize, seqLen, pastSeqLen, maskFilterValue, context)

val alpha = 1.0 / sqrt(headSize.toDouble())

Expand All @@ -148,27 +156,38 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
}
}

if (maskData != null) {
context?.returnNDArray(maskData)
}
context?.returnNDArray(queries)

val softmaxDest = (context?.getNDArray(scores.type, scoresStrides) ?: allocateNDArray(scores.type, scoresStrides)) as MutableNumberNDArrayCore

return softmax(input = scores, axis = -1, dest = softmaxDest)

//softmax for each result (normalize along last axis)
return scores.softmax(axis = -1)
// return scores.softmax(axis = -1)
}

private suspend fun IntNDArray?.maskFromIndices(unidir: Boolean, batchSize: Int, seqLen: Int, pastSeqLen: Int, maskFilterValue: Float = -10_000f): FloatNDArray {
private suspend fun IntNDArray?.maskFromIndices(unidir: Boolean, batchSize: Int, seqLen: Int, pastSeqLen: Int, maskFilterValue: Float = -10_000f, context: ManualAllocatorContext? = null): FloatNDArray {
val fullSeqLen = seqLen + pastSeqLen
val maskDataShape = intArrayOf(batchSize, seqLen, fullSeqLen)
val mask = MutableFloatNDArray(Strides(maskDataShape))
val maskStrides = Strides(maskDataShape)

val mask = context?.getNDArray(DataType.FLOAT, maskStrides) ?: MutableFloatNDArray(maskStrides)
val maskOffset = seqLen * fullSeqLen
repeat(batchSize) { i ->
if (this != null) {
//raw attention (no padding). only raw attention mask is 2-dimensional
if (this.rank == 2) {
val maskPointer = mask.array.pointer(maskOffset * i)
val maskPointer = (mask as MutableFloatNDArray).array.pointer(maskOffset * i)
val maskIndicesPointer = this.array.pointer(i * fullSeqLen)

maskPointer.accept(maskIndicesPointer, fullSeqLen) { _, src -> if (src > 0) 0f else maskFilterValue }
} else {
//for left/right-side padding
val maskIndicesPointer = this.array.pointer(i)
val maskPointer = mask.array.pointer(maskOffset * i + maskIndicesPointer.get())
val maskPointer = (mask as MutableFloatNDArray).array.pointer(maskOffset * i + maskIndicesPointer.get())
maskPointer.map(fullSeqLen - maskIndicesPointer.get()) { maskFilterValue }

if (this.rank == 1 && this.shape[0] == 2 * batchSize) {
Expand All @@ -186,15 +205,15 @@ sealed class Attention(name: String, info: OperatorInfo, attributes: Map<String,
}

if (unidir) {
val maskPointer = mask.array.pointer()
val maskPointer = (mask as MutableFloatNDArray).array.pointer()
for (seqIdx in 0 until seqLen - 1) {
val start = pastSeqLen + seqIdx + 1
maskPointer.linearIndex = seqIdx * fullSeqLen + maskOffset * i + start
maskPointer.map(fullSeqLen - start) { it + maskFilterValue }
}
}
}
return mask
return (mask as MutableFloatNDArray)
}

private val DEFAULT_VERSION = VersionInfo(sinceVersion = 1)
Expand Down Expand Up @@ -235,12 +254,13 @@ class AttentionVer1(name: String, attributes: Map<String, Attribute<Any>>, input

internal suspend fun initQueryKeyValue(
input: NDArrayCore, weights: NDArrayCore, bias: NDArrayCore,
batchSize: Int, seqLen: Int, hiddenSize: Int, numHeads: Int
batchSize: Int, seqLen: Int, hiddenSize: Int, numHeads: Int, context: ManualAllocatorContext? = null
): Array<MutableNDArrayCore> {
input as NumberNDArrayCore
val headSize = hiddenSize / numHeads

val qkv = Array(3) { allocateNDArray(input.type, Strides(intArrayOf(batchSize, numHeads, seqLen, headSize))) }
val qkvStrides = Strides(intArrayOf(batchSize, numHeads, seqLen, headSize))
val qkv = Array(3) { context?.getNDArray(input.type, qkvStrides, fillZeros = true) ?: allocateNDArray(input.type, qkvStrides) }

coroutineScope {
for (qkvIdx in 0 until 3) {
Expand Down Expand Up @@ -269,6 +289,8 @@ class AttentionVer1(name: String, attributes: Map<String, Attribute<Any>>, input
private val maskFilterValue: Float by attribute("mask_filter_value") { it: Number -> it.toFloat() }

override suspend fun <D : ONNXData<*, *>> apply(contexts: Contexts<D>, inputs: List<KITensor?>): List<KITensor?> {
val context = coroutineContext[ManualAllocatorContext.Key]

val input = inputs[0]!!
val weights = inputs[1]!!

Expand All @@ -286,10 +308,10 @@ class AttentionVer1(name: String, attributes: Map<String, Attribute<Any>>, input
input.data,
preparedWeights.data,
preparedBias.data,
batchSize, seqLen, hiddenSize, numHeads,
batchSize, seqLen, hiddenSize, numHeads, context
)

val (scores, present) = getScores(unidir, queries, keys, values, maskIndices, past, batchSize, seqLen, numHeads, hiddenSize, maskFilterValue)
return listOf(scores.asTensor(), present.asTensor())
val (scores, present) = getScores(unidir, queries, keys, values, maskIndices, past, batchSize, seqLen, numHeads, hiddenSize, maskFilterValue, context)
return listOf(scores.asTensor(context = context), present.asTensor(context = context))
}
}

This file was deleted.

This file was deleted.

Loading

0 comments on commit f1a9296

Please sign in to comment.