Skip to content

Commit

Permalink
Merge pull request opendatahub-io#48 from kserve/main
Browse files Browse the repository at this point in the history
[pull] main from kserve:main
  • Loading branch information
openshift-merge-bot[bot] committed Jan 10, 2024
2 parents c3d4fb8 + eaa2fde commit 92a8da8
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 47 deletions.
39 changes: 19 additions & 20 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ public void onHalfClose() {
String vModelId = null;
String requestId = null;
ModelResponse response = null;
ByteBuf responsePayload = null;
try (InterruptingListener cancelListener = newInterruptingListener()) {
if (logHeaders != null) {
logHeaders.addToMDC(headers); // MDC cleared in finally block
Expand Down Expand Up @@ -767,18 +768,20 @@ public void onHalfClose() {
} finally {
if (payloadProcessor != null) {
processPayload(reqMessage.readerIndex(reqReaderIndex),
requestId, resolvedModelId, methodName, headers, null, true);
requestId, resolvedModelId, vModelId, methodName, headers, null);
} else {
releaseReqMessage();
}
reqMessage = null; // ownership released or transferred
}

respReaderIndex = response.data.readerIndex();
respSize = response.data.readableBytes();
call.sendHeaders(response.metadata);
if (payloadProcessor != null) {
responsePayload = response.data.retainedSlice();
}
call.sendMessage(response.data);
// response is released via ReleaseAfterResponse.releaseAll()
// final response refcount is released via ReleaseAfterResponse.releaseAll()
status = OK;
} catch (Exception e) {
status = toStatus(e);
Expand All @@ -795,17 +798,13 @@ public void onHalfClose() {
evictMethodDescriptor(methodName);
}
} finally {
final boolean releaseResponse = status != OK;
if (payloadProcessor != null) {
ByteBuf data = null;
Metadata metadata = null;
if (response != null) {
data = response.data.readerIndex(respReaderIndex);
metadata = response.metadata;
}
processPayload(data, requestId, resolvedModelId, methodName, metadata, status, releaseResponse);
} else if (releaseResponse && response != null) {
response.release();
Metadata metadata = response != null ? response.metadata : null;
processPayload(responsePayload, requestId, resolvedModelId, vModelId, methodName, metadata, status);
}
if (status != OK && response != null) {
// An additional release is required if we call.sendMessage() wasn't sucessful
response.data.release();
}
ReleaseAfterResponse.releaseAll();
clearThreadLocals();
Expand All @@ -820,23 +819,22 @@ public void onHalfClose() {
}

/**
* Invoke PayloadProcessor on the request/response data
* Invoke PayloadProcessor on the request/response data. This method takes ownership
* of the passed-in {@code ByteBuf}.
*
* @param data the binary data
* @param payloadId the id of the request
* @param modelId the id of the model
* @param vModelId the id of the vModel
* @param methodName the name of the invoked method
* @param metadata the method name metadata
* @param status null for requests, non-null for responses
* @param takeOwnership whether the processor should take ownership
*/
private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName,
Metadata metadata, io.grpc.Status status, boolean takeOwnership) {
private void processPayload(ByteBuf data, String payloadId, String modelId, String vModelId, String methodName,
Metadata metadata, io.grpc.Status status) {
Payload payload = null;
try {
assert payloadProcessor != null;
if (!takeOwnership) {
ReferenceCountUtil.retain(data);
}
payload = new Payload(payloadId, modelId, methodName, metadata, data, status);
if (payloadProcessor.process(payload)) {
data = null; // ownership transferred
Expand Down Expand Up @@ -1200,6 +1198,7 @@ public void getVModelStatus(GetVModelStatusRequest request, StreamObserver<VMode
} finally {
clearThreadLocals();
}

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.ibm.watson.modelmesh.payload;

import java.io.IOException;
import java.util.Objects;

/**
* A {@link PayloadProcessor} that processes {@link Payload}s only if they match with given model ID or method name.
Expand All @@ -29,10 +30,13 @@ public class MatchingPayloadProcessor implements PayloadProcessor {

private final String modelId;

MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId) {
private final String vModelId;

MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId, String vModelId) {
this.delegate = delegate;
this.methodName = methodName;
this.modelId = modelId;
this.vModelId = vModelId;
}

@Override
Expand All @@ -42,40 +46,49 @@ public String getName() {

@Override
public boolean process(Payload payload) {
boolean processed = false;
boolean methodMatches = true;
if (this.methodName != null) {
methodMatches = payload.getMethod() != null && this.methodName.equals(payload.getMethod());
}
boolean methodMatches = this.methodName == null || Objects.equals(this.methodName, payload.getMethod());
if (methodMatches) {
boolean modelIdMatches = true;
if (this.modelId != null) {
modelIdMatches = this.modelId.equals(payload.getModelId());
}
boolean modelIdMatches = this.modelId == null || this.modelId.equals(payload.getModelId());
if (modelIdMatches) {
processed = delegate.process(payload);
boolean vModelIdMatches = this.vModelId == null || this.vModelId.equals(payload.getVModelId());
if (vModelIdMatches) {
return delegate.process(payload);
}
}
}
return processed;
return false;
}

public static MatchingPayloadProcessor from(String modelId, String method, PayloadProcessor processor) {
return from(modelId, null, method, processor);
}

public static MatchingPayloadProcessor from(String modelId, String vModelId,
String method, PayloadProcessor processor) {
if (modelId != null) {
if (modelId.length() > 0) {
if (!modelId.isEmpty()) {
modelId = modelId.replaceFirst("/", "");
if (modelId.length() == 0 || modelId.equals("*")) {
if (modelId.isEmpty() || modelId.equals("*")) {
modelId = null;
}
} else {
modelId = null;
}
}
if (method != null) {
if (method.length() == 0 || method.equals("*")) {
method = null;
if (vModelId != null) {
if (!vModelId.isEmpty()) {
vModelId = vModelId.replaceFirst("/", "");
if (vModelId.isEmpty() || vModelId.equals("*")) {
vModelId = null;
}
} else {
vModelId = null;
}
}
return new MatchingPayloadProcessor(processor, method, modelId);
if (method != null && (method.isEmpty() || method.equals("*"))) {
method = null;
}
return new MatchingPayloadProcessor(processor, method, modelId, vModelId);
}

@Override
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/com/ibm/watson/modelmesh/payload/Payload.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public enum Kind {

private final String modelId;

private final String vModelId;

private final String method;

private final Metadata metadata;
Expand All @@ -48,10 +50,17 @@ public enum Kind {
// null for requests, non-null for responses
private final Status status;


public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String method, @Nullable Metadata metadata,
@Nullable ByteBuf data, @Nullable Status status) {
this(id, modelId, null, method, metadata, data, status);
}

public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String vModelId, @Nullable String method,
@Nullable Metadata metadata, @Nullable ByteBuf data, @Nullable Status status) {
this.id = id;
this.modelId = modelId;
this.vModelId = vModelId;
this.method = method;
this.metadata = metadata;
this.data = data;
Expand All @@ -68,6 +77,16 @@ public String getModelId() {
return modelId;
}

@CheckForNull
public String getVModelId() {
return vModelId;
}

@Nonnull
public String getVModelIdOrModelId() {
return vModelId != null ? vModelId : modelId;
}

@CheckForNull
public String getMethod() {
return method;
Expand Down Expand Up @@ -101,6 +120,7 @@ public void release() {
public String toString() {
return "Payload{" +
"id='" + id + '\'' +
", vModelId=" + (vModelId != null ? ('\'' + vModelId + '\'') : "null") +
", modelId='" + modelId + '\'' +
", method='" + method + '\'' +
", status=" + (status == null ? "request" : String.valueOf(status)) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,10 @@ public boolean process(Payload payload) {
private static PayloadContent prepareContentBody(Payload payload) {
String id = payload.getId();
String modelId = payload.getModelId();
String vModelId = payload.getVModelId();
String kind = payload.getKind().toString().toLowerCase();
ByteBuf byteBuf = payload.getData();
String data;
if (byteBuf != null) {
data = encodeBinaryToString(byteBuf);
} else {
data = "";
}
String data = byteBuf != null ? encodeBinaryToString(byteBuf) : "";
Metadata metadata = payload.getMetadata();
Map<String, String> metadataMap = new HashMap<>();
if (metadata != null) {
Expand All @@ -79,7 +75,7 @@ private static PayloadContent prepareContentBody(Payload payload) {
}
}
String status = payload.getStatus() != null ? payload.getStatus().getCode().toString() : "";
return new PayloadContent(id, modelId, data, kind, status, metadataMap);
return new PayloadContent(id, modelId, vModelId, data, kind, status, metadataMap);
}

private static String encodeBinaryToString(ByteBuf byteBuf) {
Expand Down Expand Up @@ -116,15 +112,17 @@ private static class PayloadContent {

private final String id;
private final String modelid;
private final String vModelId;
private final String data;
private final String kind;
private final String status;
private final Map<String, String> metadata;

private PayloadContent(String id, String modelid, String data, String kind, String status,
Map<String, String> metadata) {
private PayloadContent(String id, String modelid, String vModelId, String data, String kind,
String status, Map<String, String> metadata) {
this.id = id;
this.modelid = modelid;
this.vModelId = vModelId;
this.data = data;
this.kind = kind;
this.status = status;
Expand All @@ -143,6 +141,10 @@ public String getModelid() {
return modelid;
}

public String getvModelId() {
return vModelId;
}

public String getData() {
return data;
}
Expand All @@ -160,6 +162,7 @@ public String toString() {
return "PayloadContent{" +
"id='" + id + '\'' +
", modelid='" + modelid + '\'' +
", vModelId=" + (vModelId != null ? ('\'' + vModelId + '\'') : "null") +
", data='" + data + '\'' +
", kind='" + kind + '\'' +
", status='" + status + '\'' +
Expand Down

0 comments on commit 92a8da8

Please sign in to comment.