Skip to content

Commit

Permalink
JBAI-4832 [ndarray] Refactored memory management and simplify contain…
Browse files Browse the repository at this point in the history
…er logic: added MemoryLimiter to ArrayStorage constructor for better memory control. Removed isNewlyCreated flag from ArrayContainer and streamlined getArrayContainers logic in AllocatorContext.
  • Loading branch information
dmitriyb committed Jul 30, 2024
1 parent 3039faf commit 502cbba
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,12 @@ data class AllocatorContext internal constructor(
override val key: CoroutineContext.Key<*> get() = Key

internal fun getArrayContainers(type: ArrayTypes, size: Int, count: Int): Array<ArrayContainer> {
if (limiter !is NoAllocatorMemoryLimiter) {
val arrayContainers = arrayOfNulls<ArrayContainer>(count)
for (i in 0 until count) {
val container = unusedContainers.getArrayContainer(type, size)
if (!container.isNewlyCreated)
limiter.deductMemory(container.sizeBytes.toLong())
arrayContainers[i] = container
usedContainers.add(container)
}
return arrayContainers as Array<ArrayContainer>
return if (limiter !is NoAllocatorMemoryLimiter) {
val result = Array(count) { unusedContainers.getArrayContainer(type, size) }
usedContainers.addAll(result)
result
} else {
return Array(count) { ArrayContainer(type, size) }
Array(count) { ArrayContainer(type, size) }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ sealed class ArrayContainer(
var isOutput: Boolean = false
private set

var isNewlyCreated: Boolean = true
private set

val markAsOutput = {
isOutput = true
}
Expand All @@ -39,7 +36,6 @@ sealed class ArrayContainer(
}

fun resetArray(arrayContainer: ArrayContainer) {
arrayContainer.isNewlyCreated = false
when (arrayContainer) {
is ByteArrayContainer -> arrayContainer.array.fill(0) // 8-bit signed
is UByteArrayContainer -> arrayContainer.array.fill(0u) // 8-bit unsigned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.kinference.ndarray.arrays.memory

import io.kinference.ndarray.arrays.ArrayTypes

internal class ArrayStorage(typeLength: Int, sizeLength: Int) {
internal class ArrayStorage(typeLength: Int, sizeLength: Int, private val limiter: MemoryLimiter) {
/**
* Structure is as follows:
* 1. Array by predefined types (all types are known compiled time)
Expand All @@ -29,6 +29,7 @@ internal class ArrayStorage(typeLength: Int, sizeLength: Int) {
val array = storage[tIndex][sIndex].removeFirstOrNull()
array?.let {
ArrayContainer.resetArray(it)
limiter.deductMemory(it.sizeBytes.toLong())
return it
}
sIndex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ModelArrayStorage(private val limiter: MemoryLimiter = MemoryLimiters.NoAl
}

private fun getStorage(): ArrayStorage {
return unusedArrays.poll() ?: ArrayStorage(typeSize, INIT_SIZE_VALUE)
return unusedArrays.poll() ?: ArrayStorage(typeSize, INIT_SIZE_VALUE, limiter)
}

private fun returnStorage(storage: ArrayStorage) {
Expand Down

0 comments on commit 502cbba

Please sign in to comment.