From 3e4ca49abc8ef08192fc192105f1031ab98b360e Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:25:27 +0200 Subject: [PATCH 1/2] Types fix. --- depthai_nodes/ml/parsers/utils/yolo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/depthai_nodes/ml/parsers/utils/yolo.py b/depthai_nodes/ml/parsers/utils/yolo.py index 56011a1..a8ca060 100644 --- a/depthai_nodes/ml/parsers/utils/yolo.py +++ b/depthai_nodes/ml/parsers/utils/yolo.py @@ -42,7 +42,7 @@ def non_max_suppression( prediction: np.ndarray, conf_thres: float = 0.5, iou_thres: float = 0.45, - classes: list = None, + classes: List = None, num_classes: int = 1, agnostic: bool = False, multi_label: bool = False, @@ -51,7 +51,7 @@ def non_max_suppression( max_nms: int = 30000, max_wh: int = 7680, kpts_mode: bool = False, -) -> list[np.ndarray]: +) -> List[np.ndarray]: """Performs Non-Maximum Suppression (NMS) on inference results. @param prediction: Prediction from the model, shape = (batch_size, boxes, xy+wh+...) @@ -61,7 +61,7 @@ def non_max_suppression( @param iou_thres: Intersection over union threshold. @type iou_thres: float @param classes: For filtering by classes. - @type classes: list + @type classes: List @param num_classes: Number of classes. @type num_classes: int @param agnostic: Runs NMS on all boxes together rather than per class if True. @@ -79,7 +79,7 @@ def non_max_suppression( @param kpts_mode: Keypoints mode. @type kpts_mode: bool @return: An array of detections with either kpts or segmentation outputs. - @rtype: list[np.ndarray] + @rtype: List[np.ndarray] """ bs = prediction.shape[0] # batch size # Keypoints: 4 (bbox) + 1 (objectness) + 51 (kpts) = 56 From 1bb0d37767406d12881490e2003cf50f6bf1238a Mon Sep 17 00:00:00 2001 From: kkeroo <61207502+kkeroo@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:32:08 +0200 Subject: [PATCH 2/2] Fix tests. --- tests/unittests/test_creators/test_detections.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/unittests/test_creators/test_detections.py b/tests/unittests/test_creators/test_detections.py index 8747a0f..761b49f 100644 --- a/tests/unittests/test_creators/test_detections.py +++ b/tests/unittests/test_creators/test_detections.py @@ -5,8 +5,8 @@ import pytest from depthai_nodes.ml.messages import ( - ImgDetectionsWithKeypoints, - ImgDetectionWithKeypoints, + ImgDetectionExtended, + ImgDetectionsExtended, ) from depthai_nodes.ml.messages.creators.detection import create_detection_message @@ -172,10 +172,9 @@ def test_bboxes_scores_keypoints(): message = create_detection_message(bboxes, scores, None, keypoints) - assert isinstance(message, ImgDetectionsWithKeypoints) + assert isinstance(message, ImgDetectionsExtended) assert all( - isinstance(detection, ImgDetectionWithKeypoints) - for detection in message.detections + isinstance(detection, ImgDetectionExtended) for detection in message.detections ) for i, detection in enumerate(message.detections):