Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unittests for creators #34

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
pre-commit>=3.2.1
pydoctor
pydoctor
pytest-dependency>=0.6.0
pytest-cov>=4.1.0
pytest-subtests>=0.12.1
pytest-md>=0.2.0
coverage-badge>=1.1.0
Empty file added tests/unittests/__init__.py
Empty file.
168 changes: 168 additions & 0 deletions tests/unittests/test_creators/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import pytest

from depthai_nodes.ml.messages import Classifications
from depthai_nodes.ml.messages.creators.classification import (
create_classification_message,
)


def test_none_classe():
with pytest.raises(ValueError):
create_classification_message(None, [0.5, 0.2, 0.3])


def test_none_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], None)


def test_none_both():
with pytest.raises(ValueError):
create_classification_message(None, None)


def test_empty_classes():
with pytest.raises(ValueError):
create_classification_message([], [0.5, 0.2, 0.3])


def test_empty_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [])


def test_non_list_classes():
with pytest.raises(ValueError):
create_classification_message("cat", [0.5, 0.2, 0.3])


def test_non_np_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [int(1), int(2), int(3)])


def test_tuple_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], (0.5, 0.2, 0.3))


def test_nd_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [[0.5, 0.2, 0.3]])


def test_mixed_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [0.5, 0.2, "30"])


def test_non_probability_scores():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [0.2, 0.3, 0.4])


def test_non_probability_scores_2():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [0.5, 0.5, 0.5])


def test_sum_above_upper_thr(): # upper thr is 1.01001001
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [0.5, 0.11001001, 0.4])


def test_sum_below_upper_thr():
create_classification_message(["cat", "dog", "bird"], [0.5, 0.11001000, 0.4])


def test_sum_above_lower_thr(): # lower thr is 0.98999001
create_classification_message(["cat", "dog", "bird"], [0.5, 0.18999001, 0.3])


def test_sum_below_bottom_thr():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [0.5, 0.18999, 0.3])


def test_mismatch_lengths():
with pytest.raises(ValueError):
create_classification_message(["cat", "dog", "bird"], [0.5, 0.2, 0.3, 0.4])


def test_correct_input():
classes = ["cat", "dog", "bird"]
scores = [0.2, 0.5, 0.3]

correct_classes = ["dog", "bird", "cat"]
correct_scores = [0.5, 0.3, 0.2]

message = create_classification_message(classes, scores)

assert isinstance(message, Classifications)
assert message.classes == correct_classes
assert message.scores == correct_scores
assert isinstance(message.classes, list)
assert isinstance(message.scores, list)


def test_single_class_and_score():
classes = ["cat"]
scores = [1.0]

message = create_classification_message(classes, scores)
assert message.classes == ["cat"]
assert message.scores == [1.0]


def test_correct_input_with_mixed_classes():
classes = ["cat", 1, None]
scores = [0.2, 0.5, 0.3]

correct_classes = [1, None, "cat"]
correct_scores = [0.5, 0.3, 0.2]

message = create_classification_message(classes, scores)

assert isinstance(message, Classifications)
assert message.classes == correct_classes
assert message.scores == correct_scores
assert isinstance(message.classes, list)
assert isinstance(message.scores, list)


def test_very_small_scores():
classes = ["cat", "dog", "bird"]
scores = [1e-10, 1e-10, 1 - 2e-10]

message = create_classification_message(classes, scores)

assert isinstance(message, Classifications)
assert message.classes == ["bird", "cat", "dog"]
assert message.scores == [1 - 2e-10, 1e-10, 1e-10]


def test_identical_scores():
classes = ["cat", "dog", "bird"]
scores = [1 / 3, 1 / 3, 1 / 3]

message = create_classification_message(classes, scores)

assert message.classes == classes
assert message.scores == scores


def test_duplicate_scores():
classes = ["cat", "dog", "bird"]
scores = [0.4, 0.2, 0.4]

correct_classes = ["cat", "bird", "dog"]
correct_scores = [0.4, 0.4, 0.2]

message = create_classification_message(classes, scores)

assert message.classes == correct_classes
assert message.scores == correct_scores


if __name__ == "__main__":
pytest.main()
81 changes: 81 additions & 0 deletions tests/unittests/test_creators/test_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import re

import depthai as dai
import numpy as np
import pytest

from depthai_nodes.ml.messages.creators.depth import create_depth_message

UINT16_MAX_VALUE = 65535
np.random.seed(0)


def test_not_numpy_array():
with pytest.raises(ValueError, match="Expected numpy array, got <class 'list'>."):
create_depth_message([1, 2, 3], "relative")


def test_wrong_literal_type():
with pytest.raises(ValueError):
create_depth_message(np.array([1, 2, 3]), "wrong")


def test_not_3d_input():
with pytest.raises(ValueError, match="Expected 3D input, got 1D input."):
create_depth_message(np.array([1, 2, 3]), "relative")


def test_wrong_input_shape():
with pytest.raises(
ValueError,
match=re.escape("Unexpected image shape. Expected CHW or HWC, got (3, 1, 3)."),
):
create_depth_message(
np.array([[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]), "relative"
)


def test_metric_not_implemented():
with pytest.raises(
NotImplementedError,
match="The message for 'metric' depth type is not yet implemented.",
):
create_depth_message(np.array([[[1, 2, 3]]]), "metric")


def test_depth_map():
depth_map = np.random.rand(320, 640, 1)

message = create_depth_message(depth_map, "relative")
depth_map = depth_map[:, :, 0]

assert isinstance(message, dai.ImgFrame)
assert message.getType() == dai.ImgFrame.Type.RAW16
assert message.getWidth() == 640
assert message.getHeight() == 320

frame = message.getFrame()
assert frame.shape == depth_map.shape
scaled_depth_map = (
(depth_map - depth_map.min())
/ (depth_map.max() - depth_map.min())
* UINT16_MAX_VALUE
)
scaled_depth_map = scaled_depth_map.astype(np.uint16)
assert np.all(np.isclose(frame, scaled_depth_map))


def test_same_depth():
depth_map = np.ones((320, 640, 1))
message = create_depth_message(depth_map, "relative")

assert isinstance(message, dai.ImgFrame)
assert message.getType() == dai.ImgFrame.Type.RAW16
assert message.getWidth() == 640
assert message.getHeight() == 320
frame = message.getFrame()
assert np.all(np.isclose(frame, np.zeros((320, 640))))


if __name__ == "__main__":
pytest.main()
Loading