From e3e057bdc74ecf599b85cd69759c0633a8608f76 Mon Sep 17 00:00:00 2001 From: Jay Deng Date: Tue, 18 Feb 2025 14:57:43 -0800 Subject: [PATCH] Introduce RemoteIndexBuildStrategy, refactor NativeIndexBuildStrategy to accept vector value supplier Signed-off-by: Jay Deng --- .../common/featureflags/KNNFeatureFlags.java | 20 +- .../org/opensearch/knn/index/KNNSettings.java | 31 ++- .../codec/BasePerFieldKnnVectorsFormat.java | 31 ++- .../KNN80Codec/KNN80DocValuesConsumer.java | 14 +- .../KNN9120PerFieldKnnVectorsFormat.java | 12 +- .../NativeEngines990KnnVectorsFormat.java | 24 ++- .../NativeEngines990KnnVectorsWriter.java | 79 ++++---- .../knn/index/codec/KNNCodecService.java | 11 +- .../knn/index/codec/KNNCodecVersion.java | 25 +-- .../DefaultIndexBuildStrategy.java | 4 +- .../MemOptimizedNativeIndexBuildStrategy.java | 4 +- .../codec/nativeindex/NativeIndexWriter.java | 59 ++++-- .../nativeindex/model/BuildIndexParams.java | 5 +- .../remote/RemoteIndexBuildStrategy.java | 169 ++++++++++++++++ .../knn/index/engine/KNNEngine.java | 5 + .../knn/index/engine/KNNLibrary.java | 8 + .../knn/index/engine/faiss/Faiss.java | 5 + .../vectorvalues/KNNVectorValuesFactory.java | 91 ++++++++- .../org/opensearch/knn/plugin/KNNPlugin.java | 70 +++---- ...eEngines990KnnVectorsWriterFlushTests.java | 184 ++++++++---------- ...eEngines990KnnVectorsWriterMergeTests.java | 50 +++-- .../knn/index/codec/KNNCodecServiceTests.java | 8 +- .../DefaultIndexBuildStrategyTests.java | 6 +- ...ptimizedNativeIndexBuildStrategyTests.java | 4 +- .../org/opensearch/knn/KNNRestTestCase.java | 81 ++++---- 25 files changed, 703 insertions(+), 297 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java diff --git a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java index bab5b97bb1..328e47d8cf 100644 --- a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java +++ b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java @@ -26,6 +26,7 @@ public class KNNFeatureFlags { // Feature flags private static final String KNN_FORCE_EVICT_CACHE_ENABLED = "knn.feature.cache.force_evict.enabled"; + private static final String KNN_REMOTE_VECTOR_BUILD = "knn.feature.remote_index_build.enabled"; @VisibleForTesting public static final Setting KNN_FORCE_EVICT_CACHE_ENABLED_SETTING = Setting.boolSetting( @@ -35,8 +36,18 @@ public class KNNFeatureFlags { Dynamic ); + /** + * Feature flag to control remote index build at the cluster level + */ + public static final Setting KNN_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting( + KNN_REMOTE_VECTOR_BUILD, + false, + NodeScope, + Dynamic + ); + public static List> getFeatureFlags() { - return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING); + return ImmutableList.of(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING, KNN_REMOTE_VECTOR_BUILD_SETTING); } /** @@ -46,4 +57,11 @@ public static List> getFeatureFlags() { public static boolean isForceEvictCacheEnabled() { return Booleans.parseBoolean(KNNSettings.state().getSettingValue(KNN_FORCE_EVICT_CACHE_ENABLED).toString(), false); } + + /** + * @return true if remote vector index build feature flag is enabled + */ + public static boolean isKNNRemoteVectorBuildEnabled() { + return Booleans.parseBooleanStrict(KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_BUILD).toString(), false); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index ebb55ea2ba..c33f3ea63c 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -41,9 +41,9 @@ import static java.util.stream.Collectors.toUnmodifiableMap; import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.Final; import static org.opensearch.common.settings.Setting.Property.IndexScope; import static org.opensearch.common.settings.Setting.Property.NodeScope; -import static org.opensearch.common.settings.Setting.Property.Final; import static org.opensearch.common.settings.Setting.Property.UnmodifiableOnRestore; import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio; import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; @@ -94,6 +94,8 @@ public class KNNSettings { public static final String KNN_FAISS_AVX512_SPR_DISABLED = "knn.faiss.avx512_spr.disabled"; public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled"; public static final String KNN_DERIVED_SOURCE_ENABLED = "index.knn.derived_source.enabled"; + public static final String KNN_INDEX_REMOTE_VECTOR_BUILD = "index.knn.remote_index_build.enabled"; + public static final String KNN_REMOTE_VECTOR_REPO = "knn.remote_index_build.vector_repo"; /** * Default setting values @@ -371,6 +373,21 @@ public class KNNSettings { NodeScope ); + /** + * Index level setting to control whether remote index build is enabled or not. + */ + public static final Setting KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING = Setting.boolSetting( + KNN_INDEX_REMOTE_VECTOR_BUILD, + false, + Dynamic, + IndexScope + ); + + /** + * Cluster level setting which indicates the repository that the remote index build should write to. + */ + public static final Setting KNN_REMOTE_VECTOR_REPO_SETTING = Setting.simpleString(KNN_REMOTE_VECTOR_REPO, Dynamic, NodeScope); + /** * Dynamic settings */ @@ -525,6 +542,14 @@ private Setting getSetting(String key) { return KNN_DERIVED_SOURCE_ENABLED_SETTING; } + if (KNN_INDEX_REMOTE_VECTOR_BUILD.equals(key)) { + return KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; + } + + if (KNN_REMOTE_VECTOR_REPO.equals(key)) { + return KNN_REMOTE_VECTOR_REPO_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -550,7 +575,9 @@ public List> getSettings() { QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, - KNN_DERIVED_SOURCE_ENABLED_SETTING + KNN_DERIVED_SOURCE_ENABLED_SETTING, + KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING, + KNN_REMOTE_VECTOR_REPO_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index f3a1258387..f193ce4c64 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -21,6 +21,7 @@ import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; +import org.opensearch.repositories.RepositoriesService; import java.util.Map; import java.util.Optional; @@ -44,6 +45,7 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor private final Supplier defaultFormatSupplier; private final Function vectorsFormatSupplier; private Function scalarQuantizedVectorsFormatSupplier; + private final Supplier repositoriesServiceSupplier; private static final String MAX_CONNECTIONS = "max_connections"; private static final String BEAM_WIDTH = "beam_width"; @@ -54,11 +56,26 @@ public BasePerFieldKnnVectorsFormat( Supplier defaultFormatSupplier, Function vectorsFormatSupplier ) { - this.mapperService = mapperService; - this.defaultMaxConnections = defaultMaxConnections; - this.defaultBeamWidth = defaultBeamWidth; - this.defaultFormatSupplier = defaultFormatSupplier; - this.vectorsFormatSupplier = vectorsFormatSupplier; + this(mapperService, defaultMaxConnections, defaultBeamWidth, defaultFormatSupplier, vectorsFormatSupplier, null); + } + + public BasePerFieldKnnVectorsFormat( + Optional mapperService, + int defaultMaxConnections, + int defaultBeamWidth, + Supplier defaultFormatSupplier, + Function vectorsFormatSupplier, + Function scalarQuantizedVectorsFormatSupplier + ) { + this( + mapperService, + defaultMaxConnections, + defaultBeamWidth, + defaultFormatSupplier, + vectorsFormatSupplier, + scalarQuantizedVectorsFormatSupplier, + null + ); } @Override @@ -141,7 +158,9 @@ private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { int approximateThreshold = getApproximateThresholdValue(); return new NativeEngines990KnnVectorsFormat( new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), - approximateThreshold + approximateThreshold, + repositoriesServiceSupplier, + mapperService.get().getIndexSettings() ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 443b12b9c4..6b8e52eb74 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -6,11 +6,6 @@ package org.opensearch.knn.index.codec.KNN80Codec; import lombok.extern.log4j.Log4j2; -import org.opensearch.common.StopWatch; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; -import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; @@ -19,8 +14,13 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.plugin.stats.KNNGraphValue; import java.io.IOException; @@ -72,9 +72,9 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, // For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total // live docs if (isMerge) { - NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs()); + NativeIndexWriter.getWriter(field, state).mergeIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs()); } else { - NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs()); + NativeIndexWriter.getWriter(field, state).flushIndex(() -> knnVectorValues, (int) knnVectorValues.totalLiveDocs()); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java index afebae2e6f..f2563ad1e9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120PerFieldKnnVectorsFormat.java @@ -13,10 +13,12 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.repositories.RepositoriesService; import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.function.Supplier; /** * Class provides per field format implementation for Lucene Knn vector type @@ -25,6 +27,13 @@ public class KNN9120PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsForma private static final Tuple DEFAULT_MERGE_THREAD_COUNT_AND_EXECUTOR_SERVICE = Tuple.tuple(1, null); public KNN9120PerFieldKnnVectorsFormat(final Optional mapperService) { + this(mapperService, null); + } + + public KNN9120PerFieldKnnVectorsFormat( + final Optional mapperService, + Supplier repositoriesServiceSupplier + ) { super( mapperService, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, @@ -67,7 +76,8 @@ public KNN9120PerFieldKnnVectorsFormat(final Optional mapperServi // Executor service mergeThreadCountAndExecutorService.v2() ); - } + }, + repositoriesServiceSupplier ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 17304c1462..695f336090 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -19,10 +19,13 @@ import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.index.IndexSettings; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.repositories.RepositoriesService; import java.io.IOException; +import java.util.function.Supplier; /** * This is a Vector format that will be used for Native engines like Faiss and Nmslib for reading and writing vector @@ -33,6 +36,8 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; private static int approximateThreshold; + private final Supplier repositoriesServiceSupplier; + private final IndexSettings indexSettings; public NativeEngines990KnnVectorsFormat() { this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); @@ -47,9 +52,20 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma } public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold) { + this(flatVectorsFormat, approximateThreshold, null, null); + } + + public NativeEngines990KnnVectorsFormat( + final FlatVectorsFormat flatVectorsFormat, + int approximateThreshold, + Supplier repositoriesServiceSupplier, + IndexSettings indexSettings + ) { super(FORMAT_NAME); NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold; + this.repositoriesServiceSupplier = repositoriesServiceSupplier; + this.indexSettings = indexSettings; } /** @@ -59,7 +75,13 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma */ @Override public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { - return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold); + return new NativeEngines990KnnVectorsWriter( + state, + flatVectorsFormat.fieldsWriter(state), + approximateThreshold, + repositoriesServiceSupplier, + indexSettings + ); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index fb93bfc073..be8c5ce3b2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -15,9 +15,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -25,6 +23,7 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; +import org.opensearch.index.IndexSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.quantizationservice.QuantizationService; @@ -32,6 +31,7 @@ import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.repositories.RepositoriesService; import java.io.IOException; import java.util.ArrayList; @@ -39,7 +39,8 @@ import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; -import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues; +import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge; +import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValuesSupplier; /** * A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines. @@ -54,15 +55,29 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final List> fields = new ArrayList<>(); private boolean finished; private final Integer approximateThreshold; + private final Supplier repositoriesServiceSupplier; + private final IndexSettings indexSettings; public NativeEngines990KnnVectorsWriter( SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter, Integer approximateThreshold + ) { + this(segmentWriteState, flatVectorsWriter, approximateThreshold, null, null); + } + + public NativeEngines990KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer approximateThreshold, + Supplier repositoriesServiceSupplier, + IndexSettings indexSettings ) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; this.approximateThreshold = approximateThreshold; + this.repositoriesServiceSupplier = repositoriesServiceSupplier; + this.indexSettings = indexSettings; } /** @@ -98,7 +113,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { log.debug("[Flush] No live docs for field {}", fieldInfo.getName()); continue; } - final Supplier> knnVectorValuesSupplier = () -> getVectorValues( + final Supplier> knnVectorValuesSupplier = getVectorValuesSupplier( vectorDataType, field.getFlatFieldVectorsWriter().getDocsWithFieldSet(), field.getVectors() @@ -114,11 +129,16 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { ); continue; } - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + final NativeIndexWriter writer = NativeIndexWriter.getWriter( + fieldInfo, + segmentWriteState, + quantizationState, + repositoriesServiceSupplier, + indexSettings + ); StopWatch stopWatch = new StopWatch().start(); - writer.flushIndex(knnVectorValues, totalLiveDocs); + writer.flushIndex(knnVectorValuesSupplier, totalLiveDocs); long time_in_millis = stopWatch.stop().totalTime().millis(); KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName()); @@ -131,7 +151,7 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState flatVectorsWriter.mergeOneField(fieldInfo, mergeState); final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - final Supplier> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge( + final Supplier> knnVectorValuesSupplier = getKNNVectorValuesSupplierForMerge( vectorDataType, fieldInfo, mergeState @@ -153,12 +173,17 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState ); return; } - final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); - final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); + final NativeIndexWriter writer = NativeIndexWriter.getWriter( + fieldInfo, + segmentWriteState, + quantizationState, + repositoriesServiceSupplier, + indexSettings + ); StopWatch stopWatch = new StopWatch().start(); - writer.mergeIndex(knnVectorValues, totalLiveDocs); + writer.mergeIndex(knnVectorValuesSupplier, totalLiveDocs); long time_in_millis = stopWatch.stop().totalTime().millis(); KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis); @@ -211,38 +236,6 @@ public long ramBytesUsed() { .sum(); } - /** - * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. - * - * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. - * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. - * @param mergeState The {@link MergeState} representing the state of the merge operation. - * @param The type of vectors being processed. - * @return The {@link KNNVectorValues} associated with the field during the merge. - * @throws IOException If an I/O error occurs during the retrieval. - */ - private KNNVectorValues getKNNVectorValuesForMerge( - final VectorDataType vectorDataType, - final FieldInfo fieldInfo, - final MergeState mergeState - ) { - try { - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedFloats); - case BYTE: - ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - return getVectorValues(vectorDataType, mergedBytes); - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); - } - } catch (final IOException e) { - log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); - throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); - } - } - private QuantizationState train( final FieldInfo fieldInfo, final Supplier> knnVectorValuesSupplier, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java index 9e210fcd93..01bcf35195 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecService.java @@ -5,10 +5,13 @@ package org.opensearch.knn.index.codec; -import org.opensearch.index.codec.CodecServiceConfig; import org.apache.lucene.codecs.Codec; import org.opensearch.index.codec.CodecService; +import org.opensearch.index.codec.CodecServiceConfig; import org.opensearch.index.mapper.MapperService; +import org.opensearch.repositories.RepositoriesService; + +import java.util.function.Supplier; /** * KNNCodecService to inject the right KNNCodec version @@ -16,10 +19,12 @@ public class KNNCodecService extends CodecService { private final MapperService mapperService; + private final Supplier repositoriesServiceSupplier; - public KNNCodecService(CodecServiceConfig codecServiceConfig) { + public KNNCodecService(CodecServiceConfig codecServiceConfig, Supplier repositoriesServiceSupplier) { super(codecServiceConfig.getMapperService(), codecServiceConfig.getIndexSettings(), codecServiceConfig.getLogger()); mapperService = codecServiceConfig.getMapperService(); + this.repositoriesServiceSupplier = repositoriesServiceSupplier; } /** @@ -30,6 +35,6 @@ public KNNCodecService(CodecServiceConfig codecServiceConfig) { */ @Override public Codec codec(String name) { - return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService); + return KNNCodecVersion.current().getKnnCodecSupplier().apply(super.codec(name), mapperService, repositoriesServiceSupplier); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 0f03170c25..b2b7a7b461 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -8,14 +8,15 @@ import lombok.AllArgsConstructor; import lombok.Getter; import org.apache.lucene.backward_codecs.lucene91.Lucene91Codec; +import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; import org.apache.lucene.backward_codecs.lucene92.Lucene92Codec; import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec; -import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; -import org.apache.lucene.backward_codecs.lucene912.Lucene912Codec; +import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.common.TriFunction; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNN10010Codec.KNN10010Codec; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; @@ -31,9 +32,9 @@ import org.opensearch.knn.index.codec.KNN950Codec.KNN950PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec; import org.opensearch.knn.index.codec.KNN990Codec.KNN990PerFieldKnnVectorsFormat; +import org.opensearch.repositories.RepositoriesService; import java.util.Optional; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -53,7 +54,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> new KNN910Codec(userCodec), + (userCodec, mapperService, remoteIndexBuilder) -> new KNN910Codec(userCodec), KNN910Codec::new ), @@ -65,7 +66,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN920Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN920Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN920PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -80,7 +81,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN940Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN940Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN940PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -95,7 +96,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN950Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN950Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN950PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -110,7 +111,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN990Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN990Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), @@ -125,7 +126,7 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN9120Codec.builder() + (userCodec, mapperService, remoteIndexBuilder) -> KNN9120Codec.builder() .delegate(userCodec) .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .mapperService(mapperService) @@ -140,9 +141,9 @@ public enum KNNCodecVersion { new KNN80DocValuesFormat(delegate.docValuesFormat()), new KNN80CompoundFormat(delegate.compoundFormat()) ), - (userCodec, mapperService) -> KNN10010Codec.builder() + (userCodec, mapperService, repositoriesServiceSupplier) -> KNN10010Codec.builder() .delegate(userCodec) - .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .knnVectorsFormat(new KNN9120PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService), repositoriesServiceSupplier)) .mapperService(mapperService) .build(), KNN10010Codec::new @@ -154,7 +155,7 @@ public enum KNNCodecVersion { private final Codec defaultCodecDelegate; private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; private final Function knnFormatFacadeSupplier; - private final BiFunction knnCodecSupplier; + private final TriFunction, Codec> knnCodecSupplier; private final Supplier defaultKnnCodecSupplier; public static final KNNCodecVersion current() { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 15d38a0791..9c31419c71 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -23,8 +23,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; /** * Transfers all vectors to off heap and then builds an index @@ -50,7 +50,7 @@ public static DefaultIndexBuildStrategy getInstance() { * @throws IOException If an I/O error occurs during the process of building and writing the index. */ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { - final KNNVectorValues knnVectorValues = indexInfo.getVectorValues(); + final KNNVectorValues knnVectorValues = indexInfo.getKnnVectorValuesSupplier().get(); // Needed to make sure we don't get 0 dimensions while initializing index initializeVectorValues(knnVectorValues); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 2864be6d2c..689c090eb2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -22,8 +22,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.KNNVectorUtil.intListToArray; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; /** * Iteratively builds the index. Iterative builds are memory optimized as it does not require all vectors @@ -51,7 +51,7 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { * @throws IOException If an I/O error occurs during the process of building and writing the index. */ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { - final KNNVectorValues knnVectorValues = indexInfo.getVectorValues(); + final KNNVectorValues knnVectorValues = indexInfo.getKnnVectorValuesSupplier().get(); // Needed to make sure we don't get 0 dimensions while initializing index initializeVectorValues(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index de535c39e8..48288a3074 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -17,12 +17,14 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexSettings; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.codec.nativeindex.remote.RemoteIndexBuildStrategy; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.quantizationservice.QuantizationService; @@ -33,18 +35,20 @@ import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.plugin.stats.KNNGraphValue; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.repositories.RepositoriesService; import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.function.Supplier; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; /** @@ -68,7 +72,7 @@ public class NativeIndexWriter { * @return correct NativeIndexWriter to make index specified in fieldInfo */ public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { - return createWriter(fieldInfo, state, null); + return createWriter(fieldInfo, state, null, null, null); } /** @@ -88,29 +92,31 @@ public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWrit public static NativeIndexWriter getWriter( final FieldInfo fieldInfo, final SegmentWriteState state, - final QuantizationState quantizationState + final QuantizationState quantizationState, + final Supplier repositoriesServiceSupplier, + final IndexSettings indexSettings ) { - return createWriter(fieldInfo, state, quantizationState); + return createWriter(fieldInfo, state, quantizationState, repositoriesServiceSupplier, indexSettings); } /** * flushes the index * - * @param knnVectorValues + * @param knnVectorValuesSupplier * @throws IOException */ - public void flushIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { - initializeVectorValues(knnVectorValues); - buildAndWriteIndex(knnVectorValues, totalLiveDocs); + public void flushIndex(final Supplier> knnVectorValuesSupplier, int totalLiveDocs) throws IOException { + buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs); recordRefreshStats(); } /** * Merges kNN index - * @param knnVectorValues + * @param knnVectorValuesSupplier * @throws IOException */ - public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + public void mergeIndex(final Supplier> knnVectorValuesSupplier, int totalLiveDocs) throws IOException { + KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); initializeVectorValues(knnVectorValues); if (knnVectorValues.docId() == NO_MORE_DOCS) { // This is in place so we do not add metrics @@ -120,11 +126,11 @@ public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDo long bytesPerVector = knnVectorValues.bytesPerVector(); startMergeStats(totalLiveDocs, bytesPerVector); - buildAndWriteIndex(knnVectorValues, totalLiveDocs); + buildAndWriteIndex(knnVectorValuesSupplier, totalLiveDocs); endMergeStats(totalLiveDocs, bytesPerVector); } - private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + private void buildAndWriteIndex(final Supplier> knnVectorValuesSupplier, int totalLiveDocs) throws IOException { if (totalLiveDocs == 0) { log.debug("No live docs for field {}", fieldInfo.name); return; @@ -143,7 +149,7 @@ private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int to fieldInfo, indexOutputWithBuffer, knnEngine, - knnVectorValues, + knnVectorValuesSupplier, totalLiveDocs ); indexBuilder.buildAndWriteIndex(nativeIndexParams); @@ -158,7 +164,7 @@ private BuildIndexParams indexParams( FieldInfo fieldInfo, IndexOutputWithBuffer indexOutputWithBuffer, KNNEngine knnEngine, - KNNVectorValues vectorValues, + Supplier> knnVectorValuesSupplier, int totalLiveDocs ) throws IOException { final Map parameters; @@ -182,8 +188,9 @@ private BuildIndexParams indexParams( .knnEngine(knnEngine) .indexOutputWithBuffer(indexOutputWithBuffer) .quantizationState(quantizationState) - .vectorValues(vectorValues) + .knnVectorValuesSupplier(knnVectorValuesSupplier) .totalLiveDocs(totalLiveDocs) + .segmentWriteState(state) .build(); } @@ -304,12 +311,16 @@ private void recordRefreshStats() { * @param fieldInfo The FieldInfo object containing metadata about the field for which the writer is needed. * @param state The SegmentWriteState representing the current segment's writing context. * @param quantizationState The QuantizationState that contains quantization state required for quantization, can be null. + * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService}, intended to be used by {@link RemoteIndexBuildStrategy} + * @param indexSettings * @return A NativeIndexWriter instance appropriate for the specified field, configured with or without quantization. */ private static NativeIndexWriter createWriter( final FieldInfo fieldInfo, final SegmentWriteState state, - @Nullable final QuantizationState quantizationState + @Nullable final QuantizationState quantizationState, + final Supplier repositoriesServiceSupplier, + final IndexSettings indexSettings ) { final KNNEngine knnEngine = extractKNNEngine(fieldInfo); boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); @@ -317,6 +328,20 @@ private static NativeIndexWriter createWriter( NativeIndexBuildStrategy strategy = iterative ? MemOptimizedNativeIndexBuildStrategy.getInstance() : DefaultIndexBuildStrategy.getInstance(); - return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState); + + // TODO: Add threshold checks for if/when we should use remote index build strategy + if (knnEngine.supportsRemoteIndexBuild() + && repositoriesServiceSupplier != null + && RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings)) { + return new NativeIndexWriter( + state, + fieldInfo, + new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy), + quantizationState + ); + } else { + return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState); + } + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java index 36e874c43f..cf5d5be70b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.ToString; import lombok.Value; +import org.apache.lucene.index.SegmentWriteState; import org.opensearch.common.Nullable; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -16,6 +17,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.util.Map; +import java.util.function.Supplier; @Value @Builder @@ -31,6 +33,7 @@ public class BuildIndexParams { */ @Nullable QuantizationState quantizationState; - KNNVectorValues vectorValues; + Supplier> knnVectorValuesSupplier; int totalLiveDocs; + SegmentWriteState segmentWriteState; } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java new file mode 100644 index 0000000000..8555e2ad68 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/remote/RemoteIndexBuildStrategy.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex.remote; + +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.NotImplementedException; +import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.common.StopWatch; +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.index.IndexSettings; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.repositories.Repository; +import org.opensearch.repositories.RepositoryMissingException; +import org.opensearch.repositories.blobstore.BlobStoreRepository; + +import java.io.IOException; +import java.util.function.Supplier; + +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING; +import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING; + +/** + * This class orchestrates building vector indices. It handles uploading data to a repository, submitting a remote + * build request, awaiting upon the build request to complete, and finally downloading the data from a repository. + */ +@Log4j2 +@ExperimentalApi +public class RemoteIndexBuildStrategy implements NativeIndexBuildStrategy { + + private final Supplier repositoriesServiceSupplier; + private final NativeIndexBuildStrategy fallbackStrategy; + private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec"; + private static final String DOC_ID_FILE_EXTENSION = ".knndid"; + + /** + * Public constructor + * + * @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used for interacting with repository + */ + public RemoteIndexBuildStrategy(Supplier repositoriesServiceSupplier, NativeIndexBuildStrategy fallbackStrategy) { + this.repositoriesServiceSupplier = repositoriesServiceSupplier; + this.fallbackStrategy = fallbackStrategy; + } + + /** + * @return whether to use the remote build feature + */ + public static boolean shouldBuildIndexRemotely(IndexSettings indexSettings) { + String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); + return KNNFeatureFlags.isKNNRemoteVectorBuildEnabled() + && indexSettings.getValue(KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING) + && vectorRepo != null + && !vectorRepo.isEmpty(); + } + + /** + * Entry point for flush/merge operations. This method orchestrates the following: + * 1. Writes required data to repository + * 2. Triggers index build + * 3. Awaits on vector build to complete + * 4. Downloads index file and writes to indexOutput + * + * @param indexInfo + * @throws IOException + */ + @Override + public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException { + // TODO: Metrics Collection + StopWatch stopWatch; + long time_in_millis; + try { + stopWatch = new StopWatch().start(); + writeToRepository( + indexInfo.getFieldName(), + indexInfo.getKnnVectorValuesSupplier(), + indexInfo.getTotalLiveDocs(), + indexInfo.getSegmentWriteState() + ); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Repository write took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + + stopWatch = new StopWatch().start(); + submitVectorBuild(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Submit vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + + stopWatch = new StopWatch().start(); + awaitVectorBuild(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + + stopWatch = new StopWatch().start(); + readFromRepository(); + time_in_millis = stopWatch.stop().totalTime().millis(); + log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName()); + } catch (Exception e) { + // TODO: This needs more robust failure handling + log.warn("Failed to build index remotely", e); + fallbackStrategy.buildAndWriteIndex(indexInfo); + } + } + + /** + * Gets the KNN repository container from the repository service. + * + * @return {@link RepositoriesService} + * @throws RepositoryMissingException if repository is not registered or if {@link KNN_REMOTE_VECTOR_REPO_SETTING} is not set + */ + private BlobStoreRepository getRepository() throws RepositoryMissingException { + RepositoriesService repositoriesService = repositoriesServiceSupplier.get(); + assert repositoriesService != null; + String vectorRepo = KNNSettings.state().getSettingValue(KNN_REMOTE_VECTOR_REPO_SETTING.getKey()); + if (vectorRepo == null || vectorRepo.isEmpty()) { + throw new RepositoryMissingException("Vector repository " + KNN_REMOTE_VECTOR_REPO_SETTING.getKey() + " is not registered"); + } + final Repository repository = repositoriesService.repository(vectorRepo); + assert repository instanceof BlobStoreRepository : "Repository should be instance of BlobStoreRepository"; + return (BlobStoreRepository) repository; + } + + /** + * Write relevant vector data to repository + * + * @param fieldName + * @param knnVectorValuesSupplier + * @param totalLiveDocs + * @param segmentWriteState + * @throws IOException + * @throws InterruptedException + */ + private void writeToRepository( + String fieldName, + Supplier> knnVectorValuesSupplier, + int totalLiveDocs, + SegmentWriteState segmentWriteState + ) throws IOException, InterruptedException { + throw new NotImplementedException(); + } + + /** + * Submit vector build request to remote vector build service + * + */ + private void submitVectorBuild() { + throw new NotImplementedException(); + } + + /** + * Wait on remote vector build to complete + */ + private void awaitVectorBuild() { + throw new NotImplementedException(); + } + + /** + * Read constructed vector file from remote repository and write to IndexOutput + */ + private void readFromRepository() { + throw new NotImplementedException(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index 1e560a11ba..0bd4b0f27a 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -211,4 +211,9 @@ public ResolvedMethodContext resolveMethod( ) { return knnLibrary.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); } + + @Override + public boolean supportsRemoteIndexBuild() { + return knnLibrary.supportsRemoteIndexBuild(); + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java index cf7c4ad82f..29e6442f48 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java @@ -139,4 +139,12 @@ KNNLibraryIndexingContext getKNNLibraryIndexingContext( default List mmapFileExtensions() { return Collections.emptyList(); } + + /** + * Returns whether or not the engine implementation supports remote index build + * @return true if remote index build is supported, false otherwise + */ + default boolean supportsRemoteIndexBuild() { + return false; + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java index 5a02582798..d23a475aa7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java @@ -118,4 +118,9 @@ public ResolvedMethodContext resolveMethod( ) { return methodResolver.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType); } + + @Override + public boolean supportsRemoteIndexBuild() { + return true; + } } diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java index 835425b2a2..e52332435d 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValuesFactory.java @@ -5,26 +5,35 @@ package org.opensearch.knn.index.vectorvalues; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.MergeState; import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import java.util.Map; +import java.util.function.Supplier; /** * A factory class that provides various methods to create the {@link KNNVectorValues}. */ public final class KNNVectorValuesFactory { + private static final Logger log = LogManager.getLogger(KNNVectorValuesFactory.class); + /** * Returns a {@link KNNVectorValues} for the given {@link DocIdSetIterator} and {@link VectorDataType} * @@ -36,6 +45,21 @@ public static KNNVectorValues getVectorValues(final VectorDataType vector return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(knnVectorValues)); } + /** + * Returns a {@link Supplier} for {@link #getVectorValues(VectorDataType, KnnVectorValues)} + * Note: This class is public static so that it can be mocked for testing. + * + * @param vectorDataType {@link VectorDataType} + * @param knnVectorValues {@link KnnVectorValues} + * @return {@link KNNVectorValues} + */ + public static Supplier> getVectorValuesSupplier( + final VectorDataType vectorDataType, + final KnnVectorValues knnVectorValues + ) { + return () -> getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(knnVectorValues)); + } + public static KNNVectorValues getVectorValues(final VectorDataType vectorDataType, final DocIdSetIterator docIdSetIterator) { return getVectorValues(vectorDataType, new KNNVectorValuesIterator.DocIdsIteratorValues(docIdSetIterator)); } @@ -55,6 +79,22 @@ public static KNNVectorValues getVectorValues( return getVectorValues(vectorDataType, new KNNVectorValuesIterator.FieldWriterIteratorValues(docIdWithFieldSet, vectors)); } + /** + * Returns a {@link Supplier} for {@link #getVectorValuesSupplier(VectorDataType, DocsWithFieldSet, Map)}. + * Note: This class is public static so that it can be mocked for testing. + * + * @param vectorDataType {@link VectorDataType} + * @param docIdWithFieldSet {@link DocsWithFieldSet} + * @return {@link KNNVectorValues} + */ + public static Supplier> getVectorValuesSupplier( + final VectorDataType vectorDataType, + final DocsWithFieldSet docIdWithFieldSet, + final Map vectors + ) { + return () -> getVectorValues(vectorDataType, docIdWithFieldSet, vectors); + } + /** * Returns a {@link KNNVectorValues} for the given {@link FieldInfo} and {@link LeafReader} * @@ -135,4 +175,53 @@ private static KNNVectorValues getVectorValues( } throw new IllegalArgumentException("Invalid Vector data type provided, hence cannot return VectorValues"); } + + /** + * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. + * + * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. + * @param mergeState The {@link org.apache.lucene.index.MergeState} representing the state of the merge operation. + * @param The type of vectors being processed. + * @return The {@link KNNVectorValues} associated with the field during the merge. + * @throws IOException If an I/O error occurs during the retrieval. + */ + private static KNNVectorValues getKNNVectorValuesForMerge( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) { + try { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + FloatVectorValues mergedFloats = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + return getVectorValues(vectorDataType, mergedFloats); + case BYTE: + ByteVectorValues mergedBytes = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + return getVectorValues(vectorDataType, mergedBytes); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } catch (final IOException e) { + log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e); + throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e); + } + } + + /** + * Returns a {@link Supplier} for {@link #getKNNVectorValuesForMerge(VectorDataType, FieldInfo, MergeState)}. + * Note: This class is public static so that it can be mocked for testing. + * + * @param vectorDataType + * @param fieldInfo + * @param mergeState + * @return + */ + public static Supplier> getKNNVectorValuesSupplierForMerge( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) { + return () -> getKNNVectorValuesForMerge(vectorDataType, fieldInfo, mergeState); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 44c824862a..cd77ee6ddb 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -5,37 +5,55 @@ package org.opensearch.knn.plugin; +import com.google.common.collect.ImmutableList; +import org.opensearch.action.ActionRequest; import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.IndexModule; +import org.opensearch.index.IndexSettings; import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; +import org.opensearch.index.mapper.Mapper; import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; -import org.opensearch.knn.index.util.KNNClusterUtil; -import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; - -import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; -import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.codec.KNNCodecService; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; -import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; +import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelGraveyard; +import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; import org.opensearch.knn.plugin.rest.RestGetModelHandler; import org.opensearch.knn.plugin.rest.RestKNNStatsHandler; import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.rest.RestSearchModelHandler; import org.opensearch.knn.plugin.rest.RestTrainModelHandler; -import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; +import org.opensearch.knn.plugin.search.KNNConcurrentSearchRequestDecider; import org.opensearch.knn.plugin.stats.KNNStats; +import org.opensearch.knn.plugin.transport.ClearCacheAction; +import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; import org.opensearch.knn.plugin.transport.GetModelAction; @@ -44,27 +62,6 @@ import org.opensearch.knn.plugin.transport.KNNStatsTransportAction; import org.opensearch.knn.plugin.transport.KNNWarmupAction; import org.opensearch.knn.plugin.transport.KNNWarmupTransportAction; -import org.opensearch.knn.plugin.transport.ClearCacheAction; -import org.opensearch.knn.plugin.transport.ClearCacheTransportAction; -import com.google.common.collect.ImmutableList; - -import org.opensearch.action.ActionRequest; -import org.opensearch.transport.client.Client; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.IndexScopedSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.settings.SettingsFilter; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.env.Environment; -import org.opensearch.env.NodeEnvironment; -import org.opensearch.index.IndexModule; -import org.opensearch.index.IndexSettings; -import org.opensearch.index.mapper.Mapper; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheTransportAction; import org.opensearch.knn.plugin.transport.SearchModelAction; @@ -76,10 +73,10 @@ import org.opensearch.knn.plugin.transport.TrainingModelAction; import org.opensearch.knn.plugin.transport.TrainingModelRequest; import org.opensearch.knn.plugin.transport.TrainingModelTransportAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache; import org.opensearch.knn.training.TrainingJobClusterStateListener; import org.opensearch.knn.training.TrainingJobRunner; @@ -102,10 +99,11 @@ import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; import org.opensearch.watcher.ResourceWatcherService; -import java.util.Arrays; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -163,6 +161,7 @@ public class KNNPlugin extends Plugin private KNNStats knnStats; private ClusterService clusterService; + private Supplier repositoriesServiceSupplier; @Override public Map getMappers() { @@ -192,6 +191,7 @@ public Collection createComponents( Supplier repositoriesServiceSupplier ) { this.clusterService = clusterService; + this.repositoriesServiceSupplier = repositoriesServiceSupplier; // Initialize Native Memory loading strategies VectorReader vectorReader = new VectorReader(client); @@ -284,7 +284,7 @@ public Optional getEngineFactory(IndexSettings indexSettings) { @Override public Optional getCustomCodecServiceFactory(IndexSettings indexSettings) { if (indexSettings.getValue(KNNSettings.IS_KNN_INDEX_SETTING)) { - return Optional.of(KNNCodecService::new); + return Optional.of((config) -> new KNNCodecService(config, repositoriesServiceSupplier)); } return Optional.empty(); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index f87ed6bcf6..c44600ed1e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -40,7 +40,7 @@ import java.util.List; import java.util.Map; import java.util.function.Predicate; -import java.util.stream.Collectors; +import java.util.function.Supplier; import java.util.stream.IntStream; import static com.carrotsearch.randomizedtesting.RandomizedTest.$; @@ -122,16 +122,12 @@ public static Collection data() { @SneakyThrows public void testFlush() { // Given - final List> expectedVectorValues = vectorsPerField.stream().map(vectors -> { + final List>> expectedVectorValuesSuppliers = vectorsPerField.stream().map(vectors -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectors.values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); - return knnVectorValues; - }).collect(Collectors.toList()); + return KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues); + }).toList(); try ( MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); @@ -163,11 +159,11 @@ public void testFlush() { DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); }); @@ -189,9 +185,9 @@ public void testFlush() { IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { if (vectorsPerField.get(i).isEmpty()) { - verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } else { - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } } catch (Exception e) { throw new RuntimeException(e); @@ -199,7 +195,7 @@ public void testFlush() { }); final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + () -> KNNVectorValuesFactory.getVectorValuesSupplier(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } @@ -208,17 +204,12 @@ public void testFlush() { @SneakyThrows public void testFlush_WithQuantization() { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSuppliers = new ArrayList<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); - expectedVectorValues.add(knnVectorValues); - + expectedVectorValuesSuppliers.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); try ( @@ -252,19 +243,25 @@ public void testFlush_WithQuantization() { DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train( + quantizationParams, + expectedVectorValuesSuppliers.get(i).get(), + vectorsPerField.get(i).size() + ) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState, null, null) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -287,10 +284,10 @@ public void testFlush_WithQuantization() { try { if (vectorsPerField.get(i).isEmpty()) { verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); - verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } else { verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } } catch (Exception e) { throw new RuntimeException(e); @@ -298,24 +295,20 @@ public void testFlush_WithQuantization() { }); final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + () -> KNNVectorValuesFactory.getVectorValuesSupplier(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } } public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSupplier = new ArrayList<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); - expectedVectorValues.add(knnVectorValues); + expectedVectorValuesSupplier.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); @@ -354,11 +347,11 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSupplier.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); }); @@ -381,18 +374,14 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriterIsNeverCalled() throws IOException { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSuppliers = new ArrayList<>(); final Map sizeMap = new HashMap<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); sizeMap.put(i, randomVectorValues.size()); - expectedVectorValues.add(knnVectorValues); + expectedVectorValuesSuppliers.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); @@ -431,11 +420,11 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); }); @@ -458,19 +447,14 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSuppliers = new ArrayList<>(); final Map sizeMap = new HashMap<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); sizeMap.put(i, randomVectorValues.size()); - expectedVectorValues.add(knnVectorValues); - + expectedVectorValuesSuppliers.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); final int minThreshold = sizeMap.values().stream().filter(count -> count != 0).min(Integer::compareTo).orElse(0); @@ -509,11 +493,11 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); }); @@ -534,9 +518,9 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { if (vectorsPerField.get(i).size() > 0) { - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } else { - verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } } catch (Exception e) { throw new RuntimeException(e); @@ -547,16 +531,12 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWriterIsCalled() throws IOException { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSuppliers = new ArrayList<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); - expectedVectorValues.add(knnVectorValues); + expectedVectorValuesSuppliers.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); final int threshold = 4; @@ -595,11 +575,11 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); }); @@ -619,9 +599,9 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { if (vectorsPerField.get(i).size() >= threshold) { - verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } else { - verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValuesSuppliers.get(i), vectorsPerField.get(i).size()); } } catch (Exception e) { throw new RuntimeException(e); @@ -632,18 +612,14 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenStillBuildGraph() throws IOException { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSuppliers = new ArrayList<>(); final Map sizeMap = new HashMap<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); sizeMap.put(i, randomVectorValues.size()); - expectedVectorValues.add(knnVectorValues); + expectedVectorValuesSuppliers.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); @@ -684,19 +660,25 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train( + quantizationParams, + expectedVectorValuesSuppliers.get(i).get(), + vectorsPerField.get(i).size() + ) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState, null, null) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -726,8 +708,8 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres }); final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + () -> KNNVectorValuesFactory.getVectorValuesSupplier(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } } @@ -735,18 +717,14 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenStillBuildGraph() throws IOException { // Given - List> expectedVectorValues = new ArrayList<>(); + List>> expectedVectorValuesSuppliers = new ArrayList<>(); final Map sizeMap = new HashMap<>(); IntStream.range(0, vectorsPerField.size()).forEach(i -> { final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( new ArrayList<>(vectorsPerField.get(i).values()) ); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - VectorDataType.FLOAT, - randomVectorValues - ); sizeMap.put(i, randomVectorValues.size()); - expectedVectorValues.add(knnVectorValues); + expectedVectorValuesSuppliers.add(KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, randomVectorValues)); }); final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( @@ -786,19 +764,25 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet(); knnVectorValuesFactoryMockedStatic.when( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) - ).thenReturn(expectedVectorValues.get(i)); + () -> KNNVectorValuesFactory.getVectorValuesSupplier(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValuesSuppliers.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train( + quantizationParams, + expectedVectorValuesSuppliers.get(i).get(), + vectorsPerField.get(i).size() + ) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState, null, null) + ).thenReturn(nativeIndexWriter); }); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -828,8 +812,8 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres }); final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + () -> KNNVectorValuesFactory.getVectorValuesSupplier(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index cdc372bda2..77e0dd7310 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -39,6 +39,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Map; +import java.util.function.Supplier; import static com.carrotsearch.randomizedtesting.RandomizedTest.$; import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; @@ -108,6 +109,7 @@ public void testMerge() { new ArrayList<>(mergedVectors.values()) ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final Supplier> knnVectorValuesSupplier = () -> knnVectorValues; try ( MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); @@ -135,11 +137,12 @@ public void testMerge() { mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValuesSupplier); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -153,11 +156,11 @@ public void testMerge() { verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); if (!mergedVectors.isEmpty()) { - verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + verify(nativeIndexWriter).mergeIndex(knnVectorValuesSupplier, mergedVectors.size()); assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), - times(2) + () -> KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(VectorDataType.FLOAT, fieldInfo, mergeState), + times(1) ); } else { verifyNoInteractions(nativeIndexWriter); @@ -171,6 +174,7 @@ public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled new ArrayList<>(mergedVectors.values()) ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final Supplier> knnVectorValuesSupplier = () -> knnVectorValues; final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( segmentWriteState, flatVectorsWriter, @@ -202,11 +206,12 @@ public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValuesSupplier); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -229,6 +234,7 @@ public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWrite new ArrayList<>(mergedVectors.values()) ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final Supplier> knnVectorValuesSupplier = () -> knnVectorValues; final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( segmentWriteState, flatVectorsWriter, @@ -260,11 +266,12 @@ public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWrite mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValuesSupplier); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null, null, null)) .thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion @@ -278,7 +285,7 @@ public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWrite verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); if (!mergedVectors.isEmpty()) { - verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + verify(nativeIndexWriter).mergeIndex(knnVectorValuesSupplier, mergedVectors.size()); } else { verifyNoInteractions(nativeIndexWriter); } @@ -292,6 +299,7 @@ public void testMerge_WithQuantization() { new ArrayList<>(mergedVectors.values()) ); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final Supplier> knnVectorValuesSupplier = () -> knnVectorValues; try ( MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); @@ -320,8 +328,9 @@ public void testMerge_WithQuantization() { mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); - knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) - .thenReturn(knnVectorValues); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(VectorDataType.FLOAT, fieldInfo, mergeState) + ).thenReturn(knnVectorValuesSupplier); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { @@ -330,8 +339,9 @@ public void testMerge_WithQuantization() { throw new RuntimeException(e); } - nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) - .thenReturn(nativeIndexWriter); + nativeIndexWriterMockedStatic.when( + () -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState, null, null) + ).thenReturn(nativeIndexWriter); doAnswer(answer -> { Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion return null; @@ -345,11 +355,11 @@ public void testMerge_WithQuantization() { if (!mergedVectors.isEmpty()) { verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(0, quantizationState); - verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + verify(nativeIndexWriter).mergeIndex(knnVectorValuesSupplier, mergedVectors.size()); assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); knnVectorValuesFactoryMockedStatic.verify( - () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), - times(3) + () -> KNNVectorValuesFactory.getKNNVectorValuesSupplierForMerge(VectorDataType.FLOAT, fieldInfo, mergeState), + times(1) ); } else { assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java index dfe4e7f22a..43c86b94fd 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecServiceTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec; +import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.Codec; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; @@ -13,8 +14,7 @@ import org.opensearch.index.codec.CodecServiceConfig; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; - -import org.apache.logging.log4j.Logger; +import org.opensearch.repositories.RepositoriesService; import java.util.UUID; @@ -46,7 +46,7 @@ public void testGetCodecByName() { MapperService mapperService = mock(MapperService.class); Logger loggerMock = mock(Logger.class); CodecServiceConfig codecServiceConfig = new CodecServiceConfig(indexSettings, mapperService, loggerMock); - KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig); + KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig, () -> mock(RepositoriesService.class)); Codec codec = knnCodecService.codec(KNNCodecVersion.current().getCodecName()); assertNotNull(codec); } @@ -61,7 +61,7 @@ public void testGetCodecByName() { public void testGetCodecByNameWithNoMapperService() { Logger loggerMock = mock(Logger.class); CodecServiceConfig codecServiceConfig = new CodecServiceConfig(indexSettings, null, loggerMock); - KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig); + KNNCodecService knnCodecService = new KNNCodecService(codecServiceConfig, () -> mock(RepositoriesService.class)); Codec codec = knnCodecService.codec(KNNCodecVersion.current().getCodecName()); assertNotNull(codec); } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 35c54f3b35..56c641e3c4 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -74,7 +74,7 @@ public void testBuildAndWrite() { .knnEngine(KNNEngine.NMSLIB) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) - .vectorValues(knnVectorValues) + .knnVectorValuesSupplier(() -> knnVectorValues) .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); @@ -169,7 +169,7 @@ public void testBuildAndWrite_withQuantization() { .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) .quantizationState(quantizationState) - .vectorValues(knnVectorValues) + .knnVectorValuesSupplier(() -> knnVectorValues) .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); @@ -254,7 +254,7 @@ public void testBuildAndWriteWithModel() { .knnEngine(KNNEngine.NMSLIB) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("model_id", "id", "model_blob", modelBlob)) - .vectorValues(knnVectorValues) + .knnVectorValuesSupplier(() -> knnVectorValues) .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 08942fe7f5..4a04b77926 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -73,7 +73,7 @@ public void testBuildAndWrite() { .knnEngine(KNNEngine.FAISS) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) - .vectorValues(knnVectorValues) + .knnVectorValuesSupplier(() -> knnVectorValues) .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); @@ -193,7 +193,7 @@ public void testBuildAndWrite_withQuantization() { .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) .quantizationState(quantizationState) - .vectorValues(knnVectorValues) + .knnVectorValuesSupplier(() -> knnVectorValues) .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index ed2ffc54a6..76aa2efc92 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -16,36 +16,37 @@ import org.apache.commons.lang.StringUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.net.URIBuilder; +import org.junit.AfterClass; +import org.junit.Before; import org.opensearch.Version; -import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.ExistsQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.common.featureflags.KNNFeatureFlags; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.derivedsource.ParentChildHelper; import org.opensearch.knn.index.query.KNNQueryBuilder; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; -import org.junit.AfterClass; -import org.junit.Before; -import org.opensearch.client.Request; -import org.opensearch.client.Response; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.MediaType; -import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; -import org.opensearch.core.rest.RestStatus; import org.opensearch.script.Script; import org.opensearch.search.SearchService; import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; @@ -75,43 +76,42 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import static org.opensearch.knn.TestUtils.FIELD; +import static org.opensearch.knn.TestUtils.INDEX_KNN; +import static org.opensearch.knn.TestUtils.KNN_VECTOR; +import static org.opensearch.knn.TestUtils.NUMBER_OF_REPLICAS; +import static org.opensearch.knn.TestUtils.NUMBER_OF_SHARDS; +import static org.opensearch.knn.TestUtils.PROPERTIES; +import static org.opensearch.knn.TestUtils.QUERY_VALUE; +import static org.opensearch.knn.TestUtils.VECTOR_TYPE; +import static org.opensearch.knn.TestUtils.computeGroundTruthValues; +import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; -import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; import static org.opensearch.knn.common.KNNConstants.NAME; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; -import static org.opensearch.knn.common.KNNConstants.CLEAR_CACHE; - -import static org.opensearch.knn.TestUtils.NUMBER_OF_REPLICAS; -import static org.opensearch.knn.TestUtils.NUMBER_OF_SHARDS; -import static org.opensearch.knn.TestUtils.INDEX_KNN; -import static org.opensearch.knn.TestUtils.PROPERTIES; -import static org.opensearch.knn.TestUtils.VECTOR_TYPE; -import static org.opensearch.knn.TestUtils.KNN_VECTOR; -import static org.opensearch.knn.TestUtils.FIELD; -import static org.opensearch.knn.TestUtils.QUERY_VALUE; -import static org.opensearch.knn.TestUtils.computeGroundTruthValues; - +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD; import static org.opensearch.knn.index.SpaceType.L2; -import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; import static org.opensearch.knn.index.engine.KNNEngine.FAISS; +import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; import static org.opensearch.knn.plugin.stats.StatNames.INDICES_IN_CACHE; /** @@ -170,6 +170,18 @@ public void cleanUpCache() throws Exception { clearCache(); } + /** + * Set up cluster settings for remote index build feature. We do this for all tests to ensure the fallback mechanisms are working correctly. + */ + @Before + public void setupRemoteIndexBuildSettings() throws Exception { + log.info("Setting remote index build settings"); + if (randomBoolean()) { + updateClusterSettings(KNNFeatureFlags.KNN_REMOTE_VECTOR_BUILD_SETTING.getKey(), true); + updateClusterSettings(KNNSettings.KNN_REMOTE_VECTOR_REPO, "integ-test-repo"); + } + } + /** * Gives the ability for certain, more exhaustive checks, to be disabled by default * @@ -971,6 +983,7 @@ protected Settings buildKNNIndexSettings(int approximateThreshold) { .put("number_of_replicas", 0) .put(KNN_INDEX, true) .put(INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, approximateThreshold) + .put(KNN_INDEX_REMOTE_VECTOR_BUILD, randomBoolean()) .build(); }