From e0b9d7cc59bb66991c949ca8840461e0b127d800 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Tue, 10 Sep 2024 15:14:19 +0200 Subject: [PATCH 01/16] refactor: separate visualizers and their utils --- .../visualization/visualizers/__init__.py | 19 +++++++++++++++++++ .../{ => visualizers}/classification.py | 2 +- .../{ => visualizers}/detection.py | 4 ++-- .../visualization/{ => visualizers}/image.py | 2 +- .../{ => visualizers}/keypoints.py | 2 +- .../{ => visualizers}/segmentation.py | 4 ++-- .../visualizers/utils/__init__.py | 0 .../{ => visualizers/utils}/colors.py | 0 .../utils/message_parsers.py} | 7 +++++++ 9 files changed, 33 insertions(+), 7 deletions(-) create mode 100644 examples/visualization/visualizers/__init__.py rename examples/visualization/{ => visualizers}/classification.py (94%) rename examples/visualization/{ => visualizers}/detection.py (98%) rename examples/visualization/{ => visualizers}/image.py (86%) rename examples/visualization/{ => visualizers}/keypoints.py (90%) rename examples/visualization/{ => visualizers}/segmentation.py (94%) create mode 100644 examples/visualization/visualizers/utils/__init__.py rename examples/visualization/{ => visualizers/utils}/colors.py (100%) rename examples/visualization/{messages.py => visualizers/utils/message_parsers.py} (93%) diff --git a/examples/visualization/visualizers/__init__.py b/examples/visualization/visualizers/__init__.py new file mode 100644 index 0000000..0ed8632 --- /dev/null +++ b/examples/visualization/visualizers/__init__.py @@ -0,0 +1,19 @@ +from .classification import visualize_age_gender, visualize_classification +from .detection import visualize_detections, visualize_line_detections, visualize_yolo_extended +from .image import visualize_image +from .keypoints import visualize_keypoints +from .segmentation import visualize_fastsam, visualize_segmentation +from .map import visualize_map + +__all__ = [ + "visualize_image", + "visualize_segmentation", + "visualize_keypoints", + "visualize_classification", + "visualize_map", + "visualize_age_gender", + "visualize_yolo_extended", + "visualize_detections", + "visualize_line_detections", + "visualize_fastsam", +] \ No newline at end of file diff --git a/examples/visualization/classification.py b/examples/visualization/visualizers/classification.py similarity index 94% rename from examples/visualization/classification.py rename to examples/visualization/visualizers/classification.py index 90c3bb0..5e8025a 100644 --- a/examples/visualization/classification.py +++ b/examples/visualization/visualizers/classification.py @@ -3,7 +3,7 @@ from depthai_nodes.ml.messages import AgeGender, Classifications -from .messages import parse_classification_message, parser_age_gender_message +from utils.message_parsers import parse_classification_message, parser_age_gender_message def visualize_classification( diff --git a/examples/visualization/detection.py b/examples/visualization/visualizers/detection.py similarity index 98% rename from examples/visualization/detection.py rename to examples/visualization/visualizers/detection.py index bfc1c6c..4acc7de 100644 --- a/examples/visualization/detection.py +++ b/examples/visualization/visualizers/detection.py @@ -4,8 +4,8 @@ from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines -from .colors import get_yolo_colors -from .messages import ( +from utils.colors import get_yolo_colors +from utils.message_parsers import ( parse_detection_message, parse_line_detection_message, parse_yolo_kpts_message, diff --git a/examples/visualization/image.py b/examples/visualization/visualizers/image.py similarity index 86% rename from examples/visualization/image.py rename to examples/visualization/visualizers/image.py index 19255cb..24aa5b5 100644 --- a/examples/visualization/image.py +++ b/examples/visualization/visualizers/image.py @@ -1,7 +1,7 @@ import cv2 import depthai as dai -from .messages import parse_image_message +from utils.message_parsers import parse_image_message def visualize_image(frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict): diff --git a/examples/visualization/keypoints.py b/examples/visualization/visualizers/keypoints.py similarity index 90% rename from examples/visualization/keypoints.py rename to examples/visualization/visualizers/keypoints.py index a587c2e..74f87df 100644 --- a/examples/visualization/keypoints.py +++ b/examples/visualization/visualizers/keypoints.py @@ -3,7 +3,7 @@ from depthai_nodes.ml.messages import Keypoints -from .messages import parse_keypoints_message +from utils.message_parsers import parse_keypoints_message def visualize_keypoints(frame: dai.ImgFrame, message: Keypoints, extraParams: dict): diff --git a/examples/visualization/segmentation.py b/examples/visualization/visualizers/segmentation.py similarity index 94% rename from examples/visualization/segmentation.py rename to examples/visualization/visualizers/segmentation.py index 4ebbd82..f6734a7 100644 --- a/examples/visualization/segmentation.py +++ b/examples/visualization/visualizers/segmentation.py @@ -4,8 +4,8 @@ from depthai_nodes.ml.messages import SegmentationMasks -from .colors import get_adas_colors, get_ewasr_colors, get_selfie_colors -from .messages import parse_fast_sam_message, parse_segmentation_message +from utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors +from utils.message_parsers import parse_fast_sam_message, parse_segmentation_message def visualize_segmentation( diff --git a/examples/visualization/visualizers/utils/__init__.py b/examples/visualization/visualizers/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/visualization/colors.py b/examples/visualization/visualizers/utils/colors.py similarity index 100% rename from examples/visualization/colors.py rename to examples/visualization/visualizers/utils/colors.py diff --git a/examples/visualization/messages.py b/examples/visualization/visualizers/utils/message_parsers.py similarity index 93% rename from examples/visualization/messages.py rename to examples/visualization/visualizers/utils/message_parsers.py index e870e90..baf4ed5 100644 --- a/examples/visualization/messages.py +++ b/examples/visualization/visualizers/utils/message_parsers.py @@ -6,6 +6,7 @@ ImgDetectionsExtended, Keypoints, Lines, + Map2D, SegmentationMasks, ) @@ -69,3 +70,9 @@ def parse_fast_sam_message(message: SegmentationMasks): """Parses the fast sam message and returns the masks.""" masks = message.masks return masks + + +def parse_map_message(message: Map2D): + """Parses the map message and returns the map.""" + map = message.map + return map From 6abaaed9cfa3aa1a876565a7a23a68449e56aa80 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Tue, 10 Sep 2024 15:15:00 +0200 Subject: [PATCH 02/16] refactor: remove mapping.py --- examples/visualization/mapping.py | 29 ----------------------------- examples/visualization/visualize.py | 23 +++++++++++++++++++++-- 2 files changed, 21 insertions(+), 31 deletions(-) delete mode 100644 examples/visualization/mapping.py diff --git a/examples/visualization/mapping.py b/examples/visualization/mapping.py deleted file mode 100644 index 5ec9e84..0000000 --- a/examples/visualization/mapping.py +++ /dev/null @@ -1,29 +0,0 @@ -from .classification import visualize_age_gender, visualize_classification -from .detection import ( - visualize_detections, - visualize_line_detections, - visualize_yolo_extended, -) -from .image import visualize_image -from .keypoints import visualize_keypoints -from .segmentation import visualize_fastsam, visualize_segmentation - -parser_mapping = { - "YuNetParser": visualize_detections, - "SCRFDParser": visualize_detections, - "MPPalmDetectionParser": visualize_detections, - "YOLO": visualize_detections, - "SSD": visualize_detections, - "SegmentationParser": visualize_segmentation, - "MLSDParser": visualize_line_detections, - "KeypointParser": visualize_keypoints, - "HRNetParser": visualize_keypoints, - "SuperAnimalParser": visualize_keypoints, - "MPHandLandmarkParser": visualize_keypoints, - "ClassificationParser": visualize_classification, - "ImageOutputParser": visualize_image, - "MonocularDepthParser": visualize_image, - "AgeGenderParser": visualize_age_gender, - "YOLOExtendedParser": visualize_yolo_extended, - "FastSAMParser": visualize_fastsam, -} diff --git a/examples/visualization/visualize.py b/examples/visualization/visualize.py index c4ce279..adb7118 100644 --- a/examples/visualization/visualize.py +++ b/examples/visualization/visualize.py @@ -1,12 +1,31 @@ import depthai as dai -from .mapping import parser_mapping +from visualizers import * +visualizers = { + "YuNetParser": visualize_detections, + "SCRFDParser": visualize_detections, + "MPPalmDetectionParser": visualize_detections, + "YOLO": visualize_detections, + "SSD": visualize_detections, + "SegmentationParser": visualize_segmentation, + "MLSDParser": visualize_line_detections, + "KeypointParser": visualize_keypoints, + "HRNetParser": visualize_keypoints, + "SuperAnimalParser": visualize_keypoints, + "MPHandLandmarkParser": visualize_keypoints, + "ClassificationParser": visualize_classification, + "ImageOutputParser": visualize_image, + "MapOutputParser": visualize_map, + "AgeGenderParser": visualize_age_gender, + "YOLOExtendedParser": visualize_yolo_extended, + "FastSAMParser": visualize_fastsam, +} def visualize( frame: dai.ImgFrame, message: dai.Buffer, parser_name: str, extraParams: dict ): """Calls the appropriate visualizer based on the parser name and returns True if the pipeline should be stopped.""" - visualizer = parser_mapping[parser_name] + visualizer = visualizers[parser_name] return visualizer(frame, message, extraParams) From aa2889631862651746a4d99e5e7378585f92bccc Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Tue, 10 Sep 2024 15:20:42 +0200 Subject: [PATCH 03/16] feat: add support for models with 2D map output --- examples/utils/parser.py | 20 ++++++--------- examples/visualization/visualizers/map.py | 30 +++++++++++++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) create mode 100644 examples/visualization/visualizers/map.py diff --git a/examples/utils/parser.py b/examples/utils/parser.py index 360116e..3907cdb 100644 --- a/examples/utils/parser.py +++ b/examples/utils/parser.py @@ -4,7 +4,7 @@ ClassificationParser, FastSAMParser, KeypointParser, - MonocularDepthParser, + MapOutputParser, MPPalmDetectionParser, SCRFDParser, SegmentationParser, @@ -63,19 +63,15 @@ def setup_classification_parser(parser: ClassificationParser, params: dict): ) -def setup_monocular_depth_parser(parser: MonocularDepthParser, params: dict): +def setup_map_output_parser(parser: MapOutputParser, params: dict): """Setup the monocular depth parser with the required metadata.""" try: - depth_type = params["depth_type"] - depth_limit = params["depth_limit"] - if depth_type == "relative": - parser.setRelativeDepthType() - else: - parser.setMetricDepthType() - parser.setDepthLimit(depth_limit) + min_max_scaling = params["min_max_scaling"] + if min_max_scaling: + parser.setMinMaxScaling(True) except Exception: print( - "This NN archive does not have required metadata for MonocularDepthParser. Skipping setup..." + "This NN archive does not have required metadata for MapOutputParser. Skipping setup..." ) @@ -145,8 +141,8 @@ def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_nam setup_keypoint_parser(parser, extraParams) elif parser_name == "ClassificationParser": setup_classification_parser(parser, extraParams) - elif parser_name == "MonocularDepthParser": - setup_monocular_depth_parser(parser, extraParams) + elif parser_name == "MapOutputParser": + setup_map_output_parser(parser, extraParams) elif parser_name == "XFeatParser": setup_xfeat_parser(parser, extraParams) elif parser_name == "YOLOExtendedParser": diff --git a/examples/visualization/visualizers/map.py b/examples/visualization/visualizers/map.py new file mode 100644 index 0000000..bfa310a --- /dev/null +++ b/examples/visualization/visualizers/map.py @@ -0,0 +1,30 @@ +import cv2 +import depthai as dai +import numpy as np + +from depthai_nodes.ml.messages import Map2D + +from utils.message_parsers import parse_map_message + + +def visualize_map(frame: dai.ImgFrame, message: Map2D, extraParams: dict): + """Visualizes the map on the frame.""" + + map = parse_map_message(message) + + # make color representation of the map + map_normalized = cv2.normalize(map, None, 0, 255, cv2.NORM_MINMAX) + map_normalized = map_normalized.astype(np.uint8) + colored_map = cv2.applyColorMap(map_normalized, cv2.COLORMAP_INFERNO) + frame_height, frame_width, _ = frame.shape + colored_map = cv2.resize(colored_map, (frame_width, frame_height), interpolation = cv2.INTER_LINEAR) + + alpha = 0.6 + overlay = cv2.addWeighted(colored_map, alpha, frame, 1 - alpha, 0) + + cv2.imshow("Map Overlay", overlay) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False \ No newline at end of file From c1f11eeeff281d2611ff733080405b150fbd11c3 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Tue, 10 Sep 2024 15:24:55 +0200 Subject: [PATCH 04/16] docs: add instalation step to the instructions --- examples/README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 8c284c7..c839a96 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,8 +1,18 @@ # DepthAI Nodes examples -The `main.py` script lets you run fully-automated pipeline with the model of your choice. To run the script you need the model slug and then move to `examples` folder and run: +The `main.py` script lets you run fully-automated pipeline with the model of your choice. To run the script: + +Make sure you have the `depthai-nodes` package installed: + +``` +cd depthai-nodes +pip install -e . +``` + +Prepare the model slug and run: ``` +cd examples python main.py -s ``` From 82e0ba32852d4169baf27bdff11c61b814e8cd2a Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Tue, 10 Sep 2024 15:26:46 +0200 Subject: [PATCH 05/16] fix: pre-commit --- examples/README.md | 2 +- examples/visualization/visualize.py | 15 +++++++++++++-- examples/visualization/visualizers/__init__.py | 10 +++++++--- .../visualization/visualizers/classification.py | 6 ++++-- examples/visualization/visualizers/detection.py | 5 ++--- examples/visualization/visualizers/image.py | 1 - examples/visualization/visualizers/keypoints.py | 3 +-- examples/visualization/visualizers/map.py | 13 +++++++------ .../visualization/visualizers/segmentation.py | 5 ++--- 9 files changed, 37 insertions(+), 23 deletions(-) diff --git a/examples/README.md b/examples/README.md index c839a96..49d6027 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,6 +1,6 @@ # DepthAI Nodes examples -The `main.py` script lets you run fully-automated pipeline with the model of your choice. To run the script: +The `main.py` script lets you run fully-automated pipeline with the model of your choice. To run the script: Make sure you have the `depthai-nodes` package installed: diff --git a/examples/visualization/visualize.py b/examples/visualization/visualize.py index adb7118..a3ddbf6 100644 --- a/examples/visualization/visualize.py +++ b/examples/visualization/visualize.py @@ -1,6 +1,16 @@ import depthai as dai - -from visualizers import * +from visualizers import ( + visualize_age_gender, + visualize_classification, + visualize_detections, + visualize_fastsam, + visualize_image, + visualize_keypoints, + visualize_line_detections, + visualize_map, + visualize_segmentation, + visualize_yolo_extended, +) visualizers = { "YuNetParser": visualize_detections, @@ -22,6 +32,7 @@ "FastSAMParser": visualize_fastsam, } + def visualize( frame: dai.ImgFrame, message: dai.Buffer, parser_name: str, extraParams: dict ): diff --git a/examples/visualization/visualizers/__init__.py b/examples/visualization/visualizers/__init__.py index 0ed8632..08b4e15 100644 --- a/examples/visualization/visualizers/__init__.py +++ b/examples/visualization/visualizers/__init__.py @@ -1,9 +1,13 @@ from .classification import visualize_age_gender, visualize_classification -from .detection import visualize_detections, visualize_line_detections, visualize_yolo_extended +from .detection import ( + visualize_detections, + visualize_line_detections, + visualize_yolo_extended, +) from .image import visualize_image from .keypoints import visualize_keypoints -from .segmentation import visualize_fastsam, visualize_segmentation from .map import visualize_map +from .segmentation import visualize_fastsam, visualize_segmentation __all__ = [ "visualize_image", @@ -16,4 +20,4 @@ "visualize_detections", "visualize_line_detections", "visualize_fastsam", -] \ No newline at end of file +] diff --git a/examples/visualization/visualizers/classification.py b/examples/visualization/visualizers/classification.py index 5e8025a..99d3fc3 100644 --- a/examples/visualization/visualizers/classification.py +++ b/examples/visualization/visualizers/classification.py @@ -1,10 +1,12 @@ import cv2 import depthai as dai +from utils.message_parsers import ( + parse_classification_message, + parser_age_gender_message, +) from depthai_nodes.ml.messages import AgeGender, Classifications -from utils.message_parsers import parse_classification_message, parser_age_gender_message - def visualize_classification( frame: dai.ImgFrame, message: Classifications, extraParams: dict diff --git a/examples/visualization/visualizers/detection.py b/examples/visualization/visualizers/detection.py index 4acc7de..ad1ecd2 100644 --- a/examples/visualization/visualizers/detection.py +++ b/examples/visualization/visualizers/detection.py @@ -1,9 +1,6 @@ import cv2 import depthai as dai import numpy as np - -from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines - from utils.colors import get_yolo_colors from utils.message_parsers import ( parse_detection_message, @@ -11,6 +8,8 @@ parse_yolo_kpts_message, ) +from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines + def visualize_detections( frame: dai.ImgFrame, message: dai.ImgDetections, extraParams: dict diff --git a/examples/visualization/visualizers/image.py b/examples/visualization/visualizers/image.py index 24aa5b5..e030d73 100644 --- a/examples/visualization/visualizers/image.py +++ b/examples/visualization/visualizers/image.py @@ -1,6 +1,5 @@ import cv2 import depthai as dai - from utils.message_parsers import parse_image_message diff --git a/examples/visualization/visualizers/keypoints.py b/examples/visualization/visualizers/keypoints.py index 74f87df..89f5452 100644 --- a/examples/visualization/visualizers/keypoints.py +++ b/examples/visualization/visualizers/keypoints.py @@ -1,10 +1,9 @@ import cv2 import depthai as dai +from utils.message_parsers import parse_keypoints_message from depthai_nodes.ml.messages import Keypoints -from utils.message_parsers import parse_keypoints_message - def visualize_keypoints(frame: dai.ImgFrame, message: Keypoints, extraParams: dict): """Visualizes the keypoints on the frame.""" diff --git a/examples/visualization/visualizers/map.py b/examples/visualization/visualizers/map.py index bfa310a..40f628e 100644 --- a/examples/visualization/visualizers/map.py +++ b/examples/visualization/visualizers/map.py @@ -1,15 +1,14 @@ import cv2 import depthai as dai import numpy as np +from utils.message_parsers import parse_map_message from depthai_nodes.ml.messages import Map2D -from utils.message_parsers import parse_map_message - def visualize_map(frame: dai.ImgFrame, message: Map2D, extraParams: dict): """Visualizes the map on the frame.""" - + map = parse_map_message(message) # make color representation of the map @@ -17,14 +16,16 @@ def visualize_map(frame: dai.ImgFrame, message: Map2D, extraParams: dict): map_normalized = map_normalized.astype(np.uint8) colored_map = cv2.applyColorMap(map_normalized, cv2.COLORMAP_INFERNO) frame_height, frame_width, _ = frame.shape - colored_map = cv2.resize(colored_map, (frame_width, frame_height), interpolation = cv2.INTER_LINEAR) + colored_map = cv2.resize( + colored_map, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR + ) alpha = 0.6 overlay = cv2.addWeighted(colored_map, alpha, frame, 1 - alpha, 0) - + cv2.imshow("Map Overlay", overlay) if cv2.waitKey(1) == ord("q"): cv2.destroyAllWindows() return True - return False \ No newline at end of file + return False diff --git a/examples/visualization/visualizers/segmentation.py b/examples/visualization/visualizers/segmentation.py index f6734a7..0998475 100644 --- a/examples/visualization/visualizers/segmentation.py +++ b/examples/visualization/visualizers/segmentation.py @@ -1,12 +1,11 @@ import cv2 import depthai as dai import numpy as np - -from depthai_nodes.ml.messages import SegmentationMasks - from utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors from utils.message_parsers import parse_fast_sam_message, parse_segmentation_message +from depthai_nodes.ml.messages import SegmentationMasks + def visualize_segmentation( frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict From af1686da5a5dc8f0f39302f64edc8a785b66a1e3 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 09:32:34 +0200 Subject: [PATCH 06/16] fix: visualizers' imports --- examples/visualization/visualize.py | 2 +- examples/visualization/visualizers/classification.py | 2 +- examples/visualization/visualizers/detection.py | 4 ++-- examples/visualization/visualizers/image.py | 2 +- examples/visualization/visualizers/keypoints.py | 2 +- examples/visualization/visualizers/map.py | 2 +- examples/visualization/visualizers/segmentation.py | 4 ++-- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/visualization/visualize.py b/examples/visualization/visualize.py index a3ddbf6..51e38fd 100644 --- a/examples/visualization/visualize.py +++ b/examples/visualization/visualize.py @@ -1,5 +1,5 @@ import depthai as dai -from visualizers import ( +from .visualizers import ( visualize_age_gender, visualize_classification, visualize_detections, diff --git a/examples/visualization/visualizers/classification.py b/examples/visualization/visualizers/classification.py index 99d3fc3..da41601 100644 --- a/examples/visualization/visualizers/classification.py +++ b/examples/visualization/visualizers/classification.py @@ -1,6 +1,6 @@ import cv2 import depthai as dai -from utils.message_parsers import ( +from .utils.message_parsers import ( parse_classification_message, parser_age_gender_message, ) diff --git a/examples/visualization/visualizers/detection.py b/examples/visualization/visualizers/detection.py index ad1ecd2..fbfa5e4 100644 --- a/examples/visualization/visualizers/detection.py +++ b/examples/visualization/visualizers/detection.py @@ -1,8 +1,8 @@ import cv2 import depthai as dai import numpy as np -from utils.colors import get_yolo_colors -from utils.message_parsers import ( +from .utils.colors import get_yolo_colors +from .utils.message_parsers import ( parse_detection_message, parse_line_detection_message, parse_yolo_kpts_message, diff --git a/examples/visualization/visualizers/image.py b/examples/visualization/visualizers/image.py index e030d73..c7ab9c8 100644 --- a/examples/visualization/visualizers/image.py +++ b/examples/visualization/visualizers/image.py @@ -1,6 +1,6 @@ import cv2 import depthai as dai -from utils.message_parsers import parse_image_message +from .utils.message_parsers import parse_image_message def visualize_image(frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict): diff --git a/examples/visualization/visualizers/keypoints.py b/examples/visualization/visualizers/keypoints.py index 89f5452..9ebc14c 100644 --- a/examples/visualization/visualizers/keypoints.py +++ b/examples/visualization/visualizers/keypoints.py @@ -1,6 +1,6 @@ import cv2 import depthai as dai -from utils.message_parsers import parse_keypoints_message +from .utils.message_parsers import parse_keypoints_message from depthai_nodes.ml.messages import Keypoints diff --git a/examples/visualization/visualizers/map.py b/examples/visualization/visualizers/map.py index 40f628e..aca29e6 100644 --- a/examples/visualization/visualizers/map.py +++ b/examples/visualization/visualizers/map.py @@ -1,7 +1,7 @@ import cv2 import depthai as dai import numpy as np -from utils.message_parsers import parse_map_message +from .utils.message_parsers import parse_map_message from depthai_nodes.ml.messages import Map2D diff --git a/examples/visualization/visualizers/segmentation.py b/examples/visualization/visualizers/segmentation.py index 0998475..90ba317 100644 --- a/examples/visualization/visualizers/segmentation.py +++ b/examples/visualization/visualizers/segmentation.py @@ -1,8 +1,8 @@ import cv2 import depthai as dai import numpy as np -from utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors -from utils.message_parsers import parse_fast_sam_message, parse_segmentation_message +from .utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors +from .utils.message_parsers import parse_fast_sam_message, parse_segmentation_message from depthai_nodes.ml.messages import SegmentationMasks From a9cf2688c4ebca1de410e3c4b3687961450251ed Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 09:33:25 +0200 Subject: [PATCH 07/16] fix: add FPS limit to avoid OAK-D Lite errors --- examples/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/main.py b/examples/main.py index 9954321..fb2f249 100644 --- a/examples/main.py +++ b/examples/main.py @@ -4,6 +4,8 @@ from utils.parser import setup_parser from visualization.visualize import visualize +FPS_LIMIT = 28 # adding a limit to the FPS to avoid errors on OAK-D Lite. TODO: remove once fixed + # Initialize the argument parser arg_parser, args = initialize_argparser() @@ -27,7 +29,7 @@ # YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser if parser_name == "YOLO" or parser_name == "SSD": network = pipeline.create(dai.node.DetectionNetwork).build( - cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p), nn_archive + cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p, fps = FPS_LIMIT), nn_archive ) parser_queue = network.out.createOutputQueue() else: @@ -45,13 +47,13 @@ manip = pipeline.create(dai.node.ImageManip) manip.initialConfig.setResize(input_shape) large_input_shape = (input_shape[0] * 4, input_shape[1] * 4) - cam.requestOutput(large_input_shape, type=image_type).link(manip.inputImage) + cam.requestOutput(large_input_shape, type=image_type, fps = FPS_LIMIT).link(manip.inputImage) network = pipeline.create(dai.node.NeuralNetwork).build( manip.out, nn_archive ) else: network = pipeline.create(dai.node.NeuralNetwork).build( - cam.requestOutput(input_shape, type=image_type), nn_archive + cam.requestOutput(input_shape, type=image_type, fps = FPS_LIMIT), nn_archive ) parser = pipeline.create(parser_class) From be64249dcb1c8c87e5ae5cf11ca9457491720bc3 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 09:35:34 +0200 Subject: [PATCH 08/16] fix: pre-commit --- examples/main.py | 14 ++++++++++---- examples/visualization/visualize.py | 1 + .../visualization/visualizers/classification.py | 5 +++-- examples/visualization/visualizers/detection.py | 5 +++-- examples/visualization/visualizers/image.py | 1 + examples/visualization/visualizers/keypoints.py | 3 ++- examples/visualization/visualizers/map.py | 3 ++- examples/visualization/visualizers/segmentation.py | 5 +++-- 8 files changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/main.py b/examples/main.py index fb2f249..097ff3f 100644 --- a/examples/main.py +++ b/examples/main.py @@ -4,7 +4,7 @@ from utils.parser import setup_parser from visualization.visualize import visualize -FPS_LIMIT = 28 # adding a limit to the FPS to avoid errors on OAK-D Lite. TODO: remove once fixed +FPS_LIMIT = 28 # adding a limit to the FPS to avoid errors on OAK-D Lite. TODO: remove once fixed # Initialize the argument parser arg_parser, args = initialize_argparser() @@ -29,7 +29,10 @@ # YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser if parser_name == "YOLO" or parser_name == "SSD": network = pipeline.create(dai.node.DetectionNetwork).build( - cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p, fps = FPS_LIMIT), nn_archive + cam.requestOutput( + input_shape, type=dai.ImgFrame.Type.BGR888p, fps=FPS_LIMIT + ), + nn_archive, ) parser_queue = network.out.createOutputQueue() else: @@ -47,13 +50,16 @@ manip = pipeline.create(dai.node.ImageManip) manip.initialConfig.setResize(input_shape) large_input_shape = (input_shape[0] * 4, input_shape[1] * 4) - cam.requestOutput(large_input_shape, type=image_type, fps = FPS_LIMIT).link(manip.inputImage) + cam.requestOutput(large_input_shape, type=image_type, fps=FPS_LIMIT).link( + manip.inputImage + ) network = pipeline.create(dai.node.NeuralNetwork).build( manip.out, nn_archive ) else: network = pipeline.create(dai.node.NeuralNetwork).build( - cam.requestOutput(input_shape, type=image_type, fps = FPS_LIMIT), nn_archive + cam.requestOutput(input_shape, type=image_type, fps=FPS_LIMIT), + nn_archive, ) parser = pipeline.create(parser_class) diff --git a/examples/visualization/visualize.py b/examples/visualization/visualize.py index 51e38fd..76effb6 100644 --- a/examples/visualization/visualize.py +++ b/examples/visualization/visualize.py @@ -1,4 +1,5 @@ import depthai as dai + from .visualizers import ( visualize_age_gender, visualize_classification, diff --git a/examples/visualization/visualizers/classification.py b/examples/visualization/visualizers/classification.py index da41601..5a19f2b 100644 --- a/examples/visualization/visualizers/classification.py +++ b/examples/visualization/visualizers/classification.py @@ -1,12 +1,13 @@ import cv2 import depthai as dai + +from depthai_nodes.ml.messages import AgeGender, Classifications + from .utils.message_parsers import ( parse_classification_message, parser_age_gender_message, ) -from depthai_nodes.ml.messages import AgeGender, Classifications - def visualize_classification( frame: dai.ImgFrame, message: Classifications, extraParams: dict diff --git a/examples/visualization/visualizers/detection.py b/examples/visualization/visualizers/detection.py index fbfa5e4..b425ee6 100644 --- a/examples/visualization/visualizers/detection.py +++ b/examples/visualization/visualizers/detection.py @@ -1,6 +1,9 @@ import cv2 import depthai as dai import numpy as np + +from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines + from .utils.colors import get_yolo_colors from .utils.message_parsers import ( parse_detection_message, @@ -8,8 +11,6 @@ parse_yolo_kpts_message, ) -from depthai_nodes.ml.messages import ImgDetectionsExtended, Lines - def visualize_detections( frame: dai.ImgFrame, message: dai.ImgDetections, extraParams: dict diff --git a/examples/visualization/visualizers/image.py b/examples/visualization/visualizers/image.py index c7ab9c8..3beb7fa 100644 --- a/examples/visualization/visualizers/image.py +++ b/examples/visualization/visualizers/image.py @@ -1,5 +1,6 @@ import cv2 import depthai as dai + from .utils.message_parsers import parse_image_message diff --git a/examples/visualization/visualizers/keypoints.py b/examples/visualization/visualizers/keypoints.py index 9ebc14c..ea65c8e 100644 --- a/examples/visualization/visualizers/keypoints.py +++ b/examples/visualization/visualizers/keypoints.py @@ -1,9 +1,10 @@ import cv2 import depthai as dai -from .utils.message_parsers import parse_keypoints_message from depthai_nodes.ml.messages import Keypoints +from .utils.message_parsers import parse_keypoints_message + def visualize_keypoints(frame: dai.ImgFrame, message: Keypoints, extraParams: dict): """Visualizes the keypoints on the frame.""" diff --git a/examples/visualization/visualizers/map.py b/examples/visualization/visualizers/map.py index aca29e6..134c15a 100644 --- a/examples/visualization/visualizers/map.py +++ b/examples/visualization/visualizers/map.py @@ -1,10 +1,11 @@ import cv2 import depthai as dai import numpy as np -from .utils.message_parsers import parse_map_message from depthai_nodes.ml.messages import Map2D +from .utils.message_parsers import parse_map_message + def visualize_map(frame: dai.ImgFrame, message: Map2D, extraParams: dict): """Visualizes the map on the frame.""" diff --git a/examples/visualization/visualizers/segmentation.py b/examples/visualization/visualizers/segmentation.py index 90ba317..e61514a 100644 --- a/examples/visualization/visualizers/segmentation.py +++ b/examples/visualization/visualizers/segmentation.py @@ -1,11 +1,12 @@ import cv2 import depthai as dai import numpy as np -from .utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors -from .utils.message_parsers import parse_fast_sam_message, parse_segmentation_message from depthai_nodes.ml.messages import SegmentationMasks +from .utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors +from .utils.message_parsers import parse_fast_sam_message, parse_segmentation_message + def visualize_segmentation( frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict From 2b7b51b6fb07ec54903cddc84d5996e5271386b0 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 10:54:59 +0200 Subject: [PATCH 09/16] feat: add fps_limit argument --- examples/main.py | 11 +++++------ examples/utils/arguments.py | 21 ++++++++++++++++++++- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/examples/main.py b/examples/main.py index 097ff3f..1b921b4 100644 --- a/examples/main.py +++ b/examples/main.py @@ -1,16 +1,15 @@ import depthai as dai -from utils.arguments import initialize_argparser, parse_model_slug +from utils.arguments import initialize_argparser, parse_fps_limit, parse_model_slug from utils.model import get_input_shape, get_model_from_hub, get_parser from utils.parser import setup_parser from visualization.visualize import visualize -FPS_LIMIT = 28 # adding a limit to the FPS to avoid errors on OAK-D Lite. TODO: remove once fixed - # Initialize the argument parser arg_parser, args = initialize_argparser() # Parse the model slug model_slug, model_version_slug = parse_model_slug(args) +fps_limit = parse_fps_limit(args) # Get the model from the HubAI nn_archive = get_model_from_hub(model_slug, model_version_slug) @@ -30,7 +29,7 @@ if parser_name == "YOLO" or parser_name == "SSD": network = pipeline.create(dai.node.DetectionNetwork).build( cam.requestOutput( - input_shape, type=dai.ImgFrame.Type.BGR888p, fps=FPS_LIMIT + input_shape, type=dai.ImgFrame.Type.BGR888p, fps=fps_limit ), nn_archive, ) @@ -50,7 +49,7 @@ manip = pipeline.create(dai.node.ImageManip) manip.initialConfig.setResize(input_shape) large_input_shape = (input_shape[0] * 4, input_shape[1] * 4) - cam.requestOutput(large_input_shape, type=image_type, fps=FPS_LIMIT).link( + cam.requestOutput(large_input_shape, type=image_type, fps=fps_limit).link( manip.inputImage ) network = pipeline.create(dai.node.NeuralNetwork).build( @@ -58,7 +57,7 @@ ) else: network = pipeline.create(dai.node.NeuralNetwork).build( - cam.requestOutput(input_shape, type=image_type, fps=FPS_LIMIT), + cam.requestOutput(input_shape, type=image_type, fps=fps_limit), nn_archive, ) diff --git a/examples/utils/arguments.py b/examples/utils/arguments.py index 3551bf5..3b570e8 100644 --- a/examples/utils/arguments.py +++ b/examples/utils/arguments.py @@ -8,7 +8,7 @@ def initialize_argparser(): parser.description = "General example script to run any model available in HubAI on DepthAI device. \ All you need is a model slug of the model and the script will download the model from HubAI and create \ the whole pipeline with visualizations. You also need a DepthAI device connected to your computer. \ - Currently, only RVC2 devices are supported." + Currently, only RVC2 devices are supported. If using OAK-D Lite, please set the FPS limit to 28." parser.add_argument( "-s", @@ -18,6 +18,15 @@ def initialize_argparser(): type=str, ) + parser.add_argument( + "-l", + "--fps_limit", + help="FPS limit for the model runtime.", + required=False, + default=30.0, + type=float, + ) + args = parser.parse_args() return parser, args @@ -41,3 +50,13 @@ def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]: model_version_slug = model_slug_parts[1] return model_slug, model_version_slug + + +def parse_fps_limit(args: argparse.Namespace) -> int: + """Parse the FPS limit from the arguments. + + Returns the FPS limit. + """ + fps_limit = args.fps_limit + + return fps_limit From 9f387ec55b76ae9aae276f4e4a3650380b3368a0 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 10:56:32 +0200 Subject: [PATCH 10/16] fix: docstring mentioning the removed MonocularDepth parser --- examples/utils/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/utils/parser.py b/examples/utils/parser.py index 3907cdb..d12fd90 100644 --- a/examples/utils/parser.py +++ b/examples/utils/parser.py @@ -64,7 +64,7 @@ def setup_classification_parser(parser: ClassificationParser, params: dict): def setup_map_output_parser(parser: MapOutputParser, params: dict): - """Setup the monocular depth parser with the required metadata.""" + """Setup the map output parser with the required metadata.""" try: min_max_scaling = params["min_max_scaling"] if min_max_scaling: From 32372f6d781af341596d22bc68cce0b96d72cbc5 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 16:48:00 +0200 Subject: [PATCH 11/16] fix: add pytest install to GitHub Actions tests --- .github/workflows/ci.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e0ed2e8..8b21d3b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -93,6 +93,9 @@ jobs: - name: Install package run: pip install -e .[dev] + - name: Install pytest (if not included in dev dependencies) + run: pip install pytest + - name: Run pytest uses: pavelzw/pytest-action@v2 with: From 3a5d090f32ecf9261e4db992c555acd1ff9ba87e Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 18:01:17 +0200 Subject: [PATCH 12/16] fix: change pytest install to 8.3.2 --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8b21d3b..a03ab0c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -94,7 +94,7 @@ jobs: run: pip install -e .[dev] - name: Install pytest (if not included in dev dependencies) - run: pip install pytest + run: pip install pytest==8.3.2 - name: Run pytest uses: pavelzw/pytest-action@v2 From c5d5af9a976372b590c6fd66a719ab3c43dca188 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Mon, 16 Sep 2024 18:27:11 +0200 Subject: [PATCH 13/16] fix: add pytest-cov install to GitHub Actions tests --- .github/workflows/ci.yaml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a03ab0c..0dd4f82 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -93,8 +93,16 @@ jobs: - name: Install package run: pip install -e .[dev] - - name: Install pytest (if not included in dev dependencies) - run: pip install pytest==8.3.2 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-cov + + - name: Check Python version + run: python --version + + - name: Check installed packages + run: pip list - name: Run pytest uses: pavelzw/pytest-action@v2 From a513bb2f979e2fae513c45640b52592feb6b3fc1 Mon Sep 17 00:00:00 2001 From: aljazkonec1 Date: Mon, 16 Sep 2024 21:09:40 +0200 Subject: [PATCH 14/16] Refactor code. --- .../ml/messages/creators/__init__.py | 9 +- .../ml/messages/creators/classification.py | 163 +++++++++++++++++- .../creators/classification_sequence.py | 116 ------------- depthai_nodes/ml/messages/creators/misc.py | 65 +------ depthai_nodes/ml/parsers/__init__.py | 3 +- depthai_nodes/ml/parsers/classification.py | 61 ++++++- depthai_nodes/ml/parsers/ppdet.py | 2 +- .../ml/parsers/vehicle_attributes.py | 61 ------- .../test_classification_sequence.py | 2 +- tests/unittests/test_creators/test_misc.py | 15 +- .../test_creators/test_vehicle_attributes.py | 2 +- 11 files changed, 247 insertions(+), 252 deletions(-) delete mode 100644 depthai_nodes/ml/messages/creators/classification_sequence.py delete mode 100644 depthai_nodes/ml/parsers/vehicle_attributes.py diff --git a/depthai_nodes/ml/messages/creators/__init__.py b/depthai_nodes/ml/messages/creators/__init__.py index 8656ffa..17cdaf0 100644 --- a/depthai_nodes/ml/messages/creators/__init__.py +++ b/depthai_nodes/ml/messages/creators/__init__.py @@ -1,11 +1,14 @@ -from .classification import create_classification_message -from .classification_sequence import create_classification_sequence_message +from .classification import ( + create_classification_message, + create_classification_sequence_message, + create_multi_classification_message, +) from .clusters import create_cluster_message from .detection import create_detection_message, create_line_detection_message from .image import create_image_message from .keypoints import create_hand_keypoints_message, create_keypoints_message from .map import create_map_message -from .misc import create_age_gender_message, create_multi_classification_message +from .misc import create_age_gender_message from .segmentation import create_sam_message, create_segmentation_message from .tracked_features import create_tracked_features_message diff --git a/depthai_nodes/ml/messages/creators/classification.py b/depthai_nodes/ml/messages/creators/classification.py index 42ed263..49b7dd6 100644 --- a/depthai_nodes/ml/messages/creators/classification.py +++ b/depthai_nodes/ml/messages/creators/classification.py @@ -2,7 +2,7 @@ import numpy as np -from ...messages import Classifications +from ...messages import Classifications, CompositeMessage def create_classification_message( @@ -82,3 +82,164 @@ def create_classification_message( classification_msg.scores = scores.tolist() return classification_msg + + +def create_multi_classification_message( + classification_attributes: List[str], + classification_scores: Union[np.ndarray, List[List[float]]], + classification_labels: List[List[str]], +) -> CompositeMessage: + """Create a DepthAI message for multi-classification. + + @param classification_attributes: List of attributes being classified. + @type classification_attributes: List[str] + @param classification_scores: A 2D array or list of classification scores for each + attribute. + @type classification_scores: Union[np.ndarray, List[List[float]]] + @param classification_labels: A 2D list of class labels for each classification + attribute. + @type classification_labels: List[List[str]] + @return: MultiClassification message containing a dictionary of classification + attributes and their respective Classifications. + @rtype: dai.Buffer + @raise ValueError: If number of attributes is not same as number of score-label + pairs. + @raise ValueError: If number of scores is not same as number of labels for each + attribute. + @raise ValueError: If each class score not in the range [0, 1]. + @raise ValueError: If each class score not a probability distribution that sums to + 1. + """ + + if len(classification_attributes) != len(classification_scores) or len( + classification_attributes + ) != len(classification_labels): + raise ValueError( + f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels." + ) + + multi_class_dict = {} + for attribute, scores, labels in zip( + classification_attributes, classification_scores, classification_labels + ): + if len(scores) != len(labels): + raise ValueError( + f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}." + ) + multi_class_dict[attribute] = create_classification_message(labels, scores) + + multi_classification_message = CompositeMessage() + multi_classification_message.setData(multi_class_dict) + + return multi_classification_message + + +def create_classification_sequence_message( + classes: List[str], + scores: Union[np.ndarray, List], + ignored_indexes: List[int] = None, + remove_duplicates: bool = False, + concatenate_text: bool = False, +) -> Classifications: + """Creates a message for a multi-class sequence. The 'scores' array is a sequence of + probabilities for each class at each position in the sequence. The message contains + the class names and their respective scores, ordered according to the sequence. + + @param classes: A list of class names, with length 'n_classes'. + @type classes: List + @param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. + @type scores: np.ndarray + @param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) + @type ignored_indexes: List[int] + @param remove_duplicates: If True, removes consecutive duplicates from the sequence. + @type remove_duplicates: bool + @param concatenate_text: If True, concatenates consecutive words based on the space character. + @type concatenate_text: bool + @return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores. + @rtype: Classifications + @raises ValueError: If 'classes' is not a list of strings. + @raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). + @raises ValueError: If the number of classes does not match the number of columns in 'scores'. + @raises ValueError: If any score is not in the range [0, 1]. + @raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. + @raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. + """ + + if not isinstance(classes, List): + raise ValueError(f"Classes should be a list, got {type(classes)}.") + + if isinstance(scores, List): + scores = np.array(scores) + + if len(scores.shape) != 2: + raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") + + if scores.shape[1] != len(classes): + raise ValueError( + f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." + ) + + if np.any(scores < 0) or np.any(scores > 1): + raise ValueError("Scores should be in the range [0, 1].") + + if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): + raise ValueError("Each row of scores should sum to 1.") + + if ignored_indexes is not None: + if not isinstance(ignored_indexes, List): + raise ValueError( + f"Ignored indexes should be a list, got {type(ignored_indexes)}." + ) + if not all(isinstance(index, int) for index in ignored_indexes): + raise ValueError("Ignored indexes should be integers.") + if np.any(np.array(ignored_indexes) < 0) or np.any( + np.array(ignored_indexes) >= len(classes) + ): + raise ValueError( + "Ignored indexes should be integers in the range [0, num_classes -1]." + ) + + selection = np.ones(len(scores), dtype=bool) + indexes = np.argmax(scores, axis=1) + + if remove_duplicates: + selection[1:] = indexes[1:] != indexes[:-1] + + if ignored_indexes is not None: + selection &= np.array([index not in ignored_indexes for index in indexes]) + + class_list = [classes[i] for i in indexes[selection]] + score_list = np.max(scores, axis=1)[selection] + + if ( + concatenate_text + and len(class_list) > 1 + and all(len(word) <= 1 for word in class_list) + ): + concatenated_scores = [] + concatenated_words = "".join(class_list).split() + cumsumlist = np.cumsum([len(word) for word in concatenated_words]) + + start_index = 0 + for num_spaces, end_index in enumerate(cumsumlist): + word_scores = score_list[start_index + num_spaces : end_index + num_spaces] + concatenated_scores.append(np.mean(word_scores)) + start_index = end_index + + class_list = concatenated_words + score_list = np.array(concatenated_scores) + + elif ( + concatenate_text + and len(class_list) > 1 + and any(len(word) >= 2 for word in class_list) + ): + class_list = [" ".join(class_list)] + score_list = np.mean(score_list) + + classification_msg = Classifications() + + classification_msg.classes = class_list + classification_msg.scores = score_list.tolist() + + return classification_msg diff --git a/depthai_nodes/ml/messages/creators/classification_sequence.py b/depthai_nodes/ml/messages/creators/classification_sequence.py deleted file mode 100644 index 216905f..0000000 --- a/depthai_nodes/ml/messages/creators/classification_sequence.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import List, Union - -import numpy as np - -from .. import Classifications - - -def create_classification_sequence_message( - classes: List[str], - scores: Union[np.ndarray, List], - ignored_indexes: List[int] = None, - remove_duplicates: bool = False, - concatenate_text: bool = False, -) -> Classifications: - """Creates a message for a multi-class sequence. The 'scores' array is a sequence of - probabilities for each class at each position in the sequence. The message contains - the class names and their respective scores, ordered according to the sequence. - - @param classes: A list of class names, with length 'n_classes'. - @type classes: List - @param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. - @type scores: np.ndarray - @param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) - @type ignored_indexes: List[int] - @param remove_duplicates: If True, removes consecutive duplicates from the sequence. - @type remove_duplicates: bool - @param concatenate_text: If True, concatenates consecutive words based on the space character. - @type concatenate_text: bool - @return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores. - @rtype: Classifications - @raises ValueError: If 'classes' is not a list of strings. - @raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). - @raises ValueError: If the number of classes does not match the number of columns in 'scores'. - @raises ValueError: If any score is not in the range [0, 1]. - @raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. - @raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. - """ - - if not isinstance(classes, List): - raise ValueError(f"Classes should be a list, got {type(classes)}.") - - if isinstance(scores, List): - scores = np.array(scores) - - if len(scores.shape) != 2: - raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") - - if scores.shape[1] != len(classes): - raise ValueError( - f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." - ) - - if np.any(scores < 0) or np.any(scores > 1): - raise ValueError("Scores should be in the range [0, 1].") - - if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): - raise ValueError("Each row of scores should sum to 1.") - - if ignored_indexes is not None: - if not isinstance(ignored_indexes, List): - raise ValueError( - f"Ignored indexes should be a list, got {type(ignored_indexes)}." - ) - if not all(isinstance(index, int) for index in ignored_indexes): - raise ValueError("Ignored indexes should be integers.") - if np.any(np.array(ignored_indexes) < 0) or np.any( - np.array(ignored_indexes) >= len(classes) - ): - raise ValueError( - "Ignored indexes should be integers in the range [0, num_classes -1]." - ) - - selection = np.ones(len(scores), dtype=bool) - indexes = np.argmax(scores, axis=1) - - if remove_duplicates: - selection[1:] = indexes[1:] != indexes[:-1] - - if ignored_indexes is not None: - selection &= np.array([index not in ignored_indexes for index in indexes]) - - class_list = [classes[i] for i in indexes[selection]] - score_list = np.max(scores, axis=1)[selection] - - if ( - concatenate_text - and len(class_list) > 1 - and all(len(word) <= 1 for word in class_list) - ): - concatenated_scores = [] - concatenated_words = "".join(class_list).split() - cumsumlist = np.cumsum([len(word) for word in concatenated_words]) - - start_index = 0 - for num_spaces, end_index in enumerate(cumsumlist): - word_scores = score_list[start_index + num_spaces : end_index + num_spaces] - concatenated_scores.append(np.mean(word_scores)) - start_index = end_index - - class_list = concatenated_words - score_list = np.array(concatenated_scores) - - elif ( - concatenate_text - and len(class_list) > 1 - and any(len(word) >= 2 for word in class_list) - ): - class_list = [" ".join(class_list)] - score_list = np.mean(score_list) - - classification_msg = Classifications() - - classification_msg.classes = class_list - classification_msg.scores = score_list.tolist() - - return classification_msg diff --git a/depthai_nodes/ml/messages/creators/misc.py b/depthai_nodes/ml/messages/creators/misc.py index b2999eb..7ef091d 100644 --- a/depthai_nodes/ml/messages/creators/misc.py +++ b/depthai_nodes/ml/messages/creators/misc.py @@ -1,12 +1,9 @@ -from typing import List, Union +from typing import List -import numpy as np +from ...messages import Classifications, CompositeMessage -from ...messages import AgeGender, Classifications, CompositeMessage -from ...messages.creators import create_classification_message - -def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender: +def create_age_gender_message(age: float, gender_prob: List[float]) -> CompositeMessage: """Create a DepthAI message for the age and gender probability. @param age: Detected person age (must be multiplied by 100 to get years). @@ -42,61 +39,11 @@ def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender f"Gender_prob list must contain probabilities and sum to 1, got sum {sum(gender_prob)}." ) - age_gender_message = AgeGender() - age_gender_message.age = age gender = Classifications() gender.classes = ["female", "male"] gender.scores = gender_prob - age_gender_message.gender = gender - - return age_gender_message - - -def create_multi_classification_message( - classification_attributes: List[str], - classification_scores: Union[np.ndarray, List[List[float]]], - classification_labels: List[List[str]], -) -> CompositeMessage: - """Create a DepthAI message for multi-classification. - - @param classification_attributes: List of attributes being classified. - @type classification_attributes: List[str] - @param classification_scores: A 2D array or list of classification scores for each - attribute. - @type classification_scores: Union[np.ndarray, List[List[float]]] - @param classification_labels: A 2D list of class labels for each classification - attribute. - @type classification_labels: List[List[str]] - @return: MultiClassification message containing a dictionary of classification - attributes and their respective Classifications. - @rtype: dai.Buffer - @raise ValueError: If number of attributes is not same as number of score-label - pairs. - @raise ValueError: If number of scores is not same as number of labels for each - attribute. - @raise ValueError: If each class score not in the range [0, 1]. - @raise ValueError: If each class score not a probability distribution that sums to - 1. - """ - if len(classification_attributes) != len(classification_scores) or len( - classification_attributes - ) != len(classification_labels): - raise ValueError( - f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels." - ) + age_gender_message = CompositeMessage() + age_gender_message.setData({"age": age, "gender": gender}) - multi_class_dict = {} - for attribute, scores, labels in zip( - classification_attributes, classification_scores, classification_labels - ): - if len(scores) != len(labels): - raise ValueError( - f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}." - ) - multi_class_dict[attribute] = create_classification_message(labels, scores) - - multi_classification_message = CompositeMessage() - multi_classification_message.setData(multi_class_dict) - - return multi_classification_message + return age_gender_message diff --git a/depthai_nodes/ml/parsers/__init__.py b/depthai_nodes/ml/parsers/__init__.py index d72b6c9..61982ec 100644 --- a/depthai_nodes/ml/parsers/__init__.py +++ b/depthai_nodes/ml/parsers/__init__.py @@ -1,5 +1,5 @@ from .age_gender import AgeGenderParser -from .classification import ClassificationParser +from .classification import ClassificationParser, MultiClassificationParser from .fastsam import FastSAMParser from .hrnet import HRNetParser from .image_output import ImageOutputParser @@ -14,7 +14,6 @@ from .scrfd import SCRFDParser from .segmentation import SegmentationParser from .superanimal_landmarker import SuperAnimalParser -from .vehicle_attributes import MultiClassificationParser from .xfeat import XFeatParser from .yolo import YOLOExtendedParser from .yunet import YuNetParser diff --git a/depthai_nodes/ml/parsers/classification.py b/depthai_nodes/ml/parsers/classification.py index 6295e7c..7771673 100644 --- a/depthai_nodes/ml/parsers/classification.py +++ b/depthai_nodes/ml/parsers/classification.py @@ -3,7 +3,10 @@ import depthai as dai import numpy as np -from ..messages.creators import create_classification_message +from ..messages.creators import ( + create_classification_message, + create_multi_classification_message, +) class ClassificationParser(dai.node.ThreadedHostNode): @@ -95,3 +98,59 @@ def run(self): msg.setTimestamp(output.getTimestamp()) self.out.send(msg) + + +class MultiClassificationParser(dai.node.ThreadedHostNode): + """Postprocessing logic for Multiple Classification model. + + Attributes + ---------- + input : Node.Input + Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. + out : Node.Output + Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. + classification_attributes : List[str] + List of attributes to be classified. + classification_labels : List[List[str]] + List of class labels for each attribute in `classification_attributes` + + Output Message/s + ---------------- + **Type**: CompositeMessage + + **Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values. + """ + + def __init__( + self, + classification_attributes: List[str], + classification_labels: List[List[str]], + ): + """Initializes the MultipleClassificationParser node.""" + dai.node.ThreadedHostNode.__init__(self) + self.out = self.createOutput() + self.input = self.createInput() + self.classification_attributes: List[str] = classification_attributes + self.classification_labels: List[List[str]] = classification_labels + + def run(self): + while self.isRunning(): + try: + output: dai.NNData = self.input.get() + except dai.MessageQueue.QueueException: + break + + layer_names = output.getAllLayerNames() + + scores = [] + for layer_name in layer_names: + scores.append( + output.getTensor(layer_name, dequantize=True).flatten().tolist() + ) + + multi_classification_message = create_multi_classification_message( + self.classification_attributes, scores, self.classification_labels + ) + multi_classification_message.setTimestamp(output.getTimestamp()) + + self.out.send(multi_classification_message) diff --git a/depthai_nodes/ml/parsers/ppdet.py b/depthai_nodes/ml/parsers/ppdet.py index 8a74fcb..23d81cc 100644 --- a/depthai_nodes/ml/parsers/ppdet.py +++ b/depthai_nodes/ml/parsers/ppdet.py @@ -1,7 +1,7 @@ import depthai as dai from ..messages.creators import create_detection_message -from .utils.ppdet import corners2xyxy, parse_paddle_detection_outputs +from .utils import corners2xyxy, parse_paddle_detection_outputs class PPTextDetectionParser(dai.node.ThreadedHostNode): diff --git a/depthai_nodes/ml/parsers/vehicle_attributes.py b/depthai_nodes/ml/parsers/vehicle_attributes.py deleted file mode 100644 index 5e5d1c1..0000000 --- a/depthai_nodes/ml/parsers/vehicle_attributes.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List - -import depthai as dai - -from ..messages.creators import create_multi_classification_message - - -class MultiClassificationParser(dai.node.ThreadedHostNode): - """Postprocessing logic for Multiple Classification model. - - Attributes - ---------- - input : Node.Input - Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. - out : Node.Output - Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. - classification_attributes : List[str] - List of attributes to be classified. - classification_labels : List[List[str]] - List of class labels for each attribute in `classification_attributes` - - Output Message/s - ---------------- - **Type**: CompositeMessage - - **Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values. - """ - - def __init__( - self, - classification_attributes: List[str], - classification_labels: List[List[str]], - ): - """Initializes the MultipleClassificationParser node.""" - dai.node.ThreadedHostNode.__init__(self) - self.out = self.createOutput() - self.input = self.createInput() - self.classification_attributes: List[str] = classification_attributes - self.classification_labels: List[List[str]] = classification_labels - - def run(self): - while self.isRunning(): - try: - output: dai.NNData = self.input.get() - except dai.MessageQueue.QueueException: - break - - layer_names = output.getAllLayerNames() - - scores = [] - for layer_name in layer_names: - scores.append( - output.getTensor(layer_name, dequantize=True).flatten().tolist() - ) - - multi_classification_message = create_multi_classification_message( - self.classification_attributes, scores, self.classification_labels - ) - multi_classification_message.setTimestamp(output.getTimestamp()) - - self.out.send(multi_classification_message) diff --git a/tests/unittests/test_creators/test_classification_sequence.py b/tests/unittests/test_creators/test_classification_sequence.py index f4017e1..a1f85e5 100644 --- a/tests/unittests/test_creators/test_classification_sequence.py +++ b/tests/unittests/test_creators/test_classification_sequence.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from depthai_nodes.ml.messages.creators.classification_sequence import ( +from depthai_nodes.ml.messages.creators import ( create_classification_sequence_message, ) diff --git a/tests/unittests/test_creators/test_misc.py b/tests/unittests/test_creators/test_misc.py index dced104..e49c153 100644 --- a/tests/unittests/test_creators/test_misc.py +++ b/tests/unittests/test_creators/test_misc.py @@ -1,8 +1,8 @@ import numpy as np import pytest -from depthai_nodes.ml.messages import AgeGender -from depthai_nodes.ml.messages.creators.misc import create_age_gender_message +from depthai_nodes.ml.messages import CompositeMessage +from depthai_nodes.ml.messages.creators import create_age_gender_message def test_wrong_age(): @@ -45,10 +45,13 @@ def test_correct_types(): gender = [0.35, 0.65] message = create_age_gender_message(age, gender) - assert isinstance(message, AgeGender) - assert message.age == age - assert message.gender.classes == ["female", "male"] - assert np.all(np.isclose(message.gender.scores, gender)) + assert isinstance(message, CompositeMessage) + result = message.getData() + assert "age" in result + assert "gender" in result + assert result["age"] == age + assert result["gender"].classes == ["female", "male"] + assert np.all(np.isclose(result["gender"].scores, gender)) if __name__ == "__main__": diff --git a/tests/unittests/test_creators/test_vehicle_attributes.py b/tests/unittests/test_creators/test_vehicle_attributes.py index 0736382..a341bfc 100644 --- a/tests/unittests/test_creators/test_vehicle_attributes.py +++ b/tests/unittests/test_creators/test_vehicle_attributes.py @@ -1,7 +1,7 @@ import pytest from depthai_nodes.ml.messages import Classifications, CompositeMessage -from depthai_nodes.ml.messages.creators.misc import create_multi_classification_message +from depthai_nodes.ml.messages.creators import create_multi_classification_message def test_incorect_lengths(): From 7b4e309b5bf57b6df06f8892973bf890f42ce133 Mon Sep 17 00:00:00 2001 From: aljazkonec1 Date: Tue, 17 Sep 2024 08:34:41 +0200 Subject: [PATCH 15/16] remove AgeGender Message. --- depthai_nodes/ml/messages/__init__.py | 2 -- depthai_nodes/ml/messages/misc.py | 34 --------------------------- 2 files changed, 36 deletions(-) delete mode 100644 depthai_nodes/ml/messages/misc.py diff --git a/depthai_nodes/ml/messages/__init__.py b/depthai_nodes/ml/messages/__init__.py index 25ae687..835a1c4 100644 --- a/depthai_nodes/ml/messages/__init__.py +++ b/depthai_nodes/ml/messages/__init__.py @@ -8,7 +8,6 @@ from .keypoints import HandKeypoints, Keypoints from .lines import Line, Lines from .map import Map2D -from .misc import AgeGender from .segmentation import SegmentationMasks __all__ = [ @@ -20,7 +19,6 @@ "Lines", "Classifications", "SegmentationMasks", - "AgeGender", "Map2D", "Clusters", "Cluster", diff --git a/depthai_nodes/ml/messages/misc.py b/depthai_nodes/ml/messages/misc.py deleted file mode 100644 index c3aaf83..0000000 --- a/depthai_nodes/ml/messages/misc.py +++ /dev/null @@ -1,34 +0,0 @@ -import depthai as dai - -from ..messages import Classifications - - -class AgeGender(dai.Buffer): - def __init__(self): - super().__init__() - self._age: float = None - self._gender = Classifications() - - @property - def age(self) -> float: - return self._age - - @age.setter - def age(self, value: float): - if not isinstance(value, float): - raise TypeError( - f"start_point must be of type float, instead got {type(value)}." - ) - self._age = value - - @property - def gender(self) -> Classifications: - return self._gender - - @gender.setter - def gender(self, value: Classifications): - if not isinstance(value, Classifications): - raise TypeError( - f"gender must be of type Classifications, instead got {type(value)}." - ) - self._gender = value From c50796f007ad832589be5a130d88fdf9473d2a62 Mon Sep 17 00:00:00 2001 From: aljazkonec1 Date: Tue, 17 Sep 2024 09:14:10 +0200 Subject: [PATCH 16/16] Refactor examples. --- examples/visualization/visualizers/classification.py | 6 ++++-- .../visualization/visualizers/utils/message_parsers.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/visualization/visualizers/classification.py b/examples/visualization/visualizers/classification.py index 5a19f2b..65eebb2 100644 --- a/examples/visualization/visualizers/classification.py +++ b/examples/visualization/visualizers/classification.py @@ -1,7 +1,7 @@ import cv2 import depthai as dai -from depthai_nodes.ml.messages import AgeGender, Classifications +from depthai_nodes.ml.messages import Classifications, CompositeMessage from .utils.message_parsers import ( parse_classification_message, @@ -36,7 +36,9 @@ def visualize_classification( return False -def visualize_age_gender(frame: dai.ImgFrame, message: AgeGender, extraParams: dict): +def visualize_age_gender( + frame: dai.ImgFrame, message: CompositeMessage, extraParams: dict +): """Visualizes the age and predicted gender on the frame.""" if frame.shape[0] < 128: frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) diff --git a/examples/visualization/visualizers/utils/message_parsers.py b/examples/visualization/visualizers/utils/message_parsers.py index 4961595..389c3bf 100644 --- a/examples/visualization/visualizers/utils/message_parsers.py +++ b/examples/visualization/visualizers/utils/message_parsers.py @@ -1,9 +1,9 @@ import depthai as dai from depthai_nodes.ml.messages import ( - AgeGender, Classifications, Clusters, + CompositeMessage, ImgDetectionsExtended, Keypoints, Lines, @@ -50,11 +50,11 @@ def parse_image_message(message: dai.ImgFrame): return image -def parser_age_gender_message(message: AgeGender): +def parser_age_gender_message(message: CompositeMessage): """Parses the age-gender message and return the age and scores for all genders.""" - - age = message.age - gender = message.gender + message = message.getData() + age = message["age"] + gender = message["gender"] gender_scores = gender.scores gender_classes = gender.classes