Skip to content

Commit 53111bb

Browse files
authored
Remove unused yolox functions and fix nim client shut down. (#299)
1 parent 2cdbeb8 commit 53111bb

File tree

5 files changed

+3
-287
lines changed

5 files changed

+3
-287
lines changed

src/nv_ingest/extraction_workflows/image/image_handlers.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -107,79 +107,6 @@ def convert_svg_to_bitmap(image_stream: io.BytesIO) -> np.ndarray:
107107
return image_array
108108

109109

110-
# TODO(Devin): Move to common file
111-
def process_inference_results(
112-
output_array: np.ndarray,
113-
original_image_shapes: List[Tuple[int, int]],
114-
num_classes: int,
115-
conf_thresh: float,
116-
iou_thresh: float,
117-
min_score: float,
118-
final_thresh: float,
119-
):
120-
"""
121-
Process the model output to generate detection results and expand bounding boxes.
122-
123-
Parameters
124-
----------
125-
output_array : np.ndarray
126-
The raw output from the model inference.
127-
original_image_shapes : List[Tuple[int, int]]
128-
The shapes of the original images before resizing, used for scaling bounding boxes.
129-
num_classes : int
130-
The number of classes the model can detect.
131-
conf_thresh : float
132-
The confidence threshold for detecting objects.
133-
iou_thresh : float
134-
The Intersection Over Union (IoU) threshold for non-maximum suppression.
135-
min_score : float
136-
The minimum score for keeping a detection.
137-
final_thresh: float
138-
Threshold for keeping a bounding box applied after postprocessing.
139-
140-
141-
Returns
142-
-------
143-
List[dict]
144-
A list of dictionaries, each containing processed detection results including expanded bounding boxes.
145-
146-
Notes
147-
-----
148-
This function applies non-maximum suppression to the model's output and scales the bounding boxes back to the
149-
original image size.
150-
151-
Examples
152-
--------
153-
>>> output_array = np.random.rand(2, 100, 85)
154-
>>> original_image_shapes = [(1536, 1536), (1536, 1536)]
155-
>>> results = process_inference_results(output_array, original_image_shapes, 80, 0.5, 0.5, 0.1)
156-
>>> len(results)
157-
2
158-
"""
159-
pred = yolox_utils.postprocess_model_prediction(
160-
output_array, num_classes, conf_thresh, iou_thresh, class_agnostic=True
161-
)
162-
results = yolox_utils.postprocess_results(pred, original_image_shapes, min_score=min_score)
163-
logger.debug(f"Number of results: {len(results)}")
164-
logger.debug(f"Results: {results}")
165-
166-
annotation_dicts = [yolox_utils.expand_chart_bboxes(annotation_dict) for annotation_dict in results]
167-
inference_results = []
168-
169-
# Filter out bounding boxes below the final threshold
170-
for annotation_dict in annotation_dicts:
171-
new_dict = {}
172-
if "table" in annotation_dict:
173-
new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= final_thresh]
174-
if "chart" in annotation_dict:
175-
new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= final_thresh]
176-
if "title" in annotation_dict:
177-
new_dict["title"] = annotation_dict["title"]
178-
inference_results.append(new_dict)
179-
180-
return inference_results
181-
182-
183110
def extract_table_and_chart_images(
184111
annotation_dict: Dict[str, List[List[float]]],
185112
original_image: np.ndarray,

src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -142,76 +142,6 @@ def extract_tables_and_charts_using_image_ensemble(
142142
return tables_and_charts
143143

144144

145-
def process_inference_results(
146-
output_array: np.ndarray,
147-
original_image_shapes: List[Tuple[int, int]],
148-
num_classes: int,
149-
conf_thresh: float,
150-
iou_thresh: float,
151-
min_score: float,
152-
final_thresh: float,
153-
):
154-
"""
155-
Process the model output to generate detection results and expand bounding boxes.
156-
157-
Parameters
158-
----------
159-
output_array : np.ndarray
160-
The raw output from the model inference.
161-
original_image_shapes : List[Tuple[int, int]]
162-
The shapes of the original images before resizing, used for scaling bounding boxes.
163-
num_classes : int
164-
The number of classes the model can detect.
165-
conf_thresh : float
166-
The confidence threshold for detecting objects.
167-
iou_thresh : float
168-
The Intersection Over Union (IoU) threshold for non-maximum suppression.
169-
min_score : float
170-
The minimum score for keeping a detection.
171-
final_thresh: float
172-
Threshold for keeping a bounding box applied after postprocessing.
173-
174-
175-
Returns
176-
-------
177-
List[dict]
178-
A list of dictionaries, each containing processed detection results including expanded bounding boxes.
179-
180-
Notes
181-
-----
182-
This function applies non-maximum suppression to the model's output and scales the bounding boxes back to the
183-
original image size.
184-
185-
Examples
186-
--------
187-
>>> output_array = np.random.rand(2, 100, 85)
188-
>>> original_image_shapes = [(1536, 1536), (1536, 1536)]
189-
>>> results = process_inference_results(output_array, original_image_shapes, 80, 0.5, 0.5, 0.1)
190-
>>> len(results)
191-
2
192-
"""
193-
pred = yolox_utils.postprocess_model_prediction(
194-
output_array, num_classes, conf_thresh, iou_thresh, class_agnostic=True
195-
)
196-
results = yolox_utils.postprocess_results(pred, original_image_shapes, min_score=min_score)
197-
198-
annotation_dicts = [yolox_utils.expand_chart_bboxes(annotation_dict) for annotation_dict in results]
199-
inference_results = []
200-
201-
# Filter out bounding boxes below the final threshold
202-
for annotation_dict in annotation_dicts:
203-
new_dict = {}
204-
if "table" in annotation_dict:
205-
new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= final_thresh]
206-
if "chart" in annotation_dict:
207-
new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= final_thresh]
208-
if "title" in annotation_dict:
209-
new_dict["title"] = annotation_dict["title"]
210-
inference_results.append(new_dict)
211-
212-
return inference_results
213-
214-
215145
# Handle individual table/chart extraction and model inference
216146
def extract_table_and_chart_images(
217147
annotation_dict,

src/nv_ingest/stages/nim/chart_extraction.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Tuple
1111

1212
import pandas as pd
13-
import tritonclient.grpc as grpcclient
1413
from morpheus.config import Config
1514

1615
from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema
@@ -190,10 +189,8 @@ def _extract_chart_data(
190189
logger.error("Error occurred while extracting chart data.", exc_info=True)
191190
raise
192191
finally:
193-
if isinstance(cached_client, grpcclient.InferenceServerClient):
194-
cached_client.close()
195-
if isinstance(deplot_client, grpcclient.InferenceServerClient):
196-
deplot_client.close()
192+
cached_client.close()
193+
deplot_client.close()
197194

198195

199196
def generate_chart_extractor_stage(

src/nv_ingest/stages/nim/table_extraction.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,7 @@ def _extract_table_data(
172172
logger.error("Error occurred while extracting table data.", exc_info=True)
173173
raise
174174
finally:
175-
if isinstance(paddle_client, NimClient):
176-
paddle_client.close()
175+
paddle_client.close()
177176

178177

179178
def generate_table_extractor_stage(

tests/nv_ingest/extraction_workflows/image/test_image_handlers.py

Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from nv_ingest.extraction_workflows.image.image_handlers import convert_svg_to_bitmap
88
from nv_ingest.extraction_workflows.image.image_handlers import extract_table_and_chart_images
99
from nv_ingest.extraction_workflows.image.image_handlers import load_and_preprocess_image
10-
from nv_ingest.extraction_workflows.image.image_handlers import process_inference_results
1110
from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
1211

1312

@@ -119,142 +118,6 @@ def test_convert_svg_to_bitmap_large_svg():
119118
assert np.all(result[:, :, 2] == 255) # Blue channel fully on
120119

121120

122-
def test_process_inference_results_basic_case():
123-
"""Test process_inference_results with a typical valid input."""
124-
125-
# Simulated model output array for a single image with several detections.
126-
# Array format is (batch_size, num_detections, 85) - 80 classes + 5 box coordinates
127-
# For simplicity, use random values for the boxes and class predictions.
128-
output_array = np.zeros((1, 3, 85), dtype=np.float32)
129-
130-
# Mock bounding box coordinates
131-
output_array[0, 0, :4] = [0.5, 0.5, 0.2, 0.2] # x_center, y_center, width, height
132-
output_array[0, 1, :4] = [0.6, 0.6, 0.2, 0.2]
133-
output_array[0, 2, :4] = [0.7, 0.7, 0.2, 0.2]
134-
135-
# Mock object confidence scores
136-
output_array[0, :, 4] = [0.8, 0.9, 0.85]
137-
138-
# Mock class scores (set class 1 with highest confidence for simplicity)
139-
output_array[0, 0, 5 + 1] = 0.7
140-
output_array[0, 1, 5 + 1] = 0.75
141-
output_array[0, 2, 5 + 1] = 0.72
142-
143-
original_image_shapes = [(640, 640)] # Original shape of the image before resizing
144-
145-
# Process inference results with thresholds that should retain all mock detections
146-
results = process_inference_results(
147-
output_array,
148-
original_image_shapes,
149-
num_classes=80,
150-
conf_thresh=0.5,
151-
iou_thresh=0.5,
152-
min_score=0.1,
153-
final_thresh=0.3,
154-
)
155-
156-
# Check output structure
157-
assert isinstance(results, list)
158-
assert len(results) == 1
159-
assert isinstance(results[0], dict)
160-
161-
# Validate bounding box scaling and structure
162-
assert "chart" in results[0] or "table" in results[0]
163-
if "chart" in results[0]:
164-
assert isinstance(results[0]["chart"], list)
165-
assert len(results[0]["chart"]) > 0
166-
# Check bounding box format for each detected "chart" item (5 values per box)
167-
for bbox in results[0]["chart"]:
168-
assert len(bbox) == 5 # [x1, y1, x2, y2, score]
169-
assert bbox[4] >= 0.3 # score meets final threshold
170-
171-
print("Processed inference results:", results)
172-
173-
174-
def test_process_inference_results_multiple_images():
175-
"""Test with multiple images to verify batch processing."""
176-
# Simulate model output with 2 images and 3 detections each
177-
output_array = np.zeros((2, 3, 85), dtype=np.float32)
178-
# Set bounding boxes and confidence for the mock detections
179-
output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8]
180-
output_array[0, 1, :5] = [0.6, 0.6, 0.2, 0.2, 0.7]
181-
output_array[1, 0, :5] = [0.4, 0.4, 0.1, 0.1, 0.9]
182-
# Assign class confidences for classes 0 and 1
183-
output_array[0, 0, 5 + 1] = 0.75
184-
output_array[0, 1, 5 + 1] = 0.65
185-
output_array[1, 0, 5 + 0] = 0.8
186-
187-
original_image_shapes = [(640, 640), (800, 800)]
188-
189-
results = process_inference_results(
190-
output_array,
191-
original_image_shapes,
192-
num_classes=80,
193-
conf_thresh=0.5,
194-
iou_thresh=0.5,
195-
min_score=0.1,
196-
final_thresh=0.3,
197-
)
198-
199-
assert isinstance(results, list)
200-
assert len(results) == 2
201-
for result in results:
202-
assert isinstance(result, dict)
203-
if "chart" in result:
204-
assert all(len(bbox) == 5 and bbox[4] >= 0.3 for bbox in result["chart"])
205-
206-
207-
def test_process_inference_results_high_confidence_threshold():
208-
"""Test with a high confidence threshold to verify filtering."""
209-
output_array = np.zeros((1, 5, 85), dtype=np.float32)
210-
# Set low confidence scores below the threshold
211-
output_array[0, :, 4] = [0.2, 0.3, 0.4, 0.4, 0.2]
212-
output_array[0, :, 5] = [0.5] * 5 # Class confidence
213-
214-
original_image_shapes = [(640, 640)]
215-
216-
results = process_inference_results(
217-
output_array,
218-
original_image_shapes,
219-
num_classes=80,
220-
conf_thresh=0.9, # High confidence threshold
221-
iou_thresh=0.5,
222-
min_score=0.1,
223-
final_thresh=0.3,
224-
)
225-
226-
assert isinstance(results, list)
227-
assert len(results) == 1
228-
assert results[0] == {} # No detections should pass the high confidence threshold
229-
230-
231-
def test_process_inference_results_varied_num_classes():
232-
"""Test compatibility with different model class counts."""
233-
output_array = np.zeros((1, 3, 25), dtype=np.float32) # 20 classes + 5 box coords
234-
# Assign box, object confidence, and class scores
235-
output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8]
236-
output_array[0, 1, :5] = [0.6, 0.6, 0.3, 0.3, 0.7]
237-
output_array[0, 0, 5 + 1] = 0.9 # Assign highest confidence to class 1
238-
239-
original_image_shapes = [(640, 640)]
240-
241-
results = process_inference_results(
242-
output_array,
243-
original_image_shapes,
244-
num_classes=20, # Different class count
245-
conf_thresh=0.5,
246-
iou_thresh=0.5,
247-
min_score=0.1,
248-
final_thresh=0.3,
249-
)
250-
251-
assert isinstance(results, list)
252-
assert len(results) == 1
253-
assert isinstance(results[0], dict)
254-
assert "chart" in results[0]
255-
assert len(results[0]["chart"]) > 0 # Verify detections processed correctly with 20 classes
256-
257-
258121
def crop_image(image: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray:
259122
"""Mock function to simulate cropping an image."""
260123
h1, w1, h2, w2 = bbox

0 commit comments

Comments
 (0)