Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clip width and height in BBoxAnnotation #192

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions luxonis_ml/data/datasets/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,24 @@ def validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]:
logger.warning(
"BBox annotation has values outside of [0, 1] range. Clipping them to [0, 1]."
)

# cliping done in function instead of separate model validator so
# order of execution is explicitly defined
values = cls.clip_sum(values)
return values

@classmethod
def clip_sum(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values["x"] + values["w"] > 1:
values["w"] = 1 - values["x"]
logger.warning(
"BBox annotation has x + width > 1. Clipping width so the sum is 1."
)
if values["y"] + values["h"] > 1:
values["h"] = 1 - values["y"]
logger.warning(
"BBox annotation has y + height > 1. Clipping height so the sum is 1."
)
return values

def to_numpy(self, class_mapping: Dict[str, int]) -> np.ndarray:
Expand Down
82 changes: 82 additions & 0 deletions tests/test_data/test_ann_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pydantic
import pytest

from luxonis_ml.data.datasets.annotation import (
BBoxAnnotation,
KeypointAnnotation,
PolylineSegmentationAnnotation,
)


def test_bbox_no_auto_clip():
base_dict = {"x": 0, "y": 0, "w": 0, "h": 0}
for k in ["x", "y", "w", "h"]:
for v in [-2.1, 2.3, -3.3, 3]:
with pytest.raises(pydantic.ValidationError):
curr_dict = base_dict.copy()
curr_dict[k] = v
BBoxAnnotation(**curr_dict)


def test_bbox_auto_clip():
base_dict = {"x": 0, "y": 0, "w": 0, "h": 0}
for k in ["x", "y", "w", "h"]:
for v in [-1.1, 1.3, -1.3, 2]:
curr_dict = base_dict.copy()
curr_dict[k] = v
bbox_ann = BBoxAnnotation(**curr_dict)
assert 0 <= bbox_ann.x <= 1
assert 0 <= bbox_ann.y <= 1
assert 0 <= bbox_ann.w <= 1
assert 0 <= bbox_ann.h <= 1


def test_bbox_clip_sum():
bbox_ann = BBoxAnnotation(**{"x": 0.9, "y": 0, "w": 0.2, "h": 0})
assert bbox_ann.x + bbox_ann.w <= 1
bbox_ann = BBoxAnnotation(**{"x": 1.2, "y": 0, "w": 0.2, "h": 0})
assert bbox_ann.x + bbox_ann.w <= 1
bbox_ann = BBoxAnnotation(**{"x": 0, "y": 0.9, "w": 0, "h": 0.2})
assert bbox_ann.y + bbox_ann.h <= 1
bbox_ann = BBoxAnnotation(**{"x": 0, "y": 1.2, "w": 0, "h": 0.2})
assert bbox_ann.y + bbox_ann.h <= 1


def test_kpt_no_auto_clip():
with pytest.raises(pydantic.ValidationError):
KeypointAnnotation(**{"keypoints": [(-2.1, 1.1, 0)]})
with pytest.raises(pydantic.ValidationError):
KeypointAnnotation(**{"keypoints": [(0.1, 2.1, 1)]})
with pytest.raises(pydantic.ValidationError):
KeypointAnnotation(**{"keypoints": [(0.1, 1.1, 2), (0.1, 2.1, 1)]})


def test_kpt_auto_clip():
kpt_ann = KeypointAnnotation(**{"keypoints": [(-1.1, 1.1, 0)]})
assert (
0 <= kpt_ann.keypoints[0][0] <= 1 and 0 <= kpt_ann.keypoints[0][1] <= 1
)
kpt_ann = KeypointAnnotation(**{"keypoints": [(0.1, 1.1, 1)]})
assert (
0 <= kpt_ann.keypoints[0][0] <= 1 and 0 <= kpt_ann.keypoints[0][1] <= 1
)
kpt_ann = KeypointAnnotation(**{"keypoints": [(-2, 2, 2)]})
assert (
0 <= kpt_ann.keypoints[0][0] <= 1 and 0 <= kpt_ann.keypoints[0][1] <= 1
)


def test_poly_no_auto_clip():
with pytest.raises(pydantic.ValidationError):
PolylineSegmentationAnnotation(
**{"points": [(-2.1, 1.1), (-2.1, 2.1), (-0.1, -2.1)]}
)


def test_poly_auto_clip():
poly_ann = PolylineSegmentationAnnotation(
**{"points": [(-0.1, 1.1), (-2, 2), (-0.1, -1.1)]}
)
assert 0 <= poly_ann.points[0][0] <= 1 and 0 <= poly_ann.points[0][1] <= 1
assert 0 <= poly_ann.points[1][0] <= 1 and 0 <= poly_ann.points[1][1] <= 1
assert 0 <= poly_ann.points[2][0] <= 1 and 0 <= poly_ann.points[2][1] <= 1
Loading