Skip to content

Commit

Permalink
Merge pull request #201 from JetBrains-Research/broadcast-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitriyb authored Sep 24, 2024
2 parents 0d7094f + 4bdb061 commit 5944ca8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import io.kinference.primitives.annotations.GenerateNameFromPrimitives
import io.kinference.primitives.annotations.GeneratePrimitives
import io.kinference.primitives.types.DataType
import io.kinference.primitives.types.PrimitiveType
import io.kinference.utils.inlines.InlineInt

@GenerateNameFromPrimitives
internal fun broadcastTwoTensorsPrimitive(
Expand Down Expand Up @@ -45,16 +44,20 @@ internal fun broadcastTwoTensorsPrimitive(
val rightBlocks = right.array.blocks
val destBlocks = dest.array.blocks

val leftIsScalarFun = { leftOffset: InlineInt, rightOffset: InlineInt, destOffset: InlineInt, axisToBroadcastIdx: InlineInt ->
val shapeIdx = axisToBroadcastIdx.value * 2
val leftIsScalarFun = ScalarBroadcastFun { leftOffset, rightOffset, destOffset, axisToBroadcastIdx ->
val shapeIdx = axisToBroadcastIdx * 2
val batchSize = destBroadcastingShape[shapeIdx]

for (batchIdx in 0 until batchSize) {
val leftScalar = leftBlocks[leftOffset.value][0]
val leftBatchOffset = leftOffset + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset + destOffsets[shapeIdx] * batchIdx

val leftScalar = leftBlocks[leftBatchOffset][0]

for (blockIdx in 0 until destBlocksInRow) {
val destBlock = destBlocks[destOffset.value + blockIdx]
val rightBlock = rightBlocks[rightOffset.value + blockIdx]
val destBlock = destBlocks[destBatchOffset + blockIdx]
val rightBlock = rightBlocks[rightBatchOffset + blockIdx]

for (idx in destBlock.indices) {
destBlock[idx] = op(leftScalar, rightBlock[idx])
Expand All @@ -63,16 +66,20 @@ internal fun broadcastTwoTensorsPrimitive(
}
}

val rightIsScalarFun = { leftOffset: InlineInt, rightOffset: InlineInt, destOffset: InlineInt, axisToBroadcastIdx: InlineInt ->
val shapeIdx = axisToBroadcastIdx.value * 2
val rightIsScalarFun = ScalarBroadcastFun { leftOffset, rightOffset, destOffset, axisToBroadcastIdx ->
val shapeIdx = axisToBroadcastIdx * 2
val batchSize = destBroadcastingShape[shapeIdx]

for (batchIdx in 0 until batchSize) {
val rightScalar = rightBlocks[rightOffset.value][0]
val leftBatchOffset = leftOffset + leftOffsets[shapeIdx] * batchIdx
val rightBatchOffset = rightOffset + rightOffsets[shapeIdx] * batchIdx
val destBatchOffset = destOffset + destOffsets[shapeIdx] * batchIdx

val rightScalar = rightBlocks[rightBatchOffset][0]

for (blockIdx in 0 until destBlocksInRow) {
val destBlock = destBlocks[destOffset.value + blockIdx]
val leftBlock = leftBlocks[leftOffset.value + blockIdx]
val destBlock = destBlocks[destBatchOffset + blockIdx]
val leftBlock = leftBlocks[leftBatchOffset + blockIdx]

for (idx in destBlock.indices) {
destBlock[idx] = op(leftBlock[idx], rightScalar)
Expand All @@ -81,27 +88,27 @@ internal fun broadcastTwoTensorsPrimitive(
}
}

val defaultFun = { leftOffset: InlineInt, rightOffset: InlineInt, destOffset: InlineInt, axisToBroadcastIdx: InlineInt ->
val defaultFun = ScalarBroadcastFun { leftOffset, rightOffset, destOffset, _ ->
for (blockIdx in 0 until destBlocksInRow) {
val leftBlock = leftBlocks[leftOffset.value + blockIdx]
val rightBlock = rightBlocks[rightOffset.value + blockIdx]
val destBlock = destBlocks[destOffset.value + blockIdx]
val leftBlock = leftBlocks[leftOffset + blockIdx]
val rightBlock = rightBlocks[rightOffset + blockIdx]
val destBlock = destBlocks[destOffset + blockIdx]

for (idx in destBlock.indices) {
destBlock[idx] = op(leftBlock[idx], rightBlock[idx])
}
}
}

val broadcastingFun = when {
val broadcastingFun: ScalarBroadcastFun = when {
leftIsScalar -> leftIsScalarFun
rightIsScalar -> rightIsScalarFun
else -> defaultFun
}

fun broadcast(leftOffset: Int, rightOffset: Int, destOffset: Int, axisToBroadcastIdx: Int) {
if (axisToBroadcastIdx == totalAxesToBroadcast) {
broadcastingFun(InlineInt(leftOffset), InlineInt(rightOffset), InlineInt(destOffset), InlineInt(axisToBroadcastIdx))
broadcastingFun(leftOffset, rightOffset, destOffset, axisToBroadcastIdx)
} else {
val shapeIdx = axisToBroadcastIdx * 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package io.kinference.ndarray.extensions.broadcasting
import io.kinference.ndarray.arrays.NDArrayCore
import io.kinference.ndarray.extensions.utils.calculateBlock

internal fun interface ScalarBroadcastFun {
operator fun invoke(leftOffset: Int, rightOffset: Int, destOffset: Int, axisToBroadcastIdx: Int)
}

internal data class BroadcastingInfo(
val broadcastingShapes: Array<IntArray>,
val broadcastingDestShape: IntArray,
Expand Down Expand Up @@ -89,8 +93,6 @@ internal data class BroadcastingInfo(
}
}



internal fun makeOffsets(shape: IntArray, blocksInRow: Int): IntArray {
val offsets = IntArray(shape.size)
offsets[offsets.lastIndex - 1] = blocksInRow
Expand Down

0 comments on commit 5944ca8

Please sign in to comment.