From 3d2a3932780e89575345bf373dc37c6b11af1fdc Mon Sep 17 00:00:00 2001 From: will-hwang Date: Thu, 30 Jan 2025 15:32:29 -0800 Subject: [PATCH] add initial implementation for supporting inference call skip in text embedding processor --- .../neuralsearch/plugin/NeuralSearch.java | 7 +- .../processor/InferenceProcessor.java | 181 +++++++++++++++++- .../processor/TextEmbeddingProcessor.java | 37 +++- .../TextEmbeddingProcessorFactory.java | 24 ++- .../processor/util/ProcessorUtils.java | 124 +++++++++++- .../TextEmbeddingProcessorTests.java | 7 + 6 files changed, 357 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index f7ac5d19f..722ae8145 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -127,7 +127,12 @@ public Map getProcessors(Processor.Parameters paramet clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Map.of( TextEmbeddingProcessor.TYPE, - new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), + new TextEmbeddingProcessorFactory( + parameters.client, + clientAccessor, + parameters.env, + parameters.ingestService.getClusterService() + ), SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), TextImageEmbeddingProcessor.TYPE, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 6ee54afe7..7658af1df 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -27,6 +28,7 @@ import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.common.collect.Tuple; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; @@ -41,6 +43,7 @@ import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.util.ProcessorUtils; import org.opensearch.neuralsearch.util.ProcessorDocumentUtils; /** @@ -278,7 +281,7 @@ private static class DataForInference { } @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { + protected List createInferenceList(Map knnKeyMap) { List texts = new ArrayList<>(); knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { Object sourceValue = knnMapEntry.getValue(); @@ -434,6 +437,160 @@ Map buildNLPResult(Map processorMap, List res return result; } + /** + * This method traverses the ProcessMap generated with IngestDocument to filter out the entries that can be skipped for inference. + * @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document + * @param sourceAndMetadataMap SourceAndMetadataMap of IngestDocument + * @param processMap processMap built from ingestDocument + * @return filtered ProcessMap with only entries that require model inferencing + */ + + protected Map filterProcessMap( + Map existingSourceAndMetadataMap, + Map sourceAndMetadataMap, + Map processMap + ) { + Iterator> iterator = processMap.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + Pair processedNestedKey = processNestedKey(entry); + String key = processedNestedKey.getKey(); + Object sourceValue = processedNestedKey.getValue(); + traverseProcessMap(key, "", key, sourceValue, iterator, existingSourceAndMetadataMap, sourceAndMetadataMap, 0); + } + return processMap; + } + + /** + * This method recursively traverses sourceValue with Map and List Iterators and removes entries that have copied existing embeddings from existing document + * @param pathKey traversed path separated by . + * @param prevPathKey previous traversed path separated by . + * @param currKey current key in ProcessMap in traversal + * @param sourceValue current object in traversal. Can be List, Map or String + * @param iterator Map/List Iterator used to iterating and removing eligible entries + * @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document + * @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument + * @param listIndex index of sourceValue if in list + */ + + private void traverseProcessMap( + String pathKey, + String prevPathKey, + String currKey, + Object sourceValue, + Iterator iterator, + Map existingSourceAndMetadataMap, + Map sourceAndMetadataMap, + int listIndex + ) { + if (sourceValue instanceof String) { + if (copyEmbeddings( + prevPathKey, + currKey, + sourceValue.toString(), + existingSourceAndMetadataMap, + sourceAndMetadataMap, + listIndex + )) { + iterator.remove(); + } + } else if (sourceValue instanceof List) { + Iterator listIterator = ((List) sourceValue).iterator(); + IntStream.range(0, ((List) sourceValue).size()).forEach(index -> { + Object item = listIterator.next(); + traverseProcessMap( + pathKey, + prevPathKey, + currKey, + item, + listIterator, + existingSourceAndMetadataMap, + sourceAndMetadataMap, + index + ); + }); + } else if (sourceValue instanceof Map) { + Iterator> nestedIterator = ((Map) sourceValue).entrySet().iterator(); + while (nestedIterator.hasNext()) { + Map.Entry nestedEntry = nestedIterator.next(); + Pair processedNestedKey = processNestedKey(nestedEntry); + String nextPathKey = pathKey + "." + processedNestedKey.getKey(); + Object nextSourceValue = processedNestedKey.getValue(); + traverseProcessMap( + nextPathKey, + pathKey, + processedNestedKey.getKey(), + nextSourceValue, + nestedIterator, + existingSourceAndMetadataMap, + sourceAndMetadataMap, + listIndex + ); + } + } + } + + /** + * This method checks for the following requirements to determine whether embeddings can be copied from existingSourceAndMetadataMap to sourceAndMetadataMap + * - inference text is the same between existingSourceAndMetadataMap and sourceAndMetadataMap + * - existing existingSourceAndMetadataMap has embeddings for inference text + * @param embeddingPath path to embedding field + * @param embeddingField name of the embedding field + * @param text inference text in IngestDocument + * @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document + * @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument + * @param listIndex index of sourceValue if in list + * + * returns true if existing embedding was successfully populated after passing the required checks, return false otherwise + */ + + private boolean copyEmbeddings( + String embeddingPath, + String embeddingField, + String text, + Map existingSourceAndMetadataMap, + Map sourceAndMetadataMap, + int listIndex + ) { + Optional textKeyValue = ProcessorUtils.findKeyFromFromValue(fieldMap, embeddingPath, embeddingField); + if (textKeyValue.isPresent()) { + String textKey = textKeyValue.get(); + Optional inferenceText = ProcessorUtils.getValueFromSource( + existingSourceAndMetadataMap, + String.join(".", embeddingPath, textKey), + listIndex + ); + if (inferenceText.isPresent() && inferenceText.get().equals(text)) { + Optional embeddings = ProcessorUtils.getValueFromSource( + existingSourceAndMetadataMap, + String.join(".", embeddingPath, embeddingField), + listIndex + ); + if (embeddings.isPresent()) { + return ProcessorUtils.setValueToSource( + sourceAndMetadataMap, + String.join(".", embeddingPath, embeddingField), + embeddings.get(), + listIndex + ); + } + } + } + return false; + } + + protected void makeInferenceCall( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ) { + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + @SuppressWarnings({ "unchecked" }) private void putNLPResultToSourceMapForMapType( String processorKey, @@ -512,7 +669,8 @@ private void processMapEntryValue( ) { // build nlp output for object in sourceValue which is map type Iterator> iterator = sourceAndMetadataMapValueInList.iterator(); - IntStream.range(0, sourceAndMetadataMapValueInList.size()).forEach(index -> { + IndexWrapper listIndexWrapper = new IndexWrapper(0); + for (int i = 0; i < sourceAndMetadataMapValueInList.size(); i++) { Map nestedElement = iterator.next(); putNLPResultToSingleSourceMapInList( inputNestedMapEntryKey, @@ -520,9 +678,9 @@ private void processMapEntryValue( results, indexWrapper, nestedElement, - index + listIndexWrapper ); - }); + } } /** @@ -534,7 +692,7 @@ private void processMapEntryValue( * @param results * @param indexWrapper * @param sourceAndMetadataMap - * @param nestedElementIndex index of the element in the list field of source document + * @param listIndexWrapper index of the element in the list field of source document */ @SuppressWarnings("unchecked") private void putNLPResultToSingleSourceMapInList( @@ -543,7 +701,7 @@ private void putNLPResultToSingleSourceMapInList( List results, IndexWrapper indexWrapper, Map sourceAndMetadataMap, - int nestedElementIndex + IndexWrapper listIndexWrapper ) { if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; if (sourceValue instanceof Map) { @@ -556,12 +714,17 @@ private void putNLPResultToSingleSourceMapInList( results, indexWrapper, sourceMap, - nestedElementIndex + listIndexWrapper ); } } else { - if (sourceValue instanceof List && ((List) sourceValue).get(nestedElementIndex) != null) { - sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION); + if (sourceValue instanceof List) { + if (sourceAndMetadataMap.containsKey(processorKey)) { + return; + } + if (((List) sourceValue).get(listIndexWrapper.index++) != null) { + sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION); + } } } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c8f9f080d..5b0f7ee14 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -9,6 +9,9 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.client.OpenSearchClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; @@ -26,6 +29,10 @@ public final class TextEmbeddingProcessor extends InferenceProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + public static final String IGNORE_UNALTERED = "ignore_unaltered"; + public static final boolean DEFAULT_IGNORE_UNALTERED = false; + private final boolean ignoreUnaltered; + private final OpenSearchClient openSearchClient; public TextEmbeddingProcessor( String tag, @@ -33,11 +40,15 @@ public TextEmbeddingProcessor( int batchSize, String modelId, Map fieldMap, + boolean ignoreUnaltered, + OpenSearchClient openSearchClient, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + this.openSearchClient = openSearchClient; + this.ignoreUnaltered = ignoreUnaltered; } @Override @@ -47,10 +58,28 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + if (ignoreUnaltered == true) { + String index = ingestDocument.getSourceAndMetadata().get("_index").toString(); + String id = ingestDocument.getSourceAndMetadata().get("_id").toString(); + openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> { + final Map document = response.getSourceAsMap(); + if (document == null || document.isEmpty()) { + makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler); + } else { + Map filteredProcessMap = filterProcessMap(document, ingestDocument.getSourceAndMetadata(), ProcessMap); + List filteredInferenceList = createInferenceList(filteredProcessMap); + if (!filteredInferenceList.isEmpty()) { + log.info("making inference call for: {}", filteredInferenceList); + makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler); + } else { + log.info("skipping inference call"); + handler.accept(ingestDocument, null); + } + } + }, e -> { makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler); })); + } else { + makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler); + } } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 6b442b56c..e2431491e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -4,14 +4,18 @@ */ package org.opensearch.neuralsearch.processor.factory; +import static org.opensearch.ingest.ConfigurationUtils.readBooleanProperty; import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.DEFAULT_IGNORE_UNALTERED; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.IGNORE_UNALTERED; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; import java.util.Map; +import org.opensearch.client.OpenSearchClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.ingest.AbstractBatchingProcessor; @@ -23,6 +27,8 @@ */ public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcessor.Factory { + private final OpenSearchClient openSearchClient; + private final MLCommonsClientAccessor clientAccessor; private final Environment environment; @@ -30,11 +36,13 @@ public final class TextEmbeddingProcessorFactory extends AbstractBatchingProcess private final ClusterService clusterService; public TextEmbeddingProcessorFactory( + final OpenSearchClient openSearchClient, final MLCommonsClientAccessor clientAccessor, final Environment environment, final ClusterService clusterService ) { super(TYPE); + this.openSearchClient = openSearchClient; this.clientAccessor = clientAccessor; this.environment = environment; this.clusterService = clusterService; @@ -43,7 +51,19 @@ public TextEmbeddingProcessorFactory( @Override protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); - Map filedMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); - return new TextEmbeddingProcessor(tag, description, batchSize, modelId, filedMap, clientAccessor, environment, clusterService); + Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + boolean ignoreUnaltered = readBooleanProperty(TYPE, tag, config, IGNORE_UNALTERED, DEFAULT_IGNORE_UNALTERED); + return new TextEmbeddingProcessor( + tag, + description, + batchSize, + modelId, + fieldMap, + ignoreUnaltered, + openSearchClient, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java index d799f323f..eeb0ec95c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java @@ -8,6 +8,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.search.SearchHit; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; @@ -151,26 +152,81 @@ public static void removeTargetFieldFromSource(final Map sourceA * @return A possible result within an optional */ public static Optional getValueFromSource(final Map sourceAsMap, final String targetField) { - String[] keys = targetField.split("\\."); - Optional currentValue = Optional.of(sourceAsMap); + return getValueFromSource(sourceAsMap, targetField, -1); + } + public static Optional getValueFromSource( + final Map sourceAsMap, + final String targetField, + final int listIndex + ) { + // filter keys with non-empty strings + String[] keys = Arrays.stream(targetField.split("\\.")).filter(key -> !key.isBlank()).toArray(String[]::new); + Optional currentValue = Optional.of(sourceAsMap); for (String key : keys) { currentValue = currentValue.flatMap(value -> { - if (!(value instanceof Map)) { - return Optional.empty(); + if (value instanceof List && listIndex != -1) { + List currentList = (List) value; + Object listValue = currentList.get(listIndex); + if (listValue instanceof Map) { + Map currentMap = (Map) listValue; + return Optional.ofNullable(currentMap.get(key)); + } } - Map currentMap = (Map) value; - return Optional.ofNullable(currentMap.get(key)); + if ((value instanceof Map)) { + Map currentMap = (Map) value; + return Optional.ofNullable(currentMap.get(key)); + } + return Optional.empty(); }); if (currentValue.isEmpty()) { return Optional.empty(); } } - return currentValue; } + /** + * Given the path to existing targetField in sourceAsMap, sets targetValue in targetField + * return true if targetValue was successfully set in targetField + * returns false otherwise + * + * @param sourceAsMap The Source map (a map of maps) to iterate through + * @param targetPath The path to key to insert the desired targetValue + * @param targetValue The value to insert in targetField + */ + public static boolean setValueToSource( + final Map sourceAsMap, + final String targetPath, + final Object targetValue, + final int listIndex + ) { + // filter keys with non-empty strings + String[] keys = Arrays.stream(targetPath.split("\\.")).filter(key -> !key.isBlank()).toArray(String[]::new); + Map currentMap = sourceAsMap; + String lastKey = keys[keys.length - 1]; + for (int i = 0; i < keys.length - 1; i++) { + String key = keys[i]; + Object value = currentMap.get(key); + if (value instanceof List && listIndex != -1) { + List currentList = (List) value; + Object listValue = currentList.get(listIndex); + if (listValue instanceof Map) { + currentMap = (Map) listValue; + } + } else if (value instanceof Map) { + currentMap = (Map) value; + } else { + return false; + } + } + currentMap.put(lastKey, targetValue); + return true; + } + + + /** * Determines whether there exists a value that has a mapping according to the pathToValue. This is particularly * useful when the source map is a map of maps and when the pathToValue is of the form key[.key]. @@ -209,4 +265,58 @@ public static boolean isNumeric(Object value) { return false; } + + /** + * Given a map, path and value, return the key mapped with given value. + * + * e.g: + * + * sourceAsMap: { + * "level1": { + * "level2" : { + * "text": "passage_embedding" + * } + * } + * } + * path: "level1.level2" + * targeValue: "passage_embedding" + * + * returns "text" + * + * if there are multiple keys mapping with same value, return the last key + * + * @param sourceAsMap The Source map (a map of maps) to iterate through + * @param targetPath The path to key to insert the desired mapping + * @param targetValue The target value to find in targetField + */ + public static Optional findKeyFromFromValue( + final Map sourceAsMap, + final String targetPath, + final String targetValue + ) { + // filter keys with non-empty strings + String[] keys = Arrays.stream(targetPath.split("\\.")).filter(key -> !key.isBlank()).toArray(String[]::new); + Optional currentValue = Optional.of(sourceAsMap); + + for (String key : keys) { + currentValue = currentValue.flatMap(value -> { + if (!(value instanceof Map)) { + return Optional.empty(); + } + Map currentMap = (Map) value; + return Optional.ofNullable(currentMap.get(key)); + }); + + if (currentValue.isEmpty()) { + return Optional.empty(); + } + } + String targetKey = null; + for (Map.Entry entry : (Iterable>) ((Map) currentValue.get()).entrySet()) { + if (entry.getValue().equals(targetValue)) { + targetKey = entry.getKey(); + } + } + return Optional.ofNullable(targetKey); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index e42c9023b..d965075a4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -36,6 +36,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.client.OpenSearchClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; @@ -67,6 +68,10 @@ public class TextEmbeddingProcessorTests extends InferenceProcessorTestCase { protected static final String TEXT_VALUE_2 = "text_value2"; protected static final String TEXT_VALUE_3 = "text_value3"; protected static final String TEXT_FIELD_2 = "abc"; + + @Mock + private OpenSearchClient openSearchClient; + @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; @@ -175,6 +180,7 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory( + openSearchClient, accessor, environment, clusterService @@ -203,6 +209,7 @@ public void testExecute_whenInferenceTextListEmpty_SuccessWithoutEmbedding() { Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); TextEmbeddingProcessorFactory textEmbeddingProcessorFactory = new TextEmbeddingProcessorFactory( + openSearchClient, accessor, environment, clusterService