Skip to content

Commit e72145c

Browse files
reworld223pre-commit-ci[bot]KumoLiu
authored
Fix box_iou returning 0 for floating-point results less than 1. #8369 (#8553)
Fixes # 8369 ### Description Fixes an issue where the result of `box_iou` was 0 when the first argument was an integer, as described in #8369. IOU values range from 0 to 1. If the first argument is an integer, the function could return an integer less than 1, resulting in a return value of 0. This pull request changes the return type of `box_iou` to `float32`, or to match the float data type of the first argument. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Signed-off-by: reworld223 <sjiez369@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 579cec5 commit e72145c

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

monai/data/box_utils.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,10 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
826826
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
827827
828828
Returns:
829-
IoU, with size of (N,M) and same data type as ``boxes1``
829+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
830+
floating-point with size ``(N, M)``:
831+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
832+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
830833
831834
"""
832835

@@ -842,16 +845,18 @@ def box_iou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTensor
842845

843846
inter, union = _box_inter_union(boxes1_t, boxes2_t, compute_dtype=COMPUTE_DTYPE)
844847

845-
# compute IoU and convert back to original box_dtype
848+
# compute IoU and convert back to original box_dtype or torch.float32
846849
iou_t = inter / (union + torch.finfo(COMPUTE_DTYPE).eps) # (N,M)
850+
if not box_dtype.is_floating_point:
851+
box_dtype = COMPUTE_DTYPE
847852
iou_t = iou_t.to(dtype=box_dtype)
848853

849854
# check if NaN or Inf
850855
if torch.isnan(iou_t).any() or torch.isinf(iou_t).any():
851856
raise ValueError("Box IoU is NaN or Inf.")
852857

853858
# convert tensor back to numpy if needed
854-
iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1)
859+
iou, *_ = convert_to_dst_type(src=iou_t, dst=boxes1, dtype=box_dtype)
855860
return iou
856861

857862

@@ -867,7 +872,10 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
867872
boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
868873
869874
Returns:
870-
GIoU, with size of (N,M) and same data type as ``boxes1``
875+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
876+
floating-point with size ``(N, M)``:
877+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
878+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
871879
872880
Reference:
873881
https://giou.stanford.edu/GIoU.pdf
@@ -904,12 +912,15 @@ def box_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOrTenso
904912

905913
# GIoU
906914
giou_t = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps)
915+
if not box_dtype.is_floating_point:
916+
box_dtype = COMPUTE_DTYPE
907917
giou_t = giou_t.to(dtype=box_dtype)
918+
908919
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
909920
raise ValueError("Box GIoU is NaN or Inf.")
910921

911922
# convert tensor back to numpy if needed
912-
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
923+
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)
913924
return giou
914925

915926

@@ -925,7 +936,10 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
925936
boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``
926937
927938
Returns:
928-
paired GIoU, with size of (N,) and same data type as ``boxes1``
939+
An array/tensor matching the container type of ``boxes1`` (NumPy ndarray or Torch tensor), always
940+
floating-point with size ``(N, )``:
941+
- if ``boxes1`` has a floating-point dtype, the same dtype is used.
942+
- if ``boxes1`` has an integer dtype, the result is returned as ``torch.float32``.
929943
930944
Reference:
931945
https://giou.stanford.edu/GIoU.pdf
@@ -982,12 +996,15 @@ def box_pair_giou(boxes1: NdarrayOrTensor, boxes2: NdarrayOrTensor) -> NdarrayOr
982996
enclosure = torch.prod(wh, dim=-1, keepdim=False) # (N,)
983997

984998
giou_t: torch.Tensor = iou - (enclosure - union) / (enclosure + torch.finfo(COMPUTE_DTYPE).eps) # type: ignore
999+
if not box_dtype.is_floating_point:
1000+
box_dtype = COMPUTE_DTYPE
9851001
giou_t = giou_t.to(dtype=box_dtype) # (N,spatial_dims)
1002+
9861003
if torch.isnan(giou_t).any() or torch.isinf(giou_t).any():
9871004
raise ValueError("Box GIoU is NaN or Inf.")
9881005

9891006
# convert tensor back to numpy if needed
990-
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1)
1007+
giou, *_ = convert_to_dst_type(src=giou_t, dst=boxes1, dtype=box_dtype)
9911008
return giou
9921009

9931010

tests/data/test_box_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import unittest
1515

1616
import numpy as np
17+
import torch
1718
from parameterized import parameterized
1819

1920
from monai.data.box_utils import (
@@ -218,5 +219,55 @@ def test_value(self, input_data, mode2, expected_box, expected_area):
218219
assert_allclose(nms_box, [1], type_test=False)
219220

220221

222+
class TestBoxUtilsDtype(unittest.TestCase):
223+
@parameterized.expand(
224+
[
225+
# numpy dtypes
226+
(np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32)),
227+
(np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32)),
228+
# torch dtypes
229+
(
230+
torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64),
231+
torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64),
232+
),
233+
(
234+
torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32),
235+
torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32),
236+
),
237+
# mixed numpy (int + float)
238+
(np.array([[1, 1, 1, 2, 2, 2]], dtype=np.int32), np.array([[1, 1, 1, 2, 2, 2]], dtype=np.float32)),
239+
# mixed torch (int + float)
240+
(
241+
torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.int64),
242+
torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.float32),
243+
),
244+
]
245+
)
246+
def test_dtype_behavior(self, boxes1, boxes2):
247+
funcs = [box_iou, box_giou, box_pair_giou]
248+
for func in funcs:
249+
result = func(boxes1, boxes2)
250+
251+
if isinstance(result, np.ndarray):
252+
self.assertTrue(
253+
np.issubdtype(result.dtype, np.floating), f"{func.__name__} expected float, got {result.dtype}"
254+
)
255+
elif torch.is_tensor(result):
256+
self.assertTrue(
257+
torch.is_floating_point(result), f"{func.__name__} expected float tensor, got {result.dtype}"
258+
)
259+
else:
260+
self.fail(f"Unexpected return type {type(result)}")
261+
262+
def test_integer_truncation_bug(self):
263+
# Verify fix for #8553: IoU < 1.0 with integer inputs should not truncate to 0
264+
boxes1 = np.array([[0, 0, 0, 2, 2, 2]], dtype=np.int32)
265+
boxes2 = np.array([[1, 1, 1, 3, 3, 3]], dtype=np.int32)
266+
267+
iou = box_iou(boxes1, boxes2)
268+
self.assertTrue(np.issubdtype(iou.dtype, np.floating))
269+
self.assertGreater(iou[0, 0], 0.0, "IoU should not be truncated to 0")
270+
271+
221272
if __name__ == "__main__":
222273
unittest.main()

0 commit comments

Comments
 (0)