Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft PR] proposed implementation for supporting inference call skip in text embedding processor #1155

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
new TextEmbeddingProcessorFactory(
parameters.client,
clientAccessor,
parameters.env,
parameters.ingestService.getClusterService()
),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()),
TextImageEmbeddingProcessor.TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
Expand All @@ -27,6 +28,7 @@
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.env.Environment;
Expand All @@ -41,6 +43,7 @@
import com.google.common.collect.ImmutableMap;

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.util.ProcessorUtils;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

/**
Expand Down Expand Up @@ -278,7 +281,7 @@
}

@SuppressWarnings({ "unchecked" })
private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
protected List<String> createInferenceList(Map<String, Object> knnKeyMap) {
List<String> texts = new ArrayList<>();
knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
Object sourceValue = knnMapEntry.getValue();
Expand Down Expand Up @@ -434,6 +437,160 @@
return result;
}

/**
* This method traverses the ProcessMap generated with IngestDocument to filter out the entries that can be skipped for inference.
* @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document
* @param sourceAndMetadataMap SourceAndMetadataMap of IngestDocument
* @param processMap processMap built from ingestDocument
* @return filtered ProcessMap with only entries that require model inferencing
*/

protected Map<String, Object> filterProcessMap(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a shared method in the parent class may not be ideal for simplifying unit testing. A better approach would be to create a separate class and inject it into the processor. This allows the injected instance to be easily mocked during unit testing, making the process of writing unit tests much more straightforward.

Map<String, Object> existingSourceAndMetadataMap,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> processMap
) {
Iterator<Map.Entry<String, Object>> iterator = processMap.entrySet().iterator();

Check warning on line 453 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L453

Added line #L453 was not covered by tests
while (iterator.hasNext()) {
Map.Entry<String, Object> entry = iterator.next();
Pair<String, Object> processedNestedKey = processNestedKey(entry);
String key = processedNestedKey.getKey();
Object sourceValue = processedNestedKey.getValue();
traverseProcessMap(key, "", key, sourceValue, iterator, existingSourceAndMetadataMap, sourceAndMetadataMap, 0);
}
return processMap;

Check warning on line 461 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L455-L461

Added lines #L455 - L461 were not covered by tests
}

/**
* This method recursively traverses sourceValue with Map and List Iterators and removes entries that have copied existing embeddings from existing document
* @param pathKey traversed path separated by .
* @param prevPathKey previous traversed path separated by .
* @param currKey current key in ProcessMap in traversal
* @param sourceValue current object in traversal. Can be List, Map or String
* @param iterator Map/List Iterator used to iterating and removing eligible entries
* @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document
* @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument
* @param listIndex index of sourceValue if in list
*/

private void traverseProcessMap(
String pathKey,
String prevPathKey,
String currKey,
Object sourceValue,
Iterator<?> iterator,
Map<String, Object> existingSourceAndMetadataMap,
Map<String, Object> sourceAndMetadataMap,
int listIndex
) {
if (sourceValue instanceof String) {
if (copyEmbeddings(
prevPathKey,
currKey,
sourceValue.toString(),

Check warning on line 490 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L490

Added line #L490 was not covered by tests
existingSourceAndMetadataMap,
sourceAndMetadataMap,
listIndex
)) {
iterator.remove();

Check warning on line 495 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L495

Added line #L495 was not covered by tests
}
} else if (sourceValue instanceof List) {
Iterator<?> listIterator = ((List) sourceValue).iterator();
IntStream.range(0, ((List) sourceValue).size()).forEach(index -> {
Object item = listIterator.next();
traverseProcessMap(

Check warning on line 501 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L498-L501

Added lines #L498 - L501 were not covered by tests
pathKey,
prevPathKey,
currKey,
item,
listIterator,
existingSourceAndMetadataMap,
sourceAndMetadataMap,
index
);
});

Check warning on line 511 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L511

Added line #L511 was not covered by tests
} else if (sourceValue instanceof Map) {
Iterator<Map.Entry<String, Object>> nestedIterator = ((Map<String, Object>) sourceValue).entrySet().iterator();

Check warning on line 513 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L513

Added line #L513 was not covered by tests
while (nestedIterator.hasNext()) {
Map.Entry<String, Object> nestedEntry = nestedIterator.next();
Pair<String, Object> processedNestedKey = processNestedKey(nestedEntry);
String nextPathKey = pathKey + "." + processedNestedKey.getKey();
Object nextSourceValue = processedNestedKey.getValue();
traverseProcessMap(

Check warning on line 519 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L515-L519

Added lines #L515 - L519 were not covered by tests
nextPathKey,
pathKey,
processedNestedKey.getKey(),

Check warning on line 522 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L522

Added line #L522 was not covered by tests
nextSourceValue,
nestedIterator,
existingSourceAndMetadataMap,
sourceAndMetadataMap,
listIndex
);
}

Check warning on line 529 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L529

Added line #L529 was not covered by tests
}
}

Check warning on line 531 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L531

Added line #L531 was not covered by tests

/**
* This method checks for the following requirements to determine whether embeddings can be copied from existingSourceAndMetadataMap to sourceAndMetadataMap
* - inference text is the same between existingSourceAndMetadataMap and sourceAndMetadataMap
* - existing existingSourceAndMetadataMap has embeddings for inference text
* @param embeddingPath path to embedding field
* @param embeddingField name of the embedding field
* @param text inference text in IngestDocument
* @param existingSourceAndMetadataMap SourceAndMetadataMap of existing Document
* @param sourceAndMetadataMap SourceAndMetadataMap of ingestDocument
* @param listIndex index of sourceValue if in list
*
* returns true if existing embedding was successfully populated after passing the required checks, return false otherwise
*/

private boolean copyEmbeddings(
String embeddingPath,
String embeddingField,
String text,
Map<String, Object> existingSourceAndMetadataMap,
Map<String, Object> sourceAndMetadataMap,
int listIndex
) {
Optional<String> textKeyValue = ProcessorUtils.findKeyFromFromValue(fieldMap, embeddingPath, embeddingField);

Check warning on line 555 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L555

Added line #L555 was not covered by tests
if (textKeyValue.isPresent()) {
String textKey = textKeyValue.get();
Optional<Object> inferenceText = ProcessorUtils.getValueFromSource(

Check warning on line 558 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L557-L558

Added lines #L557 - L558 were not covered by tests
existingSourceAndMetadataMap,
String.join(".", embeddingPath, textKey),

Check warning on line 560 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L560

Added line #L560 was not covered by tests
listIndex
);
if (inferenceText.isPresent() && inferenceText.get().equals(text)) {
Optional<Object> embeddings = ProcessorUtils.getValueFromSource(

Check warning on line 564 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L564

Added line #L564 was not covered by tests
existingSourceAndMetadataMap,
String.join(".", embeddingPath, embeddingField),

Check warning on line 566 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L566

Added line #L566 was not covered by tests
listIndex
);
if (embeddings.isPresent()) {
return ProcessorUtils.setValueToSource(

Check warning on line 570 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L570

Added line #L570 was not covered by tests
sourceAndMetadataMap,
String.join(".", embeddingPath, embeddingField),
embeddings.get(),

Check warning on line 573 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L572-L573

Added lines #L572 - L573 were not covered by tests
listIndex
);
}
}
}
return false;

Check warning on line 579 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L579

Added line #L579 was not covered by tests
}

protected void makeInferenceCall(
IngestDocument ingestDocument,
Map<String, Object> ProcessMap,
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
}

@SuppressWarnings({ "unchecked" })
private void putNLPResultToSourceMapForMapType(
String processorKey,
Expand Down Expand Up @@ -512,17 +669,18 @@
) {
// build nlp output for object in sourceValue which is map type
Iterator<Map<String, Object>> iterator = sourceAndMetadataMapValueInList.iterator();
IntStream.range(0, sourceAndMetadataMapValueInList.size()).forEach(index -> {
IndexWrapper listIndexWrapper = new IndexWrapper(0);
for (int i = 0; i < sourceAndMetadataMapValueInList.size(); i++) {
Map<String, Object> nestedElement = iterator.next();
putNLPResultToSingleSourceMapInList(
inputNestedMapEntryKey,
inputNestedMapEntryValue,
results,
indexWrapper,
nestedElement,
index
listIndexWrapper
);
});
}
}

/**
Expand All @@ -534,7 +692,7 @@
* @param results
* @param indexWrapper
* @param sourceAndMetadataMap
* @param nestedElementIndex index of the element in the list field of source document
* @param listIndexWrapper index of the element in the list field of source document
*/
@SuppressWarnings("unchecked")
private void putNLPResultToSingleSourceMapInList(
Expand All @@ -543,7 +701,7 @@
List<?> results,
IndexWrapper indexWrapper,
Map<String, Object> sourceAndMetadataMap,
int nestedElementIndex
IndexWrapper listIndexWrapper
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
if (sourceValue instanceof Map) {
Expand All @@ -556,12 +714,17 @@
results,
indexWrapper,
sourceMap,
nestedElementIndex
listIndexWrapper
);
}
} else {
if (sourceValue instanceof List && ((List<Object>) sourceValue).get(nestedElementIndex) != null) {
sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
if (sourceValue instanceof List) {
if (sourceAndMetadataMap.containsKey(processorKey)) {
return;

Check warning on line 723 in src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java#L723

Added line #L723 was not covered by tests
}
if (((List<Object>) sourceValue).get(listIndexWrapper.index++) != null) {
sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.client.OpenSearchClient;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.env.Environment;
Expand All @@ -26,18 +29,26 @@

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
public static final String IGNORE_UNALTERED = "ignore_unaltered";
public static final boolean DEFAULT_IGNORE_UNALTERED = false;
private final boolean ignoreUnaltered;
private final OpenSearchClient openSearchClient;

public TextEmbeddingProcessor(
String tag,
String description,
int batchSize,
String modelId,
Map<String, Object> fieldMap,
boolean ignoreUnaltered,
OpenSearchClient openSearchClient,
MLCommonsClientAccessor clientAccessor,
Environment environment,
ClusterService clusterService
) {
super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService);
this.openSearchClient = openSearchClient;
this.ignoreUnaltered = ignoreUnaltered;
}

@Override
Expand All @@ -47,10 +58,28 @@
List<String> inferenceList,
BiConsumer<IngestDocument, Exception> handler
) {
mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> {
setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors);
handler.accept(ingestDocument, null);
}, e -> { handler.accept(null, e); }));
if (ignoreUnaltered == true) {
String index = ingestDocument.getSourceAndMetadata().get("_index").toString();
String id = ingestDocument.getSourceAndMetadata().get("_id").toString();
openSearchClient.execute(GetAction.INSTANCE, new GetRequest(index, id), ActionListener.wrap(response -> {
final Map<String, Object> document = response.getSourceAsMap();

Check warning on line 65 in src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java#L62-L65

Added lines #L62 - L65 were not covered by tests
if (document == null || document.isEmpty()) {
makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler);

Check warning on line 67 in src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java#L67

Added line #L67 was not covered by tests
} else {
Map<String, Object> filteredProcessMap = filterProcessMap(document, ingestDocument.getSourceAndMetadata(), ProcessMap);
List<String> filteredInferenceList = createInferenceList(filteredProcessMap);

Check warning on line 70 in src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java#L69-L70

Added lines #L69 - L70 were not covered by tests
if (!filteredInferenceList.isEmpty()) {
log.info("making inference call for: {}", filteredInferenceList);
makeInferenceCall(ingestDocument, filteredProcessMap, filteredInferenceList, handler);

Check warning on line 73 in src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java#L72-L73

Added lines #L72 - L73 were not covered by tests
} else {
log.info("skipping inference call");
handler.accept(ingestDocument, null);

Check warning on line 76 in src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java#L75-L76

Added lines #L75 - L76 were not covered by tests
}
}
}, e -> { makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler); }));
} else {

Check warning on line 80 in src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java#L79-L80

Added lines #L79 - L80 were not covered by tests
makeInferenceCall(ingestDocument, ProcessMap, inferenceList, handler);
}
}

@Override
Expand Down
Loading
Loading