Skip to content

Commit 04e1d8e

Browse files
committed
fix IT failure
Signed-off-by: zane-neo <zaniu@amazon.com>
1 parent ab2d736 commit 04e1d8e

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,25 @@ public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOExc
930930
return parseResponseToMap(response);
931931
}
932932

933+
public Map predictTextEmbeddingModelIgnoreFunctionName(String modelId, MLInput mlInput) throws IOException {
934+
Response response = null;
935+
try {
936+
response = TestHelper
937+
.makeRequest(
938+
client(),
939+
"POST",
940+
"/_plugins/_ml/models/" + modelId + "/_predict",
941+
null,
942+
TestHelper.toJsonString(mlInput),
943+
null
944+
);
945+
} catch (ResponseException e) {
946+
log.error(e.getMessage(), e);
947+
response = e.getResponse();
948+
}
949+
return parseResponseToMap(response);
950+
}
951+
933952
public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
934953
return (modelProfile) -> {
935954
if (modelProfile.containsKey("model_state")) {

plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import org.junit.Before;
1717
import org.opensearch.ml.common.FunctionName;
18+
import org.opensearch.ml.common.connector.ConnectorAction;
1819
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
1920
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
2021
import org.opensearch.ml.common.input.MLInput;
@@ -327,9 +328,10 @@ public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exc
327328
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
328329
.builder()
329330
.parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512"))
331+
.actionType(ConnectorAction.ActionType.PREDICT)
330332
.build();
331333
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
332-
Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
334+
Map inferenceResult = predictTextEmbeddingModelIgnoreFunctionName(modelId, mlInput);
333335
String errorMsg = String
334336
.format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
335337
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));

0 commit comments

Comments
 (0)