Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Client-side input shape/element validation #742

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/python/library/tritonclient/grpc/_infer_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -30,7 +30,7 @@
from tritonclient.grpc import service_pb2
from tritonclient.utils import *

from ._utils import raise_error
from ._utils import num_elements, raise_error
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved


class InferInput:
Expand All @@ -54,6 +54,7 @@ def __init__(self, name, shape, datatype):
self._input.ClearField("shape")
self._input.shape.extend(shape)
self._input.datatype = datatype
self._data_shape = None
self._raw_content = None

def name(self):
Expand Down Expand Up @@ -86,6 +87,36 @@ def shape(self):
"""
return self._input.shape

def validate_data(self):
"""Validate input has data and input shape matches input data.

Returns
-------
None
"""
# Input must set only one of the following fields: '_raw_content',
# 'shared_memory_region' in '_input.parameters'
cnt = 0
cnt += self._raw_content != None
cnt += "shared_memory_region" in self._input.parameters
if cnt != 1:
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we return an error when more that one fields are specified in the inputs?

Copy link
Contributor Author

@yinggeh yinggeh Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error was handled by the server.


# Skip due to trt reformat free tensor
if "shared_memory_region" in self._input.parameters:
return

# Not using shared memory
expected_num_elements = num_elements(self._input.shape)
data_num_elements = num_elements(self._data_shape)
if expected_num_elements != data_num_elements:
raise_error(
"input '{}' got unexpected elements count {}, expected {}".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also include the respective shapes in the error message as well?

something like:

input 'XYZ' got unexpected elements count 8 (shape: 8,1), expected 16 (shape: 16,1)

I think you are trying to keep supporting the case where a user might just want to call set_shape() with a different shape but same underlying data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are shapes (8,2), (4,4) also valid?

self._input.name, data_num_elements, expected_num_elements
)
)
return

def set_shape(self, shape):
"""Set the shape of input.

Expand Down Expand Up @@ -171,6 +202,7 @@ def set_data_from_numpy(self, input_tensor):
self._raw_content = b""
else:
self._raw_content = input_tensor.tobytes()
self._data_shape = input_tensor.shape
return self

def set_shared_memory(self, region_name, byte_size, offset=0):
Expand All @@ -193,6 +225,7 @@ def set_shared_memory(self, region_name, byte_size, offset=0):
"""
self._input.ClearField("contents")
self._raw_content = None
self._data_shape = None

self._input.parameters["shared_memory_region"].string_param = region_name
self._input.parameters["shared_memory_byte_size"].int64_param = byte_size
Expand Down
3 changes: 2 additions & 1 deletion src/python/library/tritonclient/grpc/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -96,6 +96,7 @@ def _get_inference_request(
if request_id != "":
request.id = request_id
for infer_input in inputs:
infer_input.validate_data()
request.inputs.extend([infer_input._get_tensor()])
if infer_input._get_content() is not None:
request.raw_input_contents.extend([infer_input._get_content()])
Expand Down
37 changes: 36 additions & 1 deletion src/python/library/tritonclient/http/_infer_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -28,6 +28,7 @@
import numpy as np
from tritonclient.utils import (
np_to_triton_dtype,
num_elements,
raise_error,
serialize_bf16_tensor,
serialize_byte_tensor,
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(self, name, shape, datatype):
self._datatype = datatype
self._parameters = {}
self._data = None
self._data_shape = None
self._raw_data = None

def name(self):
Expand Down Expand Up @@ -87,6 +89,37 @@ def shape(self):
"""
return self._shape

def validate_data(self):
"""Validate input has data and input shape matches input data.

Returns
-------
None
"""
# Input must set only one of the following fields: 'data', 'binary_data_size',
# 'shared_memory_region' in 'parameters'
cnt = 0
cnt += self._data != None
cnt += "binary_data_size" in self._parameters
cnt += "shared_memory_region" in self._parameters
if cnt != 1:
return

# Skip due to trt reformat free tensor
if "shared_memory_region" in self._parameters:
return

# Not using shared memory
expected_num_elements = num_elements(self._shape)
data_num_elements = num_elements(self._data_shape)
if expected_num_elements != data_num_elements:
raise_error(
"input '{}' got unexpected elements count {}, expected {}".format(
self._name, data_num_elements, expected_num_elements
)
)
return

def set_shape(self, shape):
"""Set the shape of input.

Expand Down Expand Up @@ -211,6 +244,7 @@ def set_data_from_numpy(self, input_tensor, binary_data=True):
else:
self._raw_data = input_tensor.tobytes()
self._parameters["binary_data_size"] = len(self._raw_data)
self._data_shape = input_tensor.shape
return self

def set_shared_memory(self, region_name, byte_size, offset=0):
Expand All @@ -232,6 +266,7 @@ def set_shared_memory(self, region_name, byte_size, offset=0):
The updated input
"""
self._data = None
self._data_shape = None
self._raw_data = None
self._parameters.pop("binary_data_size", None)

Expand Down
8 changes: 6 additions & 2 deletions src/python/library/tritonclient/http/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -106,7 +106,11 @@ def _get_inference_request(
if timeout is not None:
parameters["timeout"] = timeout

infer_request["inputs"] = [this_input._get_tensor() for this_input in inputs]
infer_request["inputs"] = []
for infer_input in inputs:
infer_input.validate_data()
infer_request["inputs"].append(infer_input._get_tensor())

if outputs:
infer_request["outputs"] = [
this_output._get_tensor() for this_output in outputs
Expand Down
23 changes: 22 additions & 1 deletion src/python/library/tritonclient/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -40,6 +40,27 @@ def raise_error(msg):
raise InferenceServerException(msg=msg) from None


def num_elements(shape):
"""
Calculate the number of elements in an array given its shape.

Parameters
----------
shape : list or tuple
Shape of the array.

Returns
-------
int
Number of elements in the array.
"""

num_elements = 1
for dim in shape:
num_elements *= dim
return num_elements


def serialized_byte_size(tensor_value):
"""
Get the underlying number of bytes for a numpy ndarray.
Expand Down
Loading