Skip to content

Commit

Permalink
Fix failing unit tests (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
drobison00 authored Feb 12, 2025
1 parent 4affc6d commit 0560e81
Show file tree
Hide file tree
Showing 6 changed files with 444 additions and 882 deletions.
5 changes: 4 additions & 1 deletion src/nv_ingest/util/nim/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ def chunk_list(lst, chunk_size):

elif protocol == "http":
logger.debug("Formatting input for HTTP PaddleOCR model (batched).")
base64_list = [self.image_array_to_base64(img) for img in images]
if "base64_images" in data:
base64_list = data["base64_images"]
else:
base64_list = [data["base64_image"]]

if self._is_version_early_access_legacy_api():
content_list: List[Dict[str, Any]] = []
Expand Down
146 changes: 93 additions & 53 deletions tests/nv_ingest/util/nim/test_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,80 +91,120 @@ def test_prepare_data_for_inference_missing_base64_image(model_interface):

def test_format_input_grpc_with_ndim_3(model_interface):
"""
Test format_input for 'grpc' protocol with a 3-dimensional image array.
Expects the image array to be expanded and cast to float32.
Test format_input for the 'grpc' protocol when given a 3-dimensional image array.
The test verifies that the image is expanded along a new batch dimension and cast to float32.
It also confirms that the accompanying batch data reflects the original image and its dimensions.
"""
# Assume create_base64_image() returns a base64-encoded image that decodes to a (64, 64, 3) array.
base64_img = create_base64_image()
data = model_interface.prepare_data_for_inference({"base64_image": base64_img})

formatted_input = model_interface.format_input(data, "grpc", max_batch_size=1)[0]

assert isinstance(formatted_input, np.ndarray)
assert formatted_input.dtype == np.float32
assert formatted_input.shape == (1, 64, 64, 3) # Expanded along axis 0
# format_input returns a tuple: (batched_inputs, formatted_batch_data)
formatted_batches, batch_data = model_interface.format_input(data, "grpc", max_batch_size=1)

# Check that the batched input is a single numpy array with a new batch dimension.
assert isinstance(formatted_batches, list)
assert len(formatted_batches) == 1
batched_input = formatted_batches[0]
assert isinstance(batched_input, np.ndarray)
assert batched_input.dtype == np.float32
# The original image shape (64,64,3) should have been expanded to (1,64,64,3).
assert batched_input.shape == (1, 64, 64, 3)

# Verify that batch data contains the original image and its dimensions.
assert isinstance(batch_data, list)
assert len(batch_data) == 1
bd = batch_data[0]
assert "image_arrays" in bd and "image_dims" in bd
# The original image should be unmodified (still 3D) in batch_data.
assert len(bd["image_arrays"]) == 1
# Expect dimensions to be (H, W) i.e. (64, 64).
assert bd["image_dims"] == [(64, 64)]


def test_format_input_grpc_with_ndim_other(model_interface):
"""
Test format_input for 'grpc' protocol with a non-3-dimensional image array.
Expects the image array to be cast to float32 without expansion.
Test format_input for the 'grpc' protocol when given a non-3-dimensional image array.
This test uses a grayscale image which decodes to a 2D array.
The expected behavior is that the image is cast to float32 without being expanded.
Batch data is also checked for correct original dimensions.
"""
# Create a grayscale image (2D array)
# Create a grayscale (L mode) image of size 64x64.
with BytesIO() as buffer:
image = Image.new("L", (64, 64), 128) # 'L' mode for grayscale
image = Image.new("L", (64, 64), 128)
image.save(buffer, format="PNG")
base64_img = base64.b64encode(buffer.getvalue()).decode("utf-8")

data = model_interface.prepare_data_for_inference({"base64_image": base64_img})

formatted_input = model_interface.format_input(data, "grpc", max_batch_size=1)[0]

assert isinstance(formatted_input, np.ndarray)
assert formatted_input.dtype == np.float32
assert formatted_input.shape == (64, 64) # No expansion
formatted_batches, batch_data = model_interface.format_input(data, "grpc", max_batch_size=1)

# Check that the batched input is a numpy array without expansion.
assert isinstance(formatted_batches, list)
assert len(formatted_batches) == 1
batched_input = formatted_batches[0]
assert isinstance(batched_input, np.ndarray)
assert batched_input.dtype == np.float32
# For a 2D image (64,64), no extra batch dimension is added when max_batch_size=1.
assert batched_input.shape == (64, 64)

# Verify that batch data correctly reports the original image dimensions.
assert isinstance(batch_data, list)
assert len(batch_data) == 1
bd = batch_data[0]
assert "image_arrays" in bd and "image_dims" in bd
assert len(bd["image_arrays"]) == 1
# The image dimensions should reflect a 2D image: (64, 64)
assert bd["image_dims"] == [(64, 64)]


def test_format_input_http(model_interface):
"""
Test format_input for 'http' protocol under the new approach:
- The code expects 'image_arrays' in data
- Each array is re-encoded as PNG
- A single Nim message is built with multiple images in the 'content' array
Test format_input for the 'http' protocol.
This test ensures that given data with key "image_arrays", the images are re-encoded as PNG,
and a single payload is built with a proper Nim message containing the image content.
Additionally, it verifies that the accompanying batch data contains the original images and their dimensions.
"""
# 1) Create a small in-memory base64 image
# This is just a placeholder function to generate or load some base64 data
# Generate a base64-encoded image and decode it into a numpy array.
base64_img = create_base64_image()

# 2) Decode it into a NumPy array (mimicking prepare_data_for_inference)
arr = base64_to_numpy(base64_img) # or however your code does this

# 3) Build the data dict with "image_arrays"
data = {"image_arrays": [arr]} # single array for a single test image

# 4) Call format_input
formatted_input = model_interface.format_input(data, "http", max_batch_size=1)[0]

# 5) Verify the structure of the HTTP payload

assert "messages" in formatted_input, "Expected 'messages' key in output"
assert len(formatted_input["messages"]) == 1, "Expected exactly 1 message"

message = formatted_input["messages"][0]
assert "content" in message, "Expected 'content' key in the message"
assert len(message["content"]) == 1, "Expected exactly 1 image in content for this test"

content_item = message["content"][0]
assert content_item["type"] == "image_url", "Expected 'type' == 'image_url'"
assert "image_url" in content_item, "Expected 'image_url' key in content item"
assert "url" in content_item["image_url"], "Expected 'url' key in 'image_url' dict"

# 6) Optionally, check the prefix of the base64 URL
arr = base64_to_numpy(base64_img)

# Build the data dictionary directly with the "image_arrays" key.
data = {"image_arrays": [arr]}

payload_batches, batch_data = model_interface.format_input(data, "http", max_batch_size=1)

# Verify the HTTP payload structure.
assert isinstance(payload_batches, list)
assert len(payload_batches) == 1
payload = payload_batches[0]
assert "messages" in payload
messages = payload["messages"]
assert isinstance(messages, list)
assert len(messages) == 1
message = messages[0]
assert "content" in message
content_list = message["content"]
assert isinstance(content_list, list)
assert len(content_list) == 1
content_item = content_list[0]
assert content_item["type"] == "image_url"
assert "image_url" in content_item and "url" in content_item["image_url"]

# Check that the URL starts with the expected PNG base64 prefix.
url_value = content_item["image_url"]["url"]
assert url_value.startswith(
"data:image/png;base64,"
), f"URL should start with data:image/png;base64, but got {url_value[:30]}..."
# And check that there's something after the prefix
assert len(url_value) > len("data:image/png;base64,"), "Base64 string seems empty"
expected_prefix = "data:image/png;base64,"
assert url_value.startswith(expected_prefix)
assert len(url_value) > len(expected_prefix)

# Verify that the batch data is correctly built.
assert isinstance(batch_data, list)
assert len(batch_data) == 1
bd = batch_data[0]
assert "image_arrays" in bd and "image_dims" in bd
assert len(bd["image_arrays"]) == 1
# The expected dimensions should match the original array's height and width.
expected_dims = [(arr.shape[0], arr.shape[1])]
assert bd["image_dims"] == expected_dims


def test_format_input_invalid_protocol(model_interface):
Expand Down
73 changes: 69 additions & 4 deletions tests/nv_ingest/util/nim/test_deplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,44 +61,109 @@ def test_prepare_data_for_inference_invalid_base64_image(model_interface):


def test_format_input_grpc(model_interface):
"""
Test that for the gRPC protocol:
- The image (decoded from a base64 string) is normalized and batched.
- The returned formatted batch is a NumPy array of shape (B, H, W, C) with dtype float32.
- The accompanying batch data contains the original image and its dimensions.
"""
base64_img = create_base64_image()
prepared = model_interface.prepare_data_for_inference({"base64_image": base64_img})
formatted = model_interface.format_input(prepared, "grpc", max_batch_size=1)[0]
# format_input returns a tuple: (formatted_batches, formatted_batch_data)
batches, batch_data = model_interface.format_input(prepared, "grpc", max_batch_size=1)

formatted = batches[0]
# Check the formatted batch
assert isinstance(formatted, np.ndarray)
assert formatted.dtype == np.float32
# Since prepare_data_for_inference decodes to (256,256,3), the grpc branch expands it to (1,256,256,3)
assert formatted.ndim == 4
assert formatted.shape == (1, 256, 256, 3)
# Ensure normalization to [0, 1]
assert 0.0 <= formatted.min() and formatted.max() <= 1.0

# Verify accompanying batch data
assert isinstance(batch_data, list)
assert len(batch_data) == 1
bd = batch_data[0]
assert "image_arrays" in bd and "image_dims" in bd
assert isinstance(bd["image_arrays"], list)
assert len(bd["image_arrays"]) == 1
# The original image should have shape (256,256,3)
assert bd["image_arrays"][0].shape == (256, 256, 3)
# Dimensions should be recorded as (height, width)
assert bd["image_dims"] == [(256, 256)]


def test_format_input_http(model_interface):
"""
Test that for the HTTP protocol:
- The formatted payload is a JSON-serializable dict built via _prepare_deplot_payload.
- The payload includes the expected keys (model, messages, max_tokens, stream, temperature, top_p)
- And the accompanying batch data reflects the original image and its dimensions.
"""
base64_img = create_base64_image()
prepared = model_interface.prepare_data_for_inference({"base64_image": base64_img})
formatted = model_interface.format_input(
batches, batch_data = model_interface.format_input(
prepared, "http", max_batch_size=1, max_tokens=600, temperature=0.7, top_p=0.95
)[0]
)
formatted = batches[0]

# Check the payload structure from _prepare_deplot_payload
assert isinstance(formatted, dict)
assert formatted["model"] == "google/deplot"
assert "messages" in formatted
assert isinstance(formatted["messages"], list)
assert len(formatted["messages"]) == 1
message = formatted["messages"][0]
assert message["role"] == "user"
# The content should start with the fixed prompt text
assert message["content"].startswith("Generate the underlying data table")
# Check that the payload parameters match the supplied arguments
assert formatted["max_tokens"] == 600
assert formatted["temperature"] == 0.7
assert formatted["top_p"] == 0.95
assert formatted["stream"] is False

# Verify accompanying batch data
assert isinstance(batch_data, list)
assert len(batch_data) == 1
bd = batch_data[0]
assert "image_arrays" in bd and "image_dims" in bd
assert isinstance(bd["image_arrays"], list)
assert len(bd["image_arrays"]) == 1
assert bd["image_arrays"][0].shape == (256, 256, 3)
assert bd["image_dims"] == [(256, 256)]


def test_format_input_http_defaults(model_interface):
"""
Test the HTTP branch when default parameters are used.
- The default max_tokens, temperature, and top_p values should be applied.
- The stream flag should be False.
- Also verify that batch data is correctly returned.
"""
base64_img = create_base64_image()
prepared = model_interface.prepare_data_for_inference({"base64_image": base64_img})
formatted = model_interface.format_input(prepared, "http", max_batch_size=1)[0]
batches, batch_data = model_interface.format_input(prepared, "http", max_batch_size=1)
formatted = batches[0]

# Check that default values are set
assert formatted["max_tokens"] == 500
assert formatted["temperature"] == 0.5
assert formatted["top_p"] == 0.9
assert formatted["stream"] is False

# Verify accompanying batch data
assert isinstance(batch_data, list)
assert len(batch_data) == 1
bd = batch_data[0]
assert "image_arrays" in bd and "image_dims" in bd
assert isinstance(bd["image_arrays"], list)
assert len(bd["image_arrays"]) == 1
assert bd["image_arrays"][0].shape == (256, 256, 3)
assert bd["image_dims"] == [(256, 256)]


def test_format_input_invalid_protocol(model_interface):
base64_img = create_base64_image()
Expand Down
Loading

0 comments on commit 0560e81

Please sign in to comment.