diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index 0eabc6aa..ad4de2d6 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -34,7 +34,6 @@ from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 from nv_ingest.util.nim.helpers import create_inference_client -from nv_ingest.util.nim.helpers import get_version from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image @@ -64,22 +63,8 @@ def extract_tables_and_charts_using_image_ensemble( ) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]] tables_and_charts = [] - # Obtain yolox_version - # Assuming that the grpc endpoint is at index 0 - yolox_http_endpoint = config.yolox_endpoints[1] try: - yolox_version = get_version(yolox_http_endpoint) - if not yolox_version: - logger.warning( - "Failed to obtain yolox-page-elements version from the endpoint. Falling back to the latest version." - ) - yolox_version = None # Default to the latest version - except Exception: - logger.waring("Failed to get yolox-page-elements version after 30 seconds. Falling back to the latest version.") - yolox_version = None # Default to the latest version - - try: - model_interface = yolox_utils.YoloxPageElementsModelInterface(yolox_version=yolox_version) + model_interface = yolox_utils.YoloxPageElementsModelInterface() yolox_client = create_inference_client( config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol ) diff --git a/src/nv_ingest/util/nim/cached.py b/src/nv_ingest/util/nim/cached.py index 56513d08..1a7bf0c9 100644 --- a/src/nv_ingest/util/nim/cached.py +++ b/src/nv_ingest/util/nim/cached.py @@ -119,7 +119,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") - def process_inference_results(self, output: Any, **kwargs) -> Any: + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any: """ Process inference results for the Cached model. diff --git a/src/nv_ingest/util/nim/deplot.py b/src/nv_ingest/util/nim/deplot.py index 9cf6175d..63f16a3b 100644 --- a/src/nv_ingest/util/nim/deplot.py +++ b/src/nv_ingest/util/nim/deplot.py @@ -133,7 +133,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") - def process_inference_results(self, output: Any, **kwargs) -> Any: + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any: """ Process inference results for the Deplot model. diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index db7e0fdd..bfbb37a0 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -75,7 +75,7 @@ def prepare_data_for_inference(self, data: dict): """ raise NotImplementedError("Subclasses should implement this method") - def process_inference_results(self, output_array, **kwargs): + def process_inference_results(self, output_array, protocol: str, **kwargs): """ Process the inference results from the model. @@ -206,7 +206,7 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: response, protocol=self.protocol, data=prepared_data, **kwargs ) results = self.model_interface.process_inference_results( - parsed_output, original_image_shapes=data.get("original_image_shapes"), **kwargs + parsed_output, original_image_shapes=data.get("original_image_shapes"), protocol=self.protocol, **kwargs ) return results diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index d07f184e..831c4e62 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -16,7 +16,6 @@ import numpy as np import torch import torchvision -from packaging import version as pkgversion from PIL import Image from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size @@ -44,20 +43,6 @@ class YoloxPageElementsModelInterface(ModelInterface): An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols. """ - def __init__( - self, - yolox_version: Optional[str] = None, - ): - """ - Initialize the YOLOX model interface. - - Parameters - ---------- - yolox_version : str, optional - The version of the YOLOX model (default: None). - """ - self.yolox_version = yolox_version - def name( self, ) -> str: @@ -70,7 +55,7 @@ def name( The name of the model interface. """ - return f"yolox-page-elements (version {self.yolox_version})" + return "yolox-page-elements" def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -86,16 +71,16 @@ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: dict The updated data dictionary with resized images and original image shapes. """ + if (not isinstance(data, dict)) or ("images" not in data): + raise KeyError("Input data must be a dictionary containing an 'images' key with a list of images.") + + if not all(isinstance(x, np.ndarray) for x in data["images"]): + raise ValueError("All elements in the 'images' list must be numpy.ndarray objects.") original_images = data["images"] - # Our yolox model expects images to be resized to 1024x1024 - resized_images = [ - resize_image(image, (YOLOX_IMAGE_PREPROC_WIDTH, YOLOX_IMAGE_PREPROC_HEIGHT)) for image in original_images - ] data["original_image_shapes"] = [image.shape for image in original_images] - data["resized_images"] = resized_images - return data # Return data with added 'resized_images' key + return data def format_input(self, data: Dict[str, Any], protocol: str) -> Any: """ @@ -121,16 +106,18 @@ def format_input(self, data: Dict[str, Any], protocol: str) -> Any: if protocol == "grpc": logger.debug("Formatting input for gRPC Yolox model") + # Our yolox-page-elements model (grPC) expects images to be resized to 1024x1024 + resized_images = [ + resize_image(image, (YOLOX_IMAGE_PREPROC_WIDTH, YOLOX_IMAGE_PREPROC_HEIGHT)) for image in data["images"] + ] # Reorder axes to match model input (batch, channels, height, width) - input_array = np.einsum("bijk->bkij", data["resized_images"]).astype(np.float32) + input_array = np.einsum("bijk->bkij", resized_images).astype(np.float32) return input_array elif protocol == "http": logger.debug("Formatting input for HTTP Yolox model") - # Additional lists to keep track of scaling factors and new sizes - scaling_factors = [] content_list = [] - for image in data["resized_images"]: + for image in data["images"]: # Convert numpy array to PIL Image image_pil = Image.fromarray((image * 255).astype(np.uint8)) original_size = image_pil.size # Should be (1024, 1024) @@ -148,26 +135,12 @@ def format_input(self, data: Dict[str, Any], protocol: str) -> Any: if new_size != original_size: logger.warning(f"Image was scaled from {original_size} to {new_size} to meet size constraints.") - # Compute scaling factor - scaling_factor_x = new_size[0] / YOLOX_IMAGE_PREPROC_WIDTH - scaling_factor_y = new_size[1] / YOLOX_IMAGE_PREPROC_HEIGHT - scaling_factors.append((scaling_factor_x, scaling_factor_y)) - # Add to content_list - if self._is_version_early_access_legacy_api(): - content = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{scaled_image_b64}"}} - else: - content = {"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"} + content = {"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"} content_list.append(content) - # Store scaling factors in data - data["scaling_factors"] = scaling_factors - - if self._is_version_early_access_legacy_api(): - payload = {"messages": [{"content": content_list}]} - else: - payload = {"input": content_list} + payload = {"input": content_list} return payload else: @@ -203,108 +176,30 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An elif protocol == "http": logger.debug("Parsing output from HTTP Yolox model") - is_legacy_version = self._is_version_early_access_legacy_api() - - # Convert JSON response to numpy array similar to gRPC response - if is_legacy_version: - # Convert response data to GA API format. - response_data = response.get("data", []) - batch_results = [] - for idx, detections in enumerate(response_data): - curr_batch = {"index": idx, "bounding_boxes": {}} - for obj in detections: - obj_type = obj.get("type", "") - bboxes = obj.get("bboxes", []) - if not obj_type: - continue - if obj_type not in curr_batch: - curr_batch["bounding_boxes"][obj_type] = [] - curr_batch["bounding_boxes"][obj_type].extend(bboxes) - batch_results.append(curr_batch) - else: - batch_results = response.get("data", []) - - batch_size = len(batch_results) processed_outputs = [] - scaling_factors = data.get("scaling_factors", [(1.0, 1.0)] * batch_size) - - x_min_label = "xmin" if is_legacy_version else "x_min" - y_min_label = "ymin" if is_legacy_version else "y_min" - x_max_label = "xmax" if is_legacy_version else "x_max" - y_max_label = "ymax" if is_legacy_version else "y_max" - confidence_label = "confidence" - + batch_results = response.get("data", []) for detections in batch_results: - idx = int(detections["index"]) - scale_factor_x, scale_factor_y = scaling_factors[idx] - image_width = YOLOX_IMAGE_PREPROC_WIDTH - image_height = YOLOX_IMAGE_PREPROC_HEIGHT + new_bounding_boxes = {"table": [], "chart": [], "title": []} - # Initialize an empty tensor for detections - max_detections = 100 - detection_tensor = np.zeros((max_detections, 85), dtype=np.float32) - - index = 0 bounding_boxes = detections.get("bounding_boxes", []) for obj_type, bboxes in bounding_boxes.items(): for bbox in bboxes: - if index >= max_detections: - break - xmin_norm = bbox[x_min_label] - ymin_norm = bbox[y_min_label] - xmax_norm = bbox[x_max_label] - ymax_norm = bbox[y_max_label] - confidence = bbox[confidence_label] - - # Convert normalized coordinates to absolute pixel values in scaled image - xmin_scaled = xmin_norm * image_width * scale_factor_x - ymin_scaled = ymin_norm * image_height * scale_factor_y - xmax_scaled = xmax_norm * image_width * scale_factor_x - ymax_scaled = ymax_norm * image_height * scale_factor_y - - # Adjust coordinates back to 1024x1024 image space - xmin = xmin_scaled / scale_factor_x - ymin = ymin_scaled / scale_factor_y - xmax = xmax_scaled / scale_factor_x - ymax = ymax_scaled / scale_factor_y - - # YOLOX expects bbox format: center_x, center_y, width, height - center_x = (xmin + xmax) / 2 - center_y = (ymin + ymax) / 2 - width = xmax - xmin - height = ymax - ymin - - # Set the bbox coordinates - detection_tensor[index, 0] = center_x - detection_tensor[index, 1] = center_y - detection_tensor[index, 2] = width - detection_tensor[index, 3] = height - - # Objectness score - detection_tensor[index, 4] = confidence - - class_index = {"table": 0, "chart": 1, "title": 2}.get(obj_type, -1) - if class_index >= 0: - detection_tensor[index, 5 + class_index] = 1.0 - - index += 1 - - # Trim the detection tensor to the actual number of detections - detection_tensor = detection_tensor[:index, :] - processed_outputs.append(detection_tensor) - - # Pad batch if necessary - max_detections_in_batch = max([output.shape[0] for output in processed_outputs]) if processed_outputs else 0 - batch_output_array = np.zeros((batch_size, max_detections_in_batch, 85), dtype=np.float32) - for i, output in enumerate(processed_outputs): - batch_output_array[i, : output.shape[0], :] = output - - return batch_output_array + xmin = bbox["x_min"] + ymin = bbox["y_min"] + xmax = bbox["x_max"] + ymax = bbox["y_max"] + confidence = bbox["confidence"] + + new_bounding_boxes[obj_type].append([xmin, ymin, xmax, ymax, confidence]) + + processed_outputs.append(new_bounding_boxes) + + return processed_outputs else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") - def process_inference_results(self, output_array: np.ndarray, **kwargs) -> List[Dict[str, Any]]: + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> List[Dict[str, Any]]: """ Process the results of the Yolox model inference and return the final annotations. @@ -320,7 +215,6 @@ def process_inference_results(self, output_array: np.ndarray, **kwargs) -> List[ list[dict] A list of annotation dictionaries for each image in the batch. """ - original_image_shapes = kwargs.get("original_image_shapes", []) num_classes = kwargs.get("num_classes", YOLOX_NUM_CLASSES) conf_thresh = kwargs.get("conf_thresh", YOLOX_CONF_THRESHOLD) @@ -328,14 +222,22 @@ def process_inference_results(self, output_array: np.ndarray, **kwargs) -> List[ min_score = kwargs.get("min_score", YOLOX_MIN_SCORE) final_thresh = kwargs.get("final_thresh", YOLOX_FINAL_SCORE) - pred = postprocess_model_prediction(output_array, num_classes, conf_thresh, iou_thresh, class_agnostic=True) + if protocol == "http": + # For http, the output already has postprocessing applied. Skip to table/chart expansion. + results = output - results = postprocess_results(pred, original_image_shapes, min_score=min_score) + elif protocol == "grpc": + # For grpc, apply the same NIM postprocessing. + pred = postprocess_model_prediction(output, num_classes, conf_thresh, iou_thresh, class_agnostic=True) + results = postprocess_results(pred, original_image_shapes, min_score=min_score) - annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in results] + # Table/chart expansion is "business logic" specific to nv-ingest + annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in results] + annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts] inference_results = [] # Filter out bounding boxes below the final threshold + # This final thresholding is "business logic" specific to nv-ingest for annotation_dict in annotation_dicts: new_dict = {} if "table" in annotation_dict: @@ -348,9 +250,6 @@ def process_inference_results(self, output_array: np.ndarray, **kwargs) -> List[ return inference_results - def _is_version_early_access_legacy_api(self): - return self.yolox_version and (pkgversion.parse(self.yolox_version) < pkgversion.parse("1.0.0-rc0")) - def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): # Convert numpy array to torch tensor @@ -423,12 +322,14 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): Keep only bboxes with high enough confidence. """ - labels = ["table", "chart", "title"] + class_labels = ["table", "chart", "title"] out = [] for original_image_shape, result in zip(original_image_shapes, results): + annotation_dict = {label: [] for label in class_labels} + if result is None: - out.append({}) + out.append(annotation_dict) continue try: @@ -447,29 +348,17 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): bboxes[:, [1, 3]] /= original_image_shape[0] bboxes = np.clip(bboxes, 0.0, 1.0) - label_idxs = result[:, 6] + labels = result[:, 6] scores = scores[scores > min_score] except Exception as e: raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}") - annotation_dict = {label: [] for label in labels} - - # bboxes are in format [x_min, y_min, x_max, y_max] - for j in range(len(bboxes)): - label = labels[int(label_idxs[j])] - bbox = bboxes[j] - score = scores[j] - - # additional preprocessing for tables: extend the upper bounds to capture titles if any. - if label == "table": - height = bbox[3] - bbox[1] - bbox[1] = (bbox[1] - height * 0.2).clip(0.0, 1.0) - - annotation_dict[label].append([round(float(x), 4) for x in np.concatenate((bbox, [score]))]) + for box, score, label in zip(bboxes, scores, labels): + class_name = class_labels[int(label)] + annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))]) out.append(annotation_dict) - # {label: [[x1, y1, x2, y2, confidence], ...], ...} return out @@ -493,6 +382,37 @@ def resize_image(image, target_img_size): return image +def expand_table_bboxes(annotation_dict, labels=None): + """ + Additional preprocessing for tables: extend the upper bounds to capture titles if any. + Args: + annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title" + + Returns: + annotation_dict: same as input, with expanded bboxes for charts + + """ + if not labels: + labels = ["table", "chart", "title"] + + if not annotation_dict or len(annotation_dict["table"]) == 0: + return annotation_dict + + new_annotation_dict = {label: [] for label in labels} + + for label, bboxes in annotation_dict.items(): + for bbox_and_score in bboxes: + bbox, score = bbox_and_score[:4], bbox_and_score[4] + + if label == "table": + height = bbox[3] - bbox[1] + bbox[1] = max(0.0, min(1.0, bbox[1] - height * 0.2)) + + new_annotation_dict[label].append([round(float(x), 4) for x in bbox + [score]]) + + return new_annotation_dict + + def expand_chart_bboxes(annotation_dict, labels=None): """ Expand bounding boxes of charts and titles based on the bounding boxes of the other class. diff --git a/tests/nv_ingest/util/nim/test_cached.py b/tests/nv_ingest/util/nim/test_cached.py index 8463f25d..c3871926 100644 --- a/tests/nv_ingest/util/nim/test_cached.py +++ b/tests/nv_ingest/util/nim/test_cached.py @@ -216,7 +216,7 @@ def test_process_inference_results(model_interface): """ output = "Processed Output" - result = model_interface.process_inference_results(output) + result = model_interface.process_inference_results(output, "http") assert result == output diff --git a/tests/nv_ingest/util/nim/test_yolox.py b/tests/nv_ingest/util/nim/test_yolox.py index f42bf778..b3a84fd3 100644 --- a/tests/nv_ingest/util/nim/test_yolox.py +++ b/tests/nv_ingest/util/nim/test_yolox.py @@ -1,25 +1,17 @@ -import pytest -import numpy as np -from io import BytesIO import base64 +import random +from io import BytesIO + +import numpy as np +import pytest from PIL import Image from nv_ingest.util.nim.yolox import YoloxPageElementsModelInterface -@pytest.fixture(params=["0.2.0", "1.0.0"]) -def model_interface(request): - return YoloxPageElementsModelInterface(yolox_version=request.param) - - -@pytest.fixture -def legacy_model_interface(): - return YoloxPageElementsModelInterface(yolox_version="0.2.0") - - @pytest.fixture -def ga_model_interface(): - return YoloxPageElementsModelInterface(yolox_version="1.0.0") +def model_interface(): + return YoloxPageElementsModelInterface() def create_test_image(width=800, height=600, color=(255, 0, 0)): @@ -68,25 +60,18 @@ def create_base64_image(width=1024, height=1024, color=(255, 0, 0)): return base64.b64encode(buffer.getvalue()).decode("utf-8") -def test_name_returns_yolox_legacy(legacy_model_interface): - assert legacy_model_interface.name() == "yolox-page-elements (version 0.2.0)" - - -def test_name_returns_yolox(ga_model_interface): - ga_model_interface = YoloxPageElementsModelInterface(yolox_version="1.0.0") - assert ga_model_interface.name() == "yolox-page-elements (version 1.0.0)" +def test_name_returns_yolox(model_interface): + model_interface = YoloxPageElementsModelInterface() + assert model_interface.name() == "yolox-page-elements" def test_prepare_data_for_inference_valid(model_interface): images = [create_test_image(), create_test_image(width=640, height=480)] input_data = {"images": images} result = model_interface.prepare_data_for_inference(input_data) - assert "resized_images" in result assert "original_image_shapes" in result - assert len(result["resized_images"]) == len(images) assert len(result["original_image_shapes"]) == len(images) - for original_shape, resized_image, image in zip(result["original_image_shapes"], result["resized_images"], images): - assert resized_image.shape == (1024, 1024, 3) + for original_shape, image in zip(result["original_image_shapes"], images): assert original_shape[:2] == image.shape[:2] @@ -118,28 +103,11 @@ def test_format_input_grpc(model_interface): assert formatted_input.shape[1:] == (3, 1024, 1024) -def test_format_input_legacy(legacy_model_interface): - images = [create_test_image(), create_test_image()] - input_data = {"images": images} - prepared_data = legacy_model_interface.prepare_data_for_inference(input_data) - formatted_input = legacy_model_interface.format_input(prepared_data, "http") - assert "messages" in formatted_input - assert isinstance(formatted_input["messages"], list) - for message in formatted_input["messages"]: - assert "content" in message - for content in message["content"]: - assert "type" in content - assert content["type"] == "image_url" - assert "image_url" in content - assert "url" in content["image_url"] - assert content["image_url"]["url"].startswith("data:image/png;base64,") - - -def test_format_input(ga_model_interface): +def test_format_input_http(model_interface): images = [create_test_image(), create_test_image()] input_data = {"images": images} - prepared_data = ga_model_interface.prepare_data_for_inference(input_data) - formatted_input = ga_model_interface.format_input(prepared_data, "http") + prepared_data = model_interface.prepare_data_for_inference(input_data) + formatted_input = model_interface.format_input(prepared_data, "http") assert "input" in formatted_input assert isinstance(formatted_input["input"], list) for content in formatted_input["input"]: @@ -165,45 +133,7 @@ def test_parse_output_grpc(model_interface): assert parsed_output.dtype == np.float32 -def test_parse_output_http_valid_legacy(legacy_model_interface): - response = { - "data": [ - [ - { - "type": "table", - "bboxes": [{"xmin": 0.1, "ymin": 0.1, "xmax": 0.2, "ymax": 0.2, "confidence": 0.9}], - }, - { - "type": "chart", - "bboxes": [{"xmin": 0.3, "ymin": 0.3, "xmax": 0.4, "ymax": 0.4, "confidence": 0.8}], - }, - {"type": "title", "bboxes": [{"xmin": 0.5, "ymin": 0.5, "xmax": 0.6, "ymax": 0.6, "confidence": 0.95}]}, - ], - [ - { - "type": "table", - "bboxes": [{"xmin": 0.15, "ymin": 0.15, "xmax": 0.25, "ymax": 0.25, "confidence": 0.85}], - }, - { - "type": "chart", - "bboxes": [{"xmin": 0.35, "ymin": 0.35, "xmax": 0.45, "ymax": 0.45, "confidence": 0.75}], - }, - { - "type": "title", - "bboxes": [{"xmin": 0.55, "ymin": 0.55, "xmax": 0.65, "ymax": 0.65, "confidence": 0.92}], - }, - ], - ] - } - scaling_factors = [(1.0, 1.0), (1.0, 1.0)] - data = {"scaling_factors": scaling_factors} - parsed_output = legacy_model_interface.parse_output(response, "http", data) - assert isinstance(parsed_output, np.ndarray) - assert parsed_output.shape == (2, 3, 85) - assert parsed_output.dtype == np.float32 - - -def test_parse_output_http_valid(ga_model_interface): +def test_parse_output_http_valid(model_interface): response = { "data": [ { @@ -224,12 +154,19 @@ def test_parse_output_http_valid(ga_model_interface): }, ] } - scaling_factors = [(1.0, 1.0), (1.0, 1.0)] - data = {"scaling_factors": scaling_factors} - parsed_output = ga_model_interface.parse_output(response, "http", data) - assert isinstance(parsed_output, np.ndarray) - assert parsed_output.shape == (2, 3, 85) - assert parsed_output.dtype == np.float32 + parsed_output = model_interface.parse_output(response, "http") + assert parsed_output == [ + { + "table": [[0.1, 0.1, 0.2, 0.2, 0.9]], + "chart": [[0.3, 0.3, 0.4, 0.4, 0.8]], + "title": [[0.5, 0.5, 0.6, 0.6, 0.95]], + }, + { + "table": [[0.15, 0.15, 0.25, 0.25, 0.85]], + "chart": [[0.35, 0.35, 0.45, 0.45, 0.75]], + "title": [[0.55, 0.55, 0.65, 0.65, 0.92]], + }, + ] def test_parse_output_invalid_protocol(model_interface): @@ -238,11 +175,12 @@ def test_parse_output_invalid_protocol(model_interface): model_interface.parse_output(response, "invalid_protocol") -def test_process_inference_results(model_interface): +def test_process_inference_results_grpc(model_interface): output_array = np.random.rand(2, 100, 85).astype(np.float32) original_image_shapes = [(800, 600, 3), (640, 480, 3)] inference_results = model_interface.process_inference_results( output_array, + "grpc", original_image_shapes=original_image_shapes, num_classes=3, conf_thresh=0.5, @@ -262,3 +200,35 @@ def test_process_inference_results(model_interface): assert bbox[4] >= 0.6 if "title" in result: assert isinstance(result["title"], list) + + +def test_process_inference_results_http(model_interface): + output = [ + { + "table": [[random.random() for _ in range(5)] for _ in range(10)], + "chart": [[random.random() for _ in range(5)] for _ in range(10)], + "title": [[random.random() for _ in range(5)] for _ in range(10)], + } + for _ in range(10) + ] + inference_results = model_interface.process_inference_results( + output, + "http", + num_classes=3, + conf_thresh=0.5, + iou_thresh=0.4, + min_score=0.3, + final_thresh=0.6, + ) + assert isinstance(inference_results, list) + assert len(inference_results) == 10 + for result in inference_results: + assert isinstance(result, dict) + if "table" in result: + for bbox in result["table"]: + assert bbox[4] >= 0.6 + if "chart" in result: + for bbox in result["chart"]: + assert bbox[4] >= 0.6 + if "title" in result: + assert isinstance(result["title"], list)