Skip to content

Commit

Permalink
Add UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jan 30, 2024
1 parent 48418e7 commit 40b7da4
Show file tree
Hide file tree
Showing 5 changed files with 433 additions and 340 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.opensearch.ml.engine.algorithms.remote;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -33,7 +35,9 @@

@Log4j2
public class MLSdkAsyncHttpResponseHandler implements SdkAsyncHttpResponseHandler {
@Getter
private Integer statusCode;
@Getter
private final StringBuilder responseBody = new StringBuilder();

private WrappedCountDownLatch countDownLatch;
Expand Down Expand Up @@ -75,13 +79,52 @@ public void onHeaders(SdkHttpResponse response) {

@Override
public void onStream(Publisher<ByteBuffer> stream) {
stream.subscribe(new Subscriber<>() {
stream.subscribe(new MLResponseSubscriber());
}

@Override
public void onError(Throwable error) {
log.error(error.getMessage(), error);
actionListener.onFailure(new OpenSearchStatusException("Error on communication with remote model: " + error.getMessage(), RestStatus.INTERNAL_SERVER_ERROR));
}

private void processResponse(Integer statusCode, String body, Map<String, String> parameters, Map<Integer, ModelTensors> tensorOutputs)
throws IOException {
if (Strings.isBlank(body)) {
log.error("Remote model response body is empty!");
throw new OpenSearchStatusException("Remote model response is empty", RestStatus.fromCode(statusCode));
} else {
if (statusCode < 200 || statusCode > 300) {
log.error("Remote server returned error code: {}", statusCode);
throw new OpenSearchStatusException("Remote server returned error code: " + statusCode, RestStatus.fromCode(statusCode));
} else {
ModelTensors tensors = processOutput(body, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.put(countDownLatch.getSequence(), tensors);
}
}
}

private List<ModelTensors> reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
List<ModelTensors> modelTensors = new ArrayList<>();
TreeMap<Integer, ModelTensors> sortedMap = new TreeMap<>(tensorOutputs);
log.info("Reordered tensor outputs size is {}", sortedMap.size());
for (Map.Entry<Integer, ModelTensors> entry : sortedMap.entrySet()) {
modelTensors.add(entry.getKey(), entry.getValue());
}
return modelTensors;
}

protected class MLResponseSubscriber implements Subscriber<ByteBuffer> {
private Subscription subscription;
@Override public void onSubscribe(Subscription s) {
@Override
public void onSubscribe(Subscription s) {
this.subscription = s;
s.request(Long.MAX_VALUE);
}
@Override public void onNext(ByteBuffer byteBuffer) {

@Override
public void onNext(ByteBuffer byteBuffer) {
responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer));
subscription.request(Long.MAX_VALUE);
}
Expand Down Expand Up @@ -117,41 +160,5 @@ public void onComplete() {
}
}
}
});
}

@Override
public void onError(Throwable error) {
log.error(error.getMessage(), error);
actionListener.onFailure(new OpenSearchStatusException("Error on communication with remote model: " + error.getMessage(), RestStatus.INTERNAL_SERVER_ERROR));
}

private void processResponse(Integer statusCode, String body, Map<String, String> parameters, Map<Integer, ModelTensors> tensorOutputs)
throws IOException {
if (body == null) {
log.error("Remote model response body is empty!");
throw new OpenSearchStatusException("Remote model response is empty", RestStatus.fromCode(statusCode));
} else {
if (statusCode < 200 || statusCode > 300) {
log.error("Remote server returned error code: {}", statusCode);
throw new OpenSearchStatusException("Remote server returned error code: " + statusCode, RestStatus.fromCode(statusCode));
} else {
ModelTensors tensors = processOutput(body, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
log.info("################## Starting to put tensors into tensorOutputs, sequence is {} ", countDownLatch.getSequence());
tensorOutputs.put(countDownLatch.getSequence(), tensors);
log.info("################## End to put tensors into tensorOutputs, sequence is {} ", countDownLatch.getSequence());
}
}
}

private List<ModelTensors> reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
List<ModelTensors> modelTensors = new ArrayList<>();
TreeMap<Integer, ModelTensors> sortedMap = new TreeMap<>(tensorOutputs);
log.info("Reordered tensor outputs size is {}", sortedMap.size());
for (Map.Entry<Integer, ModelTensors> entry : sortedMap.entrySet()) {
modelTensors.add(entry.getKey(), entry.getValue());
}
return modelTensors;
}
}
Loading

0 comments on commit 40b7da4

Please sign in to comment.