Skip to content

Commit

Permalink
Add UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jan 30, 2024
1 parent 9b04a11 commit cf4409e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 317 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
throw exception;
actionListener.onFailure(exception);
} catch (Throwable e) {
log.error("Failed to execute predict in aws connector", e);
throw new MLException("Fail to execute predict in aws connector", e);
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,6 @@ public static SdkHttpFullRequest buildSdkRequest(Connector connector, Map<String
} else {
requestBody = RequestBody.empty();
}
if (requestBody.optionalContentLength().isEmpty()) {
log.error("Content length is empty. Aborting request to remote model");
actionListener.onFailure(new IllegalArgumentException("Content length is empty. Aborting request to remote model"));
}
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,13 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
SdkHttpFullRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
case "POST":
try {
log.debug("original payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST, actionListener);
MLHttpClientFactory.validateIp(request.getUri().getHost());
} catch (Exception e) {
throw new MLException("Failed to create http request for remote model", e);
}
log.debug("original payload to remote model: " + payload);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST, actionListener);
MLHttpClientFactory.validateIp(request.getUri().getHost());
break;
case "GET":
try {
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET, actionListener);
MLHttpClientFactory.validateIp(request.getUri().getHost());
} catch (Exception e) {
throw new MLException("Failed to create http request for remote model", e);
}
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET, actionListener);
MLHttpClientFactory.validateIp(request.getUri().getHost());
break;
default:
throw new IllegalArgumentException("unsupported http method");
Expand All @@ -96,10 +88,10 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
throw e;
actionListener.onFailure(e);
} catch (Throwable e) {
log.error("Fail to execute http connector", e);
throw new MLException("Fail to execute http connector", e);
actionListener.onFailure(new MLException("Fail to execute http connector", e));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ public MLOutput predict(MLInput mlInput, MLModel model) {
@Override
public void predict(MLInput mlInput, MLTask mlTask, ActionListener<MLTaskResponse> actionListener) {
if (!isModelReady()) {
throw new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy");
actionListener.onFailure(new IllegalArgumentException("Model not ready yet. Please run this first: POST /_plugins/_ml/models/<model_id>/_deploy"));
}
try {
connectorExecutor.executePredict(mlInput, actionListener);
} catch (RuntimeException e) {
log.error("Failed to call remote model.", e);
throw e;
actionListener.onFailure(e);
} catch (Throwable e) {
log.error("Failed to call remote model.", e);
throw new MLException(e);
actionListener.onFailure(new MLException(e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
Expand All @@ -19,6 +21,7 @@
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -94,8 +97,6 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {

@Test
public void executePredict_RemoteInferenceInput_invalidIp() {
exceptionRule.expect(MLException.class);

ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
Expand Down Expand Up @@ -125,11 +126,11 @@ public void executePredict_RemoteInferenceInput_invalidIp() {

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
Mockito.verify(actionListener, times(1)).onFailure(any(MLException.class));
}

@Test
public void executePredict_RemoteInferenceInput_illegalIpAddress() {
exceptionRule.expect(IllegalArgumentException.class);
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
Expand Down
Loading

0 comments on commit cf4409e

Please sign in to comment.