diff --git a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java index 70e2fe5705..4dd671e8a9 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java +++ b/frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java @@ -82,14 +82,6 @@ public GRPCJob( Arrays.asList("Host", ConfigManager.getInstance().getHostName()); } - private void cancelHandler(ServerCallStreamObserver responseObserver) { - if (responseObserver.isCancelled()) { - logger.warn( - "grpc client call already cancelled, not able to send this response for requestId: {}", - getPayload().getRequestId()); - } - } - private void logQueueTime() { logger.debug( "Waiting time ns: {}, Backend time ns: {}", @@ -118,82 +110,104 @@ public void response( ByteString output = ByteString.copyFrom(body); WorkerCommands cmd = this.getCmd(); - switch (cmd) { - case PREDICT: - case STREAMPREDICT: - case STREAMPREDICT2: - ServerCallStreamObserver responseObserver = - (ServerCallStreamObserver) predictionResponseObserver; - cancelHandler(responseObserver); - PredictionResponse reply = - PredictionResponse.newBuilder().setPrediction(output).build(); - responseObserver.onNext(reply); - if (cmd == WorkerCommands.PREDICT - || (cmd == WorkerCommands.STREAMPREDICT - && responseHeaders - .get(RequestInput.TS_STREAM_NEXT) - .equals("false"))) { - responseObserver.onCompleted(); - logQueueTime(); - } else if (cmd == WorkerCommands.STREAMPREDICT2 - && (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null - || responseHeaders - .get(RequestInput.TS_STREAM_NEXT) - .equals("false"))) { - logQueueTime(); - } - break; - case DESCRIBE: - try { - ArrayList respList = - ApiUtils.getModelDescription( - this.getModelName(), this.getModelVersion()); - if (!output.isEmpty() && respList != null && respList.size() == 1) { - respList.get(0).setCustomizedMetadata(body); + try { + switch (cmd) { + case PREDICT: + case STREAMPREDICT: + case STREAMPREDICT2: + if (isStreamCancelled( + (ServerCallStreamObserver) + predictionResponseObserver)) { + break; } - String resp = JsonUtils.GSON_PRETTY.toJson(respList); - ManagementResponse mgmtReply = - ManagementResponse.newBuilder().setMsg(resp).build(); - managementResponseObserver.onNext(mgmtReply); - managementResponseObserver.onCompleted(); - } catch (ModelNotFoundException | ModelVersionNotFoundException e) { - ManagementImpl.sendErrorResponse( - managementResponseObserver, Status.NOT_FOUND, e); - } - break; - case OIPPREDICT: - Gson gson = new Gson(); - String jsonResponse = output.toStringUtf8(); - JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class); - if (((ServerCallStreamObserver) modelInferResponseObserver) - .isCancelled()) { - logger.warn( - "grpc client call already cancelled, not able to send this response for requestId: {}", - getPayload().getRequestId()); - return; - } - ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder(); - responseBuilder.setId(jsonObject.get("id").getAsString()); - responseBuilder.setModelName(jsonObject.get("model_name").getAsString()); - responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString()); - JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray(); + ServerCallStreamObserver responseObserver = + (ServerCallStreamObserver) + predictionResponseObserver; + PredictionResponse reply = + PredictionResponse.newBuilder().setPrediction(output).build(); + responseObserver.onNext(reply); + if (cmd == WorkerCommands.PREDICT + || (cmd == WorkerCommands.STREAMPREDICT + && responseHeaders + .get(RequestInput.TS_STREAM_NEXT) + .equals("false"))) { + responseObserver.onCompleted(); + logQueueTime(); + } else if (cmd == WorkerCommands.STREAMPREDICT2 + && (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null + || responseHeaders + .get(RequestInput.TS_STREAM_NEXT) + .equals("false"))) { + logQueueTime(); + } + break; + case DESCRIBE: + if (isStreamCancelled( + (ServerCallStreamObserver) + managementResponseObserver)) { + break; + } + try { + ArrayList respList = + ApiUtils.getModelDescription( + this.getModelName(), this.getModelVersion()); + if (!output.isEmpty() && respList != null && respList.size() == 1) { + respList.get(0).setCustomizedMetadata(body); + } + String resp = JsonUtils.GSON_PRETTY.toJson(respList); + ManagementResponse mgmtReply = + ManagementResponse.newBuilder().setMsg(resp).build(); + managementResponseObserver.onNext(mgmtReply); + managementResponseObserver.onCompleted(); + } catch (ModelNotFoundException | ModelVersionNotFoundException e) { + ManagementImpl.sendErrorResponse( + managementResponseObserver, Status.NOT_FOUND, e); + } + break; + case OIPPREDICT: + if (isStreamCancelled( + (ServerCallStreamObserver) + modelInferResponseObserver)) { + break; + } + Gson gson = new Gson(); + String jsonResponse = output.toStringUtf8(); + JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class); + if (((ServerCallStreamObserver) modelInferResponseObserver) + .isCancelled()) { + logger.warn( + "grpc client call already cancelled, not able to send this response for requestId: {}", + getPayload().getRequestId()); + return; + } + ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder(); + responseBuilder.setId(jsonObject.get("id").getAsString()); + responseBuilder.setModelName(jsonObject.get("model_name").getAsString()); + responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString()); + JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray(); - for (JsonElement element : jsonOutputs) { - InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder(); - outputBuilder.setName(element.getAsJsonObject().get("name").getAsString()); - outputBuilder.setDatatype( - element.getAsJsonObject().get("datatype").getAsString()); - JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray(); - shapeArray.forEach( - shapeElement -> outputBuilder.addShape(shapeElement.getAsLong())); - setOutputContents(element, outputBuilder); - responseBuilder.addOutputs(outputBuilder); - } - modelInferResponseObserver.onNext(responseBuilder.build()); - modelInferResponseObserver.onCompleted(); - break; - default: - break; + for (JsonElement element : jsonOutputs) { + InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder(); + outputBuilder.setName(element.getAsJsonObject().get("name").getAsString()); + outputBuilder.setDatatype( + element.getAsJsonObject().get("datatype").getAsString()); + JsonArray shapeArray = + element.getAsJsonObject().get("shape").getAsJsonArray(); + shapeArray.forEach( + shapeElement -> outputBuilder.addShape(shapeElement.getAsLong())); + setOutputContents(element, outputBuilder); + responseBuilder.addOutputs(outputBuilder); + } + modelInferResponseObserver.onNext(responseBuilder.build()); + modelInferResponseObserver.onCompleted(); + break; + default: + break; + } + } catch (IllegalStateException e) { + logger.error( + "grpc stream was terminated, not able to send response for requestId: {}", + getPayload().getRequestId()); } } @@ -202,66 +216,97 @@ public void sendError(int status, String error) { Status responseStatus = GRPCUtils.getGRPCStatusCode(status); WorkerCommands cmd = this.getCmd(); - switch (cmd) { - case PREDICT: - case STREAMPREDICT: - case STREAMPREDICT2: - ServerCallStreamObserver responseObserver = - (ServerCallStreamObserver) predictionResponseObserver; - cancelHandler(responseObserver); - if (cmd == WorkerCommands.PREDICT || cmd == WorkerCommands.STREAMPREDICT) { - responseObserver.onError( + try { + switch (cmd) { + case PREDICT: + case STREAMPREDICT: + case STREAMPREDICT2: + if (isStreamCancelled( + (ServerCallStreamObserver) + predictionResponseObserver)) { + break; + } + ServerCallStreamObserver responseObserver = + (ServerCallStreamObserver) + predictionResponseObserver; + if (cmd == WorkerCommands.PREDICT || cmd == WorkerCommands.STREAMPREDICT) { + responseObserver.onError( + responseStatus + .withDescription(error) + .augmentDescription( + "org.pytorch.serve.http.InternalServerException") + .asRuntimeException()); + } else if (cmd == WorkerCommands.STREAMPREDICT2) { + com.google.rpc.Status rpcStatus = + com.google.rpc.Status.newBuilder() + .setCode(responseStatus.getCode().value()) + .setMessage(error) + .addDetails( + Any.pack( + ErrorInfo.newBuilder() + .setReason( + "org.pytorch.serve.http.InternalServerException") + .build())) + .build(); + responseObserver.onNext( + PredictionResponse.newBuilder() + .setPrediction(null) + .setStatus(rpcStatus) + .build()); + } + break; + case DESCRIBE: + if (isStreamCancelled( + (ServerCallStreamObserver) + managementResponseObserver)) { + break; + } + managementResponseObserver.onError( responseStatus .withDescription(error) .augmentDescription( "org.pytorch.serve.http.InternalServerException") .asRuntimeException()); - } else if (cmd == WorkerCommands.STREAMPREDICT2) { - com.google.rpc.Status rpcStatus = - com.google.rpc.Status.newBuilder() - .setCode(responseStatus.getCode().value()) - .setMessage(error) - .addDetails( - Any.pack( - ErrorInfo.newBuilder() - .setReason( - "org.pytorch.serve.http.InternalServerException") - .build())) - .build(); - responseObserver.onNext( - PredictionResponse.newBuilder() - .setPrediction(null) - .setStatus(rpcStatus) - .build()); - } - break; - case DESCRIBE: - managementResponseObserver.onError( - responseStatus - .withDescription(error) - .augmentDescription( - "org.pytorch.serve.http.InternalServerException") - .asRuntimeException()); - break; - case OIPPREDICT: - modelInferResponseObserver.onError( - responseStatus - .withDescription(error) - .augmentDescription( - "org.pytorch.serve.http.InternalServerException") - .asRuntimeException()); - break; - default: - break; + break; + case OIPPREDICT: + if (isStreamCancelled( + (ServerCallStreamObserver) + modelInferResponseObserver)) { + break; + } + modelInferResponseObserver.onError( + responseStatus + .withDescription(error) + .augmentDescription( + "org.pytorch.serve.http.InternalServerException") + .asRuntimeException()); + break; + default: + break; + } + } catch (IllegalStateException e) { + logger.error( + "grpc stream was terminated, not able to send response for requestId: {}", + getPayload().getRequestId()); } } @Override public boolean isOpen() { - return ((ServerCallStreamObserver) predictionResponseObserver) + return !((ServerCallStreamObserver) predictionResponseObserver) .isCancelled(); } + private boolean isStreamCancelled(ServerCallStreamObserver responseObserver) { + if (responseObserver.isCancelled()) { + logger.warn( + "grpc client call already cancelled, not able to send response for requestId: {}", + getPayload().getRequestId()); + return true; + } + return false; + } + private void setOutputContents(JsonElement element, InferOutputTensor.Builder outputBuilder) { String dataType = element.getAsJsonObject().get("datatype").getAsString(); JsonArray jsonData = element.getAsJsonObject().get("data").getAsJsonArray();