|
7 | 7 | from nv_ingest.extraction_workflows.image.image_handlers import convert_svg_to_bitmap
|
8 | 8 | from nv_ingest.extraction_workflows.image.image_handlers import extract_table_and_chart_images
|
9 | 9 | from nv_ingest.extraction_workflows.image.image_handlers import load_and_preprocess_image
|
10 |
| -from nv_ingest.extraction_workflows.image.image_handlers import process_inference_results |
11 | 10 | from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
|
12 | 11 |
|
13 | 12 |
|
@@ -119,142 +118,6 @@ def test_convert_svg_to_bitmap_large_svg():
|
119 | 118 | assert np.all(result[:, :, 2] == 255) # Blue channel fully on
|
120 | 119 |
|
121 | 120 |
|
122 |
| -def test_process_inference_results_basic_case(): |
123 |
| - """Test process_inference_results with a typical valid input.""" |
124 |
| - |
125 |
| - # Simulated model output array for a single image with several detections. |
126 |
| - # Array format is (batch_size, num_detections, 85) - 80 classes + 5 box coordinates |
127 |
| - # For simplicity, use random values for the boxes and class predictions. |
128 |
| - output_array = np.zeros((1, 3, 85), dtype=np.float32) |
129 |
| - |
130 |
| - # Mock bounding box coordinates |
131 |
| - output_array[0, 0, :4] = [0.5, 0.5, 0.2, 0.2] # x_center, y_center, width, height |
132 |
| - output_array[0, 1, :4] = [0.6, 0.6, 0.2, 0.2] |
133 |
| - output_array[0, 2, :4] = [0.7, 0.7, 0.2, 0.2] |
134 |
| - |
135 |
| - # Mock object confidence scores |
136 |
| - output_array[0, :, 4] = [0.8, 0.9, 0.85] |
137 |
| - |
138 |
| - # Mock class scores (set class 1 with highest confidence for simplicity) |
139 |
| - output_array[0, 0, 5 + 1] = 0.7 |
140 |
| - output_array[0, 1, 5 + 1] = 0.75 |
141 |
| - output_array[0, 2, 5 + 1] = 0.72 |
142 |
| - |
143 |
| - original_image_shapes = [(640, 640)] # Original shape of the image before resizing |
144 |
| - |
145 |
| - # Process inference results with thresholds that should retain all mock detections |
146 |
| - results = process_inference_results( |
147 |
| - output_array, |
148 |
| - original_image_shapes, |
149 |
| - num_classes=80, |
150 |
| - conf_thresh=0.5, |
151 |
| - iou_thresh=0.5, |
152 |
| - min_score=0.1, |
153 |
| - final_thresh=0.3, |
154 |
| - ) |
155 |
| - |
156 |
| - # Check output structure |
157 |
| - assert isinstance(results, list) |
158 |
| - assert len(results) == 1 |
159 |
| - assert isinstance(results[0], dict) |
160 |
| - |
161 |
| - # Validate bounding box scaling and structure |
162 |
| - assert "chart" in results[0] or "table" in results[0] |
163 |
| - if "chart" in results[0]: |
164 |
| - assert isinstance(results[0]["chart"], list) |
165 |
| - assert len(results[0]["chart"]) > 0 |
166 |
| - # Check bounding box format for each detected "chart" item (5 values per box) |
167 |
| - for bbox in results[0]["chart"]: |
168 |
| - assert len(bbox) == 5 # [x1, y1, x2, y2, score] |
169 |
| - assert bbox[4] >= 0.3 # score meets final threshold |
170 |
| - |
171 |
| - print("Processed inference results:", results) |
172 |
| - |
173 |
| - |
174 |
| -def test_process_inference_results_multiple_images(): |
175 |
| - """Test with multiple images to verify batch processing.""" |
176 |
| - # Simulate model output with 2 images and 3 detections each |
177 |
| - output_array = np.zeros((2, 3, 85), dtype=np.float32) |
178 |
| - # Set bounding boxes and confidence for the mock detections |
179 |
| - output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8] |
180 |
| - output_array[0, 1, :5] = [0.6, 0.6, 0.2, 0.2, 0.7] |
181 |
| - output_array[1, 0, :5] = [0.4, 0.4, 0.1, 0.1, 0.9] |
182 |
| - # Assign class confidences for classes 0 and 1 |
183 |
| - output_array[0, 0, 5 + 1] = 0.75 |
184 |
| - output_array[0, 1, 5 + 1] = 0.65 |
185 |
| - output_array[1, 0, 5 + 0] = 0.8 |
186 |
| - |
187 |
| - original_image_shapes = [(640, 640), (800, 800)] |
188 |
| - |
189 |
| - results = process_inference_results( |
190 |
| - output_array, |
191 |
| - original_image_shapes, |
192 |
| - num_classes=80, |
193 |
| - conf_thresh=0.5, |
194 |
| - iou_thresh=0.5, |
195 |
| - min_score=0.1, |
196 |
| - final_thresh=0.3, |
197 |
| - ) |
198 |
| - |
199 |
| - assert isinstance(results, list) |
200 |
| - assert len(results) == 2 |
201 |
| - for result in results: |
202 |
| - assert isinstance(result, dict) |
203 |
| - if "chart" in result: |
204 |
| - assert all(len(bbox) == 5 and bbox[4] >= 0.3 for bbox in result["chart"]) |
205 |
| - |
206 |
| - |
207 |
| -def test_process_inference_results_high_confidence_threshold(): |
208 |
| - """Test with a high confidence threshold to verify filtering.""" |
209 |
| - output_array = np.zeros((1, 5, 85), dtype=np.float32) |
210 |
| - # Set low confidence scores below the threshold |
211 |
| - output_array[0, :, 4] = [0.2, 0.3, 0.4, 0.4, 0.2] |
212 |
| - output_array[0, :, 5] = [0.5] * 5 # Class confidence |
213 |
| - |
214 |
| - original_image_shapes = [(640, 640)] |
215 |
| - |
216 |
| - results = process_inference_results( |
217 |
| - output_array, |
218 |
| - original_image_shapes, |
219 |
| - num_classes=80, |
220 |
| - conf_thresh=0.9, # High confidence threshold |
221 |
| - iou_thresh=0.5, |
222 |
| - min_score=0.1, |
223 |
| - final_thresh=0.3, |
224 |
| - ) |
225 |
| - |
226 |
| - assert isinstance(results, list) |
227 |
| - assert len(results) == 1 |
228 |
| - assert results[0] == {} # No detections should pass the high confidence threshold |
229 |
| - |
230 |
| - |
231 |
| -def test_process_inference_results_varied_num_classes(): |
232 |
| - """Test compatibility with different model class counts.""" |
233 |
| - output_array = np.zeros((1, 3, 25), dtype=np.float32) # 20 classes + 5 box coords |
234 |
| - # Assign box, object confidence, and class scores |
235 |
| - output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8] |
236 |
| - output_array[0, 1, :5] = [0.6, 0.6, 0.3, 0.3, 0.7] |
237 |
| - output_array[0, 0, 5 + 1] = 0.9 # Assign highest confidence to class 1 |
238 |
| - |
239 |
| - original_image_shapes = [(640, 640)] |
240 |
| - |
241 |
| - results = process_inference_results( |
242 |
| - output_array, |
243 |
| - original_image_shapes, |
244 |
| - num_classes=20, # Different class count |
245 |
| - conf_thresh=0.5, |
246 |
| - iou_thresh=0.5, |
247 |
| - min_score=0.1, |
248 |
| - final_thresh=0.3, |
249 |
| - ) |
250 |
| - |
251 |
| - assert isinstance(results, list) |
252 |
| - assert len(results) == 1 |
253 |
| - assert isinstance(results[0], dict) |
254 |
| - assert "chart" in results[0] |
255 |
| - assert len(results[0]["chart"]) > 0 # Verify detections processed correctly with 20 classes |
256 |
| - |
257 |
| - |
258 | 121 | def crop_image(image: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray:
|
259 | 122 | """Mock function to simulate cropping an image."""
|
260 | 123 | h1, w1, h2, w2 = bbox
|
|
0 commit comments