Skip to content

Commit

Permalink
add ITs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jun 5, 2024
1 parent caabd56 commit 7fdbffc
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -898,10 +898,12 @@ public Map predictTextEmbedding(String modelId) throws IOException {
return result;
}

public ModelTensorOutput predictRemoteModel(String modelId, MLInput input) throws IOException {
public Map predictRemoteModel(String modelId, MLInput input) throws IOException {
String requestBody = TestHelper.toJsonString(input);
System.out.println("############################## request body is:" + requestBody);
Response response = TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_predict", null, TestHelper.toJsonString(input), null);
return new ModelTensorOutput(StreamInput.wrap(response.getEntity().getContent().readAllBytes()));
.makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null);
return parseResponseToMap(response);
}

public Consumer<Map<String, Object>> verifyTextEmbeddingModelDeployed() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,80 +5,31 @@

package org.opensearch.ml.rest;

import com.google.common.collect.ImmutableList;
import com.jayway.jsonpath.JsonPath;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Assert;
import lombok.SneakyThrows;
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.utils.TestHelper;
import org.w3c.dom.Text;
import org.opensearch.ml.common.utils.StringUtils;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.ml.utils.TestHelper.makeRequest;

public class RestBedRockInferenceIT extends MLCommonsRestTestCase {
private String bedrockEmbeddingModelId;
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
private static final String GITHUB_CI_AWS_REGION = "us-west-2";

private final String bedrockEmbeddingModelConnectorEntity = "{\n"
+ " \"name\": \"Amazon Bedrock Connector: embedding\",\n"
+ " \"description\": \"The connector to bedrock Titan embedding model\",\n"
+ " \"version\": 1,\n"
+ " \"protocol\": \"aws_sigv4\",\n"
+ " \"parameters\": {\n"
+ " \"region\": \""
+ GITHUB_CI_AWS_REGION
+ "\",\n"
+ " \"service_name\": \"bedrock\",\n"
+ " \"model_name\": \"amazon.titan-embed-text-v1\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"access_key\": \""
+ AWS_ACCESS_KEY_ID
+ "\",\n"
+ " \"secret_key\": \""
+ AWS_SECRET_ACCESS_KEY
+ "\",\n"
+ " \"session_token\": \""
+ AWS_SESSION_TOKEN
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n"
+ " \"headers\": {\n"
+ " \"content-type\": \"application/json\",\n"
+ " \"x-amz-content-sha256\": \"required\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n"
+ " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n"
+ " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n"
+ " }\n"
+ " ]\n"
+ "}";

@SneakyThrows
@Before
public void setup() throws IOException, InterruptedException {
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
Thread.sleep(20000);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true);
}


Expand All @@ -87,10 +38,28 @@ public void test_bedrock_embedding_model() throws Exception {
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
ModelTensorOutput output = predictRemoteModel(bedrockEmbeddingModelId, mlInput);
assertEquals(2, output.getMlModelOutputs().size());
assertEquals(1536, output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData().length);
String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI()));
Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
for (Map.Entry<String, Object> templateEntry : templateMap.entrySet()) {
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
String testCaseName = templateEntry.getKey();
String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName);
String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true);

TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build();
MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
Map inferenceResult = predictRemoteModel(modelId, mlInput);
assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
List output = (List) inferenceResult.get("inference_results");
assertEquals(errorMsg, 2, output.size());
assertTrue(errorMsg, output.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
assertEquals(errorMsg, 1, outputList.size());
assertTrue(errorMsg, outputList.get(0) instanceof Map);
assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
assertEquals(errorMsg, 1536, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"without_step_size": {
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v1"
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "connector.post_process.bedrock.embedding"
}
]
},
"with_step_size": {
"name": "Amazon Bedrock Connector: embedding",
"description": "The connector to bedrock Titan embedding model",
"version": 1,
"protocol": "aws_sigv4",
"parameters": {
"region": "%s",
"service_name": "bedrock",
"model_name": "amazon.titan-embed-text-v1",
"input_docs_processed_step_size": 1
},
"credential": {
"access_key": "%s",
"secret_key": "%s",
"session_token": "%s"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
"headers": {
"content-type": "application/json",
"x-amz-content-sha256": "required"
},
"request_body": "{ \"inputText\": \"${parameters.inputText}\" }",
"pre_process_function": "connector.pre_process.bedrock.embedding",
"post_process_function": "connector.post_process.bedrock.embedding"
}
]
}
}

0 comments on commit 7fdbffc

Please sign in to comment.