Skip to content

Commit

Permalink
Serialize all models into cluster metadata (#1499) (#1644)
Browse files Browse the repository at this point in the history
* Remove transport calls in TrainingJobRunner and TrainingJobClusterStateListener

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix tests

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add changelog

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix CMake Faiss bug

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Add state checks for existing cluster metadata calls

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Remove CMake bug fix

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix changelog

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Fix failing tests

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Refactor and add two more created state checks

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Rebase and fix new tests

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Refactor created checks and modify error messages

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

* Refactor cluster state listener transport calls

Signed-off-by: Ryan Bogan <rbogan@amazon.com>

---------

Signed-off-by: Ryan Bogan <rbogan@amazon.com>
(cherry picked from commit dc8eb6b)

Co-authored-by: Ryan Bogan <rbogan@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and ryanbogan authored Apr 23, 2024
1 parent ec9ddb4 commit d9d2d10
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
* Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613)
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
### Infrastructure
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;

import java.io.File;
Expand Down Expand Up @@ -199,8 +200,8 @@ public static ValidationException validateKnnField(
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, field));
if (!ModelUtil.isModelCreated(modelMetadata)) {
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" is not created.", modelId, field));
return exception;
}

Expand Down Expand Up @@ -286,4 +287,5 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod
}
return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.io.IOException;

Expand Down Expand Up @@ -50,10 +51,10 @@ protected void parseCreateField(ParseContext context) throws IOException {
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(
String.format(
"Model \"%s\" from %s's mapping does not exist. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
"Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
modelId,
context.mapperService().index().getName(),
MODEL_ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -548,8 +549,8 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new IllegalArgumentException(String.format("Model ID '%s' does not exist.", modelId));
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

Expand Down Expand Up @@ -213,8 +214,8 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new RuntimeException("Model \"" + modelId + "\" is not created.");
}

knnEngine = modelMetadata.getKnnEngine();
Expand Down
8 changes: 1 addition & 7 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
if (ModelState.CREATED.equals(model.getModelMetadata().getState())) {
onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
} else {
onIndexListener = onMetaListener;
}
ActionListener<IndexResponse> onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);

// Create the model index if it does not already exist
Runnable indexModelRunnable = () -> indexRequestBuilder.execute(onIndexListener);
Expand Down
30 changes: 30 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.indices;

/**
* A utility class for models.
*/
public class ModelUtil {

public static boolean isModelPresent(ModelMetadata modelMetadata) {
return modelMetadata != null;
}

public static boolean isModelCreated(ModelMetadata modelMetadata) {
if (!isModelPresent(modelMetadata)) {
return false;
}
return modelMetadata.getState().equals(ModelState.CREATED);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ protected void updateModelsNewCluster() throws IOException, InterruptedException
if (modelDao.isCreated()) {
List<String> modelIds = searchModelIds();
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = getModelMetadata(modelId);
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as cluster crashed");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as cluster crashed");
}
}
}
Expand All @@ -123,11 +122,10 @@ protected void updateModelsNodesRemoved(List<DiscoveryNode> removedNodes) throws
List<String> modelIds = searchModelIds();
for (DiscoveryNode removedNode : removedNodes) {
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = getModelMetadata(modelId);
if (modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId())
&& modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as node dropped");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as node dropped");
}
}
}
Expand Down Expand Up @@ -158,9 +156,11 @@ public void onFailure(Exception e) {
return modelIds;
}

private void updateModelStateAsFailed(Model model, String msg) throws IOException {
model.getModelMetadata().setState(ModelState.FAILED);
model.getModelMetadata().setError(msg);
private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException,
InterruptedException {
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(msg);
Model model = new Model(modelMetadata, null, modelId);
modelDao.update(model, new ActionListener<IndexResponse>() {
@Override
public void onResponse(IndexResponse indexResponse) {
Expand All @@ -173,4 +173,17 @@ public void onFailure(Exception e) {
}
});
}

private ModelMetadata getModelMetadata(String modelId) throws ExecutionException, InterruptedException {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
// On versions prior to 2.14, only models in created state are present in model metadata.
if (modelMetadata == null) {
log.info(
"Model metadata is null in cluster metadata. This can happen for models training on nodes prior to OpenSearch version 2.14.0. Fetching model information from system index."
);
Model model = modelDao.get(modelId);
return model.getModelMetadata();
}
return modelMetadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
Expand Down Expand Up @@ -683,6 +684,7 @@ public void testDoToQuery_FromModel() {
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -712,6 +714,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -744,6 +747,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.jni.JNIService;

import java.io.IOException;
Expand Down Expand Up @@ -167,6 +168,7 @@ public void testQueryScoreForFaissWithModel() {
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata);

KNNWeight.initialize(modelDao);
Expand Down Expand Up @@ -254,7 +256,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException {
when(fieldInfo.getAttribute(eq(MODEL_ID))).thenReturn(modelId);

RuntimeException ex = expectThrows(RuntimeException.class, () -> knnWeight.scorer(leafReaderContext));
assertEquals(String.format("Model \"%s\" does not exist.", modelId), ex.getMessage());
assertEquals(String.format("Model \"%s\" is not created.", modelId), ex.getMessage());
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {
// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(trainingFieldModelMetadata.getState()).thenReturn(ModelState.CREATED);

ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testUpdateModelsNewCluster() throws IOException, InterruptedExceptio
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
doAnswer(invocationOnMock -> {
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits searchHits = mock(SearchHits.class);
Expand Down Expand Up @@ -144,6 +145,7 @@ public void testUpdateModelsNodesRemoved() throws IOException, InterruptedExcept
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
DiscoveryNode node1 = mock(DiscoveryNode.class);
when(node1.getEphemeralId()).thenReturn("test-node-model-match");
DiscoveryNode node2 = mock(DiscoveryNode.class);
Expand Down

0 comments on commit d9d2d10

Please sign in to comment.