Skip to content

Commit

Permalink
Remove requirement that all images within a gRPC batch to be the same…
Browse files Browse the repository at this point in the history
… dimensions -- shift responsibility to NimClient (#435)
  • Loading branch information
drobison00 authored Feb 12, 2025
1 parent 84eb689 commit ec347aa
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 16 deletions.
5 changes: 0 additions & 5 deletions src/nv_ingest/util/nim/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ def chunk_list(lst: list, chunk_size: int) -> List[list]:
if not batched_images:
raise ValueError("No valid images found for gRPC formatting.")

# Ensure all images have the same shape (excluding batch dimension)
shapes = [img.shape[1:] for img in batched_images] # each is (H, W, C)
if any(s != shapes[0] for s in shapes[1:]):
raise ValueError(f"All images must have the same dimensions for gRPC batching. Found: {shapes}")

# Chunk the images into groups of size up to max_batch_size
batched_image_chunks = chunk_list(batched_images, max_batch_size)
batched_inputs = []
Expand Down
5 changes: 0 additions & 5 deletions src/nv_ingest/util/nim/deplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ def chunk_list(lst: list, chunk_size: int) -> List[list]:
if not processed:
raise ValueError("No valid images found for gRPC formatting.")

# Ensure all images have the same dimensions (excluding the batch dimension)
shapes = [p.shape[1:] for p in processed]
if any(s != shapes[0] for s in shapes[1:]):
raise ValueError(f"All images must have the same dimensions for gRPC batching. Found: {shapes}")

# Split processed images into chunks of size at most max_batch_size
batched_inputs = []
for chunk in chunk_list(processed, max_batch_size):
Expand Down
2 changes: 1 addition & 1 deletion src/nv_ingest/util/nim/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint)
model_config = client.get_model_config(model_name=model_name, model_version=model_version)
self._max_batch_sizes[model_name] = model_config.config.max_batch_size
logger.info(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
except Exception as e:
self._max_batch_sizes[model_name] = 1
logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
Expand Down
5 changes: 0 additions & 5 deletions src/nv_ingest/util/nim/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,6 @@ def chunk_list(lst, chunk_size):
arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C)
processed.append(arr)

# Check that all images have the same shape (excluding batch dimension)
shapes = [p.shape[1:] for p in processed] # List of (H, W, C) shapes
if not all(s == shapes[0] for s in shapes[1:]):
raise ValueError(f"All images must have the same dimensions for gRPC batching. Found: {shapes}")

batches = []
for chunk in chunk_list(processed, max_batch_size):
# Concatenate arrays in the chunk along the batch dimension => shape (B, H, W, C)
Expand Down

0 comments on commit ec347aa

Please sign in to comment.