diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 0f8f22197b..29c863ed44 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -898,10 +898,12 @@ public Map predictTextEmbedding(String modelId) throws IOException { return result; } - public ModelTensorOutput predictRemoteModel(String modelId, MLInput input) throws IOException { + public Map predictRemoteModel(String modelId, MLInput input) throws IOException { + String requestBody = TestHelper.toJsonString(input); + System.out.println("############################## request body is:" + requestBody); Response response = TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null); - return new ModelTensorOutput(StreamInput.wrap(response.getEntity().getContent().readAllBytes())); + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + return parseResponseToMap(response); } public Consumer> verifyTextEmbeddingModelDeployed() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 2811e85e6c..4d831ef032 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,80 +5,31 @@ package org.opensearch.ml.rest; -import com.google.common.collect.ImmutableList; -import com.jayway.jsonpath.JsonPath; -import org.apache.hc.core5.http.HttpHeaders; -import org.apache.hc.core5.http.message.BasicHeader; -import org.junit.Assert; +import lombok.SneakyThrows; import org.junit.Before; -import org.opensearch.client.Request; -import org.opensearch.client.Response; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.utils.TestHelper; -import org.w3c.dom.Text; +import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.List; +import java.util.Locale; import java.util.Map; -import static org.opensearch.ml.utils.TestHelper.makeRequest; - public class RestBedRockInferenceIT extends MLCommonsRestTestCase { - private String bedrockEmbeddingModelId; private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); private static final String GITHUB_CI_AWS_REGION = "us-west-2"; - private final String bedrockEmbeddingModelConnectorEntity = "{\n" - + " \"name\": \"Amazon Bedrock Connector: embedding\",\n" - + " \"description\": \"The connector to bedrock Titan embedding model\",\n" - + " \"version\": 1,\n" - + " \"protocol\": \"aws_sigv4\",\n" - + " \"parameters\": {\n" - + " \"region\": \"" - + GITHUB_CI_AWS_REGION - + "\",\n" - + " \"service_name\": \"bedrock\",\n" - + " \"model_name\": \"amazon.titan-embed-text-v1\"\n" - + " },\n" - + " \"credential\": {\n" - + " \"access_key\": \"" - + AWS_ACCESS_KEY_ID - + "\",\n" - + " \"secret_key\": \"" - + AWS_SECRET_ACCESS_KEY - + "\",\n" - + " \"session_token\": \"" - + AWS_SESSION_TOKEN - + "\"\n" - + " },\n" - + " \"actions\": [\n" - + " {\n" - + " \"action_type\": \"predict\",\n" - + " \"method\": \"POST\",\n" - + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n" - + " \"headers\": {\n" - + " \"content-type\": \"application/json\",\n" - + " \"x-amz-content-sha256\": \"required\"\n" - + " },\n" - + " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n" - + " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n" - + " }\n" - + " ]\n" - + "}"; - + @SneakyThrows @Before public void setup() throws IOException, InterruptedException { RestMLRemoteInferenceIT.disableClusterConnectorAccessControl(); Thread.sleep(20000); - String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); - this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); } @@ -87,10 +38,28 @@ public void test_bedrock_embedding_model() throws Exception { if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { return; } - TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); - MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - ModelTensorOutput output = predictRemoteModel(bedrockEmbeddingModelId, mlInput); - assertEquals(2, output.getMlModelOutputs().size()); - assertEquals(1536, output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length); + String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI())); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictRemoteModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 2, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json new file mode 100644 index 0000000000..9dc1ed69fd --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockConnectorBodies.json @@ -0,0 +1,63 @@ +{ + "without_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1" + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + }, + "with_step_size": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "%s", + "service_name": "bedrock", + "model_name": "amazon.titan-embed-text-v1", + "input_docs_processed_step_size": 1 + }, + "credential": { + "access_key": "%s", + "secret_key": "%s", + "session_token": "%s" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\" }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +}