Skip to content

Commit

Permalink
Types fix. (#38)
Browse files Browse the repository at this point in the history
* Types fix.

* Fix tests.
  • Loading branch information
kkeroo authored Aug 30, 2024
1 parent a19b209 commit 2bbe0bb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
8 changes: 4 additions & 4 deletions depthai_nodes/ml/parsers/utils/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def non_max_suppression(
prediction: np.ndarray,
conf_thres: float = 0.5,
iou_thres: float = 0.45,
classes: list = None,
classes: List = None,
num_classes: int = 1,
agnostic: bool = False,
multi_label: bool = False,
Expand All @@ -51,7 +51,7 @@ def non_max_suppression(
max_nms: int = 30000,
max_wh: int = 7680,
kpts_mode: bool = False,
) -> list[np.ndarray]:
) -> List[np.ndarray]:
"""Performs Non-Maximum Suppression (NMS) on inference results.
@param prediction: Prediction from the model, shape = (batch_size, boxes, xy+wh+...)
Expand All @@ -61,7 +61,7 @@ def non_max_suppression(
@param iou_thres: Intersection over union threshold.
@type iou_thres: float
@param classes: For filtering by classes.
@type classes: list
@type classes: List
@param num_classes: Number of classes.
@type num_classes: int
@param agnostic: Runs NMS on all boxes together rather than per class if True.
Expand All @@ -79,7 +79,7 @@ def non_max_suppression(
@param kpts_mode: Keypoints mode.
@type kpts_mode: bool
@return: An array of detections with either kpts or segmentation outputs.
@rtype: list[np.ndarray]
@rtype: List[np.ndarray]
"""
bs = prediction.shape[0] # batch size
# Keypoints: 4 (bbox) + 1 (objectness) + 51 (kpts) = 56
Expand Down
9 changes: 4 additions & 5 deletions tests/unittests/test_creators/test_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest

from depthai_nodes.ml.messages import (
ImgDetectionsWithKeypoints,
ImgDetectionWithKeypoints,
ImgDetectionExtended,
ImgDetectionsExtended,
)
from depthai_nodes.ml.messages.creators.detection import create_detection_message

Expand Down Expand Up @@ -172,10 +172,9 @@ def test_bboxes_scores_keypoints():

message = create_detection_message(bboxes, scores, None, keypoints)

assert isinstance(message, ImgDetectionsWithKeypoints)
assert isinstance(message, ImgDetectionsExtended)
assert all(
isinstance(detection, ImgDetectionWithKeypoints)
for detection in message.detections
isinstance(detection, ImgDetectionExtended) for detection in message.detections
)

for i, detection in enumerate(message.detections):
Expand Down

0 comments on commit 2bbe0bb

Please sign in to comment.