From c8ec49f1e2c9603498ca679727a499dc0b296e26 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Thu, 25 Jul 2024 10:45:44 -0700 Subject: [PATCH] Apply https://github.com/opensearch-project/k-NN/pull/1804 (#1880) Signed-off-by: Heemin Kim --- .../knn/index/codec/transfer/VectorTransferByte.java | 10 ++++++---- .../knn/index/codec/transfer/VectorTransferFloat.java | 10 ++++++---- .../index/codec/transfer/VectorTransferByteTests.java | 2 +- .../index/codec/transfer/VectorTransferFloatTests.java | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java index 5e9831708..e81ac35fc 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferByte.java @@ -34,11 +34,13 @@ public void init(final long totalLiveDocs) { public void transfer(final BytesRef bytesRef) { dimension = bytesRef.length * 8; if (vectorsPerTransfer == Integer.MIN_VALUE) { - vectorsPerTransfer = (bytesRef.length * totalLiveDocs) / vectorsStreamingMemoryLimit; - // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer - // Doing this will reduce 1 extra trip to JNI layer. + // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with length of 5, then per + // transfer we have to send 100/5 => 20 vectors. + vectorsPerTransfer = vectorsStreamingMemoryLimit / bytesRef.length; + // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that + // we are sending minimum number of vectors. if (vectorsPerTransfer == 0) { - vectorsPerTransfer = totalLiveDocs; + vectorsPerTransfer = 1; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java index af6d9490e..a9c792398 100644 --- a/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java +++ b/src/main/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloat.java @@ -38,11 +38,13 @@ public void transfer(final BytesRef bytesRef) { dimension = vector.length; if (vectorsPerTransfer == Integer.MIN_VALUE) { - vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; - // This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer - // Doing this will reduce 1 extra trip to JNI layer. + // if vectorsStreamingMemoryLimit is 100 bytes and we have 50 vectors with 5 dimension, then per + // transfer we have to send 100/(5 * 4) => 5 vectors. + vectorsPerTransfer = vectorsStreamingMemoryLimit / ((long) dimension * Float.BYTES); + // If vectorsPerTransfer comes out to be 0, then we set number of vectors per transfer to 1, to ensure that + // we are sending minimum number of vectors. if (vectorsPerTransfer == 0) { - vectorsPerTransfer = totalLiveDocs; + vectorsPerTransfer = 1; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java index 7e837fbf2..2f091a035 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferByteTests.java @@ -21,7 +21,7 @@ public class VectorTransferByteTests extends TestCase { public void testTransfer_whenCalled_thenAdded() { final BytesRef bytesRef1 = getByteArrayOfVectors(20); final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferByte vectorTransfer = new VectorTransferByte(1000); + VectorTransferByte vectorTransfer = new VectorTransferByte(40); try { vectorTransfer.init(2); diff --git a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java index 1f36f320d..620fd7c65 100644 --- a/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/transfer/VectorTransferFloatTests.java @@ -24,7 +24,7 @@ public class VectorTransferFloatTests extends TestCase { public void testTransfer_whenCalled_thenAdded() { final BytesRef bytesRef1 = getByteArrayOfVectors(20); final BytesRef bytesRef2 = getByteArrayOfVectors(20); - VectorTransferFloat vectorTransfer = new VectorTransferFloat(1000); + VectorTransferFloat vectorTransfer = new VectorTransferFloat(160); try { vectorTransfer.init(2);