Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exception handling in server for when gRPC client closes stream #3272

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 174 additions & 129 deletions frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,6 @@ public GRPCJob(
Arrays.asList("Host", ConfigManager.getInstance().getHostName());
}

private void cancelHandler(ServerCallStreamObserver<PredictionResponse> responseObserver) {
if (responseObserver.isCancelled()) {
logger.warn(
"grpc client call already cancelled, not able to send this response for requestId: {}",
getPayload().getRequestId());
}
}

private void logQueueTime() {
logger.debug(
"Waiting time ns: {}, Backend time ns: {}",
Expand Down Expand Up @@ -118,82 +110,104 @@ public void response(
ByteString output = ByteString.copyFrom(body);
WorkerCommands cmd = this.getCmd();

switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case STREAMPREDICT2:
ServerCallStreamObserver<PredictionResponse> responseObserver =
(ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver;
cancelHandler(responseObserver);
PredictionResponse reply =
PredictionResponse.newBuilder().setPrediction(output).build();
responseObserver.onNext(reply);
if (cmd == WorkerCommands.PREDICT
|| (cmd == WorkerCommands.STREAMPREDICT
&& responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
responseObserver.onCompleted();
logQueueTime();
} else if (cmd == WorkerCommands.STREAMPREDICT2
&& (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null
|| responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
logQueueTime();
}
break;
case DESCRIBE:
try {
ArrayList<DescribeModelResponse> respList =
ApiUtils.getModelDescription(
this.getModelName(), this.getModelVersion());
if (!output.isEmpty() && respList != null && respList.size() == 1) {
respList.get(0).setCustomizedMetadata(body);
try {
switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case STREAMPREDICT2:
if (isStreamCancelled(
(ServerCallStreamObserver<PredictionResponse>)
predictionResponseObserver)) {
break;
}
String resp = JsonUtils.GSON_PRETTY.toJson(respList);
ManagementResponse mgmtReply =
ManagementResponse.newBuilder().setMsg(resp).build();
managementResponseObserver.onNext(mgmtReply);
managementResponseObserver.onCompleted();
} catch (ModelNotFoundException | ModelVersionNotFoundException e) {
ManagementImpl.sendErrorResponse(
managementResponseObserver, Status.NOT_FOUND, e);
}
break;
case OIPPREDICT:
Gson gson = new Gson();
String jsonResponse = output.toStringUtf8();
JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class);
if (((ServerCallStreamObserver<ModelInferResponse>) modelInferResponseObserver)
.isCancelled()) {
logger.warn(
"grpc client call already cancelled, not able to send this response for requestId: {}",
getPayload().getRequestId());
return;
}
ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder();
responseBuilder.setId(jsonObject.get("id").getAsString());
responseBuilder.setModelName(jsonObject.get("model_name").getAsString());
responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString());
JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray();
ServerCallStreamObserver<PredictionResponse> responseObserver =
(ServerCallStreamObserver<PredictionResponse>)
predictionResponseObserver;
PredictionResponse reply =
PredictionResponse.newBuilder().setPrediction(output).build();
responseObserver.onNext(reply);
if (cmd == WorkerCommands.PREDICT
|| (cmd == WorkerCommands.STREAMPREDICT
&& responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
responseObserver.onCompleted();
logQueueTime();
} else if (cmd == WorkerCommands.STREAMPREDICT2
&& (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null
|| responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
logQueueTime();
}
break;
case DESCRIBE:
if (isStreamCancelled(
(ServerCallStreamObserver<ManagementResponse>)
managementResponseObserver)) {
break;
}
try {
ArrayList<DescribeModelResponse> respList =
ApiUtils.getModelDescription(
this.getModelName(), this.getModelVersion());
if (!output.isEmpty() && respList != null && respList.size() == 1) {
respList.get(0).setCustomizedMetadata(body);
}
String resp = JsonUtils.GSON_PRETTY.toJson(respList);
ManagementResponse mgmtReply =
ManagementResponse.newBuilder().setMsg(resp).build();
managementResponseObserver.onNext(mgmtReply);
managementResponseObserver.onCompleted();
} catch (ModelNotFoundException | ModelVersionNotFoundException e) {
ManagementImpl.sendErrorResponse(
managementResponseObserver, Status.NOT_FOUND, e);
}
break;
case OIPPREDICT:
if (isStreamCancelled(
(ServerCallStreamObserver<ModelInferResponse>)
modelInferResponseObserver)) {
break;
}
Gson gson = new Gson();
String jsonResponse = output.toStringUtf8();
JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class);
if (((ServerCallStreamObserver<ModelInferResponse>) modelInferResponseObserver)
.isCancelled()) {
logger.warn(
"grpc client call already cancelled, not able to send this response for requestId: {}",
getPayload().getRequestId());
return;
}
ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder();
responseBuilder.setId(jsonObject.get("id").getAsString());
responseBuilder.setModelName(jsonObject.get("model_name").getAsString());
responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString());
JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray();

for (JsonElement element : jsonOutputs) {
InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder();
outputBuilder.setName(element.getAsJsonObject().get("name").getAsString());
outputBuilder.setDatatype(
element.getAsJsonObject().get("datatype").getAsString());
JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray();
shapeArray.forEach(
shapeElement -> outputBuilder.addShape(shapeElement.getAsLong()));
setOutputContents(element, outputBuilder);
responseBuilder.addOutputs(outputBuilder);
}
modelInferResponseObserver.onNext(responseBuilder.build());
modelInferResponseObserver.onCompleted();
break;
default:
break;
for (JsonElement element : jsonOutputs) {
InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder();
outputBuilder.setName(element.getAsJsonObject().get("name").getAsString());
outputBuilder.setDatatype(
element.getAsJsonObject().get("datatype").getAsString());
JsonArray shapeArray =
element.getAsJsonObject().get("shape").getAsJsonArray();
shapeArray.forEach(
shapeElement -> outputBuilder.addShape(shapeElement.getAsLong()));
setOutputContents(element, outputBuilder);
responseBuilder.addOutputs(outputBuilder);
}
modelInferResponseObserver.onNext(responseBuilder.build());
modelInferResponseObserver.onCompleted();
break;
default:
break;
}
} catch (IllegalStateException e) {
logger.error(
"grpc stream was terminated, not able to send response for requestId: {}",
getPayload().getRequestId());
}
}

Expand All @@ -202,66 +216,97 @@ public void sendError(int status, String error) {
Status responseStatus = GRPCUtils.getGRPCStatusCode(status);
WorkerCommands cmd = this.getCmd();

switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case STREAMPREDICT2:
ServerCallStreamObserver<PredictionResponse> responseObserver =
(ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver;
cancelHandler(responseObserver);
if (cmd == WorkerCommands.PREDICT || cmd == WorkerCommands.STREAMPREDICT) {
responseObserver.onError(
try {
switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case STREAMPREDICT2:
if (isStreamCancelled(
(ServerCallStreamObserver<PredictionResponse>)
predictionResponseObserver)) {
break;
}
ServerCallStreamObserver<PredictionResponse> responseObserver =
(ServerCallStreamObserver<PredictionResponse>)
predictionResponseObserver;
if (cmd == WorkerCommands.PREDICT || cmd == WorkerCommands.STREAMPREDICT) {
responseObserver.onError(
responseStatus
.withDescription(error)
.augmentDescription(
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
} else if (cmd == WorkerCommands.STREAMPREDICT2) {
com.google.rpc.Status rpcStatus =
com.google.rpc.Status.newBuilder()
.setCode(responseStatus.getCode().value())
.setMessage(error)
.addDetails(
Any.pack(
ErrorInfo.newBuilder()
.setReason(
"org.pytorch.serve.http.InternalServerException")
.build()))
.build();
responseObserver.onNext(
PredictionResponse.newBuilder()
.setPrediction(null)
.setStatus(rpcStatus)
.build());
}
break;
case DESCRIBE:
if (isStreamCancelled(
(ServerCallStreamObserver<ManagementResponse>)
managementResponseObserver)) {
break;
}
managementResponseObserver.onError(
responseStatus
.withDescription(error)
.augmentDescription(
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
} else if (cmd == WorkerCommands.STREAMPREDICT2) {
com.google.rpc.Status rpcStatus =
com.google.rpc.Status.newBuilder()
.setCode(responseStatus.getCode().value())
.setMessage(error)
.addDetails(
Any.pack(
ErrorInfo.newBuilder()
.setReason(
"org.pytorch.serve.http.InternalServerException")
.build()))
.build();
responseObserver.onNext(
PredictionResponse.newBuilder()
.setPrediction(null)
.setStatus(rpcStatus)
.build());
}
break;
case DESCRIBE:
managementResponseObserver.onError(
responseStatus
.withDescription(error)
.augmentDescription(
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
break;
case OIPPREDICT:
modelInferResponseObserver.onError(
responseStatus
.withDescription(error)
.augmentDescription(
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
break;
default:
break;
break;
case OIPPREDICT:
if (isStreamCancelled(
(ServerCallStreamObserver<ModelInferResponse>)
modelInferResponseObserver)) {
break;
}
modelInferResponseObserver.onError(
responseStatus
.withDescription(error)
.augmentDescription(
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
break;
default:
break;
}
} catch (IllegalStateException e) {
logger.error(
"grpc stream was terminated, not able to send response for requestId: {}",
getPayload().getRequestId());
}
}

@Override
public boolean isOpen() {
return ((ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver)
return !((ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver)
.isCancelled();
}

private boolean isStreamCancelled(ServerCallStreamObserver<?> responseObserver) {
if (responseObserver.isCancelled()) {
logger.warn(
"grpc client call already cancelled, not able to send response for requestId: {}",
getPayload().getRequestId());
return true;
}
return false;
}

private void setOutputContents(JsonElement element, InferOutputTensor.Builder outputBuilder) {
String dataType = element.getAsJsonObject().get("datatype").getAsString();
JsonArray jsonData = element.getAsJsonObject().get("data").getAsJsonArray();
Expand Down
Loading