diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e0ed2e8..0dd4f82 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -93,6 +93,17 @@ jobs: - name: Install package run: pip install -e .[dev] + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-cov + + - name: Check Python version + run: python --version + + - name: Check installed packages + run: pip list + - name: Run pytest uses: pavelzw/pytest-action@v2 with: diff --git a/depthai_nodes/ml/messages/__init__.py b/depthai_nodes/ml/messages/__init__.py index 25ae687..835a1c4 100644 --- a/depthai_nodes/ml/messages/__init__.py +++ b/depthai_nodes/ml/messages/__init__.py @@ -8,7 +8,6 @@ from .keypoints import HandKeypoints, Keypoints from .lines import Line, Lines from .map import Map2D -from .misc import AgeGender from .segmentation import SegmentationMasks __all__ = [ @@ -20,7 +19,6 @@ "Lines", "Classifications", "SegmentationMasks", - "AgeGender", "Map2D", "Clusters", "Cluster", diff --git a/depthai_nodes/ml/messages/creators/__init__.py b/depthai_nodes/ml/messages/creators/__init__.py index 8656ffa..17cdaf0 100644 --- a/depthai_nodes/ml/messages/creators/__init__.py +++ b/depthai_nodes/ml/messages/creators/__init__.py @@ -1,11 +1,14 @@ -from .classification import create_classification_message -from .classification_sequence import create_classification_sequence_message +from .classification import ( + create_classification_message, + create_classification_sequence_message, + create_multi_classification_message, +) from .clusters import create_cluster_message from .detection import create_detection_message, create_line_detection_message from .image import create_image_message from .keypoints import create_hand_keypoints_message, create_keypoints_message from .map import create_map_message -from .misc import create_age_gender_message, create_multi_classification_message +from .misc import create_age_gender_message from .segmentation import create_sam_message, create_segmentation_message from .tracked_features import create_tracked_features_message diff --git a/depthai_nodes/ml/messages/creators/classification.py b/depthai_nodes/ml/messages/creators/classification.py index 42ed263..49b7dd6 100644 --- a/depthai_nodes/ml/messages/creators/classification.py +++ b/depthai_nodes/ml/messages/creators/classification.py @@ -2,7 +2,7 @@ import numpy as np -from ...messages import Classifications +from ...messages import Classifications, CompositeMessage def create_classification_message( @@ -82,3 +82,164 @@ def create_classification_message( classification_msg.scores = scores.tolist() return classification_msg + + +def create_multi_classification_message( + classification_attributes: List[str], + classification_scores: Union[np.ndarray, List[List[float]]], + classification_labels: List[List[str]], +) -> CompositeMessage: + """Create a DepthAI message for multi-classification. + + @param classification_attributes: List of attributes being classified. + @type classification_attributes: List[str] + @param classification_scores: A 2D array or list of classification scores for each + attribute. + @type classification_scores: Union[np.ndarray, List[List[float]]] + @param classification_labels: A 2D list of class labels for each classification + attribute. + @type classification_labels: List[List[str]] + @return: MultiClassification message containing a dictionary of classification + attributes and their respective Classifications. + @rtype: dai.Buffer + @raise ValueError: If number of attributes is not same as number of score-label + pairs. + @raise ValueError: If number of scores is not same as number of labels for each + attribute. + @raise ValueError: If each class score not in the range [0, 1]. + @raise ValueError: If each class score not a probability distribution that sums to + 1. + """ + + if len(classification_attributes) != len(classification_scores) or len( + classification_attributes + ) != len(classification_labels): + raise ValueError( + f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels." + ) + + multi_class_dict = {} + for attribute, scores, labels in zip( + classification_attributes, classification_scores, classification_labels + ): + if len(scores) != len(labels): + raise ValueError( + f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}." + ) + multi_class_dict[attribute] = create_classification_message(labels, scores) + + multi_classification_message = CompositeMessage() + multi_classification_message.setData(multi_class_dict) + + return multi_classification_message + + +def create_classification_sequence_message( + classes: List[str], + scores: Union[np.ndarray, List], + ignored_indexes: List[int] = None, + remove_duplicates: bool = False, + concatenate_text: bool = False, +) -> Classifications: + """Creates a message for a multi-class sequence. The 'scores' array is a sequence of + probabilities for each class at each position in the sequence. The message contains + the class names and their respective scores, ordered according to the sequence. + + @param classes: A list of class names, with length 'n_classes'. + @type classes: List + @param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. + @type scores: np.ndarray + @param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) + @type ignored_indexes: List[int] + @param remove_duplicates: If True, removes consecutive duplicates from the sequence. + @type remove_duplicates: bool + @param concatenate_text: If True, concatenates consecutive words based on the space character. + @type concatenate_text: bool + @return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores. + @rtype: Classifications + @raises ValueError: If 'classes' is not a list of strings. + @raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). + @raises ValueError: If the number of classes does not match the number of columns in 'scores'. + @raises ValueError: If any score is not in the range [0, 1]. + @raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. + @raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. + """ + + if not isinstance(classes, List): + raise ValueError(f"Classes should be a list, got {type(classes)}.") + + if isinstance(scores, List): + scores = np.array(scores) + + if len(scores.shape) != 2: + raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") + + if scores.shape[1] != len(classes): + raise ValueError( + f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." + ) + + if np.any(scores < 0) or np.any(scores > 1): + raise ValueError("Scores should be in the range [0, 1].") + + if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): + raise ValueError("Each row of scores should sum to 1.") + + if ignored_indexes is not None: + if not isinstance(ignored_indexes, List): + raise ValueError( + f"Ignored indexes should be a list, got {type(ignored_indexes)}." + ) + if not all(isinstance(index, int) for index in ignored_indexes): + raise ValueError("Ignored indexes should be integers.") + if np.any(np.array(ignored_indexes) < 0) or np.any( + np.array(ignored_indexes) >= len(classes) + ): + raise ValueError( + "Ignored indexes should be integers in the range [0, num_classes -1]." + ) + + selection = np.ones(len(scores), dtype=bool) + indexes = np.argmax(scores, axis=1) + + if remove_duplicates: + selection[1:] = indexes[1:] != indexes[:-1] + + if ignored_indexes is not None: + selection &= np.array([index not in ignored_indexes for index in indexes]) + + class_list = [classes[i] for i in indexes[selection]] + score_list = np.max(scores, axis=1)[selection] + + if ( + concatenate_text + and len(class_list) > 1 + and all(len(word) <= 1 for word in class_list) + ): + concatenated_scores = [] + concatenated_words = "".join(class_list).split() + cumsumlist = np.cumsum([len(word) for word in concatenated_words]) + + start_index = 0 + for num_spaces, end_index in enumerate(cumsumlist): + word_scores = score_list[start_index + num_spaces : end_index + num_spaces] + concatenated_scores.append(np.mean(word_scores)) + start_index = end_index + + class_list = concatenated_words + score_list = np.array(concatenated_scores) + + elif ( + concatenate_text + and len(class_list) > 1 + and any(len(word) >= 2 for word in class_list) + ): + class_list = [" ".join(class_list)] + score_list = np.mean(score_list) + + classification_msg = Classifications() + + classification_msg.classes = class_list + classification_msg.scores = score_list.tolist() + + return classification_msg diff --git a/depthai_nodes/ml/messages/creators/classification_sequence.py b/depthai_nodes/ml/messages/creators/classification_sequence.py deleted file mode 100644 index 216905f..0000000 --- a/depthai_nodes/ml/messages/creators/classification_sequence.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import List, Union - -import numpy as np - -from .. import Classifications - - -def create_classification_sequence_message( - classes: List[str], - scores: Union[np.ndarray, List], - ignored_indexes: List[int] = None, - remove_duplicates: bool = False, - concatenate_text: bool = False, -) -> Classifications: - """Creates a message for a multi-class sequence. The 'scores' array is a sequence of - probabilities for each class at each position in the sequence. The message contains - the class names and their respective scores, ordered according to the sequence. - - @param classes: A list of class names, with length 'n_classes'. - @type classes: List - @param scores: A numpy array of shape (sequence_length, n_classes) containing the (row-wise) probability distributions over the classes. - @type scores: np.ndarray - @param ignored_indexes: A list of indexes to ignore during classification generation (e.g., background class, padding class) - @type ignored_indexes: List[int] - @param remove_duplicates: If True, removes consecutive duplicates from the sequence. - @type remove_duplicates: bool - @param concatenate_text: If True, concatenates consecutive words based on the space character. - @type concatenate_text: bool - @return: A Classification message with attributes `classes` and `scores`, where `classes` is a list of class names and `scores` is a list of corresponding scores. - @rtype: Classifications - @raises ValueError: If 'classes' is not a list of strings. - @raises ValueError: If 'scores' is not a 2D array of list of shape (sequence_length, n_classes). - @raises ValueError: If the number of classes does not match the number of columns in 'scores'. - @raises ValueError: If any score is not in the range [0, 1]. - @raises ValueError: If the probabilities in any row of 'scores' do not sum to 1. - @raises ValueError: If 'ignored_indexes' in not None or a list of valid indexes within the range [0, n_classes - 1]. - """ - - if not isinstance(classes, List): - raise ValueError(f"Classes should be a list, got {type(classes)}.") - - if isinstance(scores, List): - scores = np.array(scores) - - if len(scores.shape) != 2: - raise ValueError(f"Scores should be a 2D array, got {scores.shape}.") - - if scores.shape[1] != len(classes): - raise ValueError( - f"Number of classes and scores mismatch. Provided {len(classes)} class names and {scores.shape[1]} scores." - ) - - if np.any(scores < 0) or np.any(scores > 1): - raise ValueError("Scores should be in the range [0, 1].") - - if np.any(~np.isclose(scores.sum(axis=1), 1.0, atol=1e-2)): - raise ValueError("Each row of scores should sum to 1.") - - if ignored_indexes is not None: - if not isinstance(ignored_indexes, List): - raise ValueError( - f"Ignored indexes should be a list, got {type(ignored_indexes)}." - ) - if not all(isinstance(index, int) for index in ignored_indexes): - raise ValueError("Ignored indexes should be integers.") - if np.any(np.array(ignored_indexes) < 0) or np.any( - np.array(ignored_indexes) >= len(classes) - ): - raise ValueError( - "Ignored indexes should be integers in the range [0, num_classes -1]." - ) - - selection = np.ones(len(scores), dtype=bool) - indexes = np.argmax(scores, axis=1) - - if remove_duplicates: - selection[1:] = indexes[1:] != indexes[:-1] - - if ignored_indexes is not None: - selection &= np.array([index not in ignored_indexes for index in indexes]) - - class_list = [classes[i] for i in indexes[selection]] - score_list = np.max(scores, axis=1)[selection] - - if ( - concatenate_text - and len(class_list) > 1 - and all(len(word) <= 1 for word in class_list) - ): - concatenated_scores = [] - concatenated_words = "".join(class_list).split() - cumsumlist = np.cumsum([len(word) for word in concatenated_words]) - - start_index = 0 - for num_spaces, end_index in enumerate(cumsumlist): - word_scores = score_list[start_index + num_spaces : end_index + num_spaces] - concatenated_scores.append(np.mean(word_scores)) - start_index = end_index - - class_list = concatenated_words - score_list = np.array(concatenated_scores) - - elif ( - concatenate_text - and len(class_list) > 1 - and any(len(word) >= 2 for word in class_list) - ): - class_list = [" ".join(class_list)] - score_list = np.mean(score_list) - - classification_msg = Classifications() - - classification_msg.classes = class_list - classification_msg.scores = score_list.tolist() - - return classification_msg diff --git a/depthai_nodes/ml/messages/creators/misc.py b/depthai_nodes/ml/messages/creators/misc.py index b2999eb..7ef091d 100644 --- a/depthai_nodes/ml/messages/creators/misc.py +++ b/depthai_nodes/ml/messages/creators/misc.py @@ -1,12 +1,9 @@ -from typing import List, Union +from typing import List -import numpy as np +from ...messages import Classifications, CompositeMessage -from ...messages import AgeGender, Classifications, CompositeMessage -from ...messages.creators import create_classification_message - -def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender: +def create_age_gender_message(age: float, gender_prob: List[float]) -> CompositeMessage: """Create a DepthAI message for the age and gender probability. @param age: Detected person age (must be multiplied by 100 to get years). @@ -42,61 +39,11 @@ def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender f"Gender_prob list must contain probabilities and sum to 1, got sum {sum(gender_prob)}." ) - age_gender_message = AgeGender() - age_gender_message.age = age gender = Classifications() gender.classes = ["female", "male"] gender.scores = gender_prob - age_gender_message.gender = gender - - return age_gender_message - - -def create_multi_classification_message( - classification_attributes: List[str], - classification_scores: Union[np.ndarray, List[List[float]]], - classification_labels: List[List[str]], -) -> CompositeMessage: - """Create a DepthAI message for multi-classification. - - @param classification_attributes: List of attributes being classified. - @type classification_attributes: List[str] - @param classification_scores: A 2D array or list of classification scores for each - attribute. - @type classification_scores: Union[np.ndarray, List[List[float]]] - @param classification_labels: A 2D list of class labels for each classification - attribute. - @type classification_labels: List[List[str]] - @return: MultiClassification message containing a dictionary of classification - attributes and their respective Classifications. - @rtype: dai.Buffer - @raise ValueError: If number of attributes is not same as number of score-label - pairs. - @raise ValueError: If number of scores is not same as number of labels for each - attribute. - @raise ValueError: If each class score not in the range [0, 1]. - @raise ValueError: If each class score not a probability distribution that sums to - 1. - """ - if len(classification_attributes) != len(classification_scores) or len( - classification_attributes - ) != len(classification_labels): - raise ValueError( - f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels." - ) + age_gender_message = CompositeMessage() + age_gender_message.setData({"age": age, "gender": gender}) - multi_class_dict = {} - for attribute, scores, labels in zip( - classification_attributes, classification_scores, classification_labels - ): - if len(scores) != len(labels): - raise ValueError( - f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}." - ) - multi_class_dict[attribute] = create_classification_message(labels, scores) - - multi_classification_message = CompositeMessage() - multi_classification_message.setData(multi_class_dict) - - return multi_classification_message + return age_gender_message diff --git a/depthai_nodes/ml/messages/misc.py b/depthai_nodes/ml/messages/misc.py deleted file mode 100644 index c3aaf83..0000000 --- a/depthai_nodes/ml/messages/misc.py +++ /dev/null @@ -1,34 +0,0 @@ -import depthai as dai - -from ..messages import Classifications - - -class AgeGender(dai.Buffer): - def __init__(self): - super().__init__() - self._age: float = None - self._gender = Classifications() - - @property - def age(self) -> float: - return self._age - - @age.setter - def age(self, value: float): - if not isinstance(value, float): - raise TypeError( - f"start_point must be of type float, instead got {type(value)}." - ) - self._age = value - - @property - def gender(self) -> Classifications: - return self._gender - - @gender.setter - def gender(self, value: Classifications): - if not isinstance(value, Classifications): - raise TypeError( - f"gender must be of type Classifications, instead got {type(value)}." - ) - self._gender = value diff --git a/depthai_nodes/ml/parsers/__init__.py b/depthai_nodes/ml/parsers/__init__.py index d72b6c9..61982ec 100644 --- a/depthai_nodes/ml/parsers/__init__.py +++ b/depthai_nodes/ml/parsers/__init__.py @@ -1,5 +1,5 @@ from .age_gender import AgeGenderParser -from .classification import ClassificationParser +from .classification import ClassificationParser, MultiClassificationParser from .fastsam import FastSAMParser from .hrnet import HRNetParser from .image_output import ImageOutputParser @@ -14,7 +14,6 @@ from .scrfd import SCRFDParser from .segmentation import SegmentationParser from .superanimal_landmarker import SuperAnimalParser -from .vehicle_attributes import MultiClassificationParser from .xfeat import XFeatParser from .yolo import YOLOExtendedParser from .yunet import YuNetParser diff --git a/depthai_nodes/ml/parsers/classification.py b/depthai_nodes/ml/parsers/classification.py index 6295e7c..7771673 100644 --- a/depthai_nodes/ml/parsers/classification.py +++ b/depthai_nodes/ml/parsers/classification.py @@ -3,7 +3,10 @@ import depthai as dai import numpy as np -from ..messages.creators import create_classification_message +from ..messages.creators import ( + create_classification_message, + create_multi_classification_message, +) class ClassificationParser(dai.node.ThreadedHostNode): @@ -95,3 +98,59 @@ def run(self): msg.setTimestamp(output.getTimestamp()) self.out.send(msg) + + +class MultiClassificationParser(dai.node.ThreadedHostNode): + """Postprocessing logic for Multiple Classification model. + + Attributes + ---------- + input : Node.Input + Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. + out : Node.Output + Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. + classification_attributes : List[str] + List of attributes to be classified. + classification_labels : List[List[str]] + List of class labels for each attribute in `classification_attributes` + + Output Message/s + ---------------- + **Type**: CompositeMessage + + **Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values. + """ + + def __init__( + self, + classification_attributes: List[str], + classification_labels: List[List[str]], + ): + """Initializes the MultipleClassificationParser node.""" + dai.node.ThreadedHostNode.__init__(self) + self.out = self.createOutput() + self.input = self.createInput() + self.classification_attributes: List[str] = classification_attributes + self.classification_labels: List[List[str]] = classification_labels + + def run(self): + while self.isRunning(): + try: + output: dai.NNData = self.input.get() + except dai.MessageQueue.QueueException: + break + + layer_names = output.getAllLayerNames() + + scores = [] + for layer_name in layer_names: + scores.append( + output.getTensor(layer_name, dequantize=True).flatten().tolist() + ) + + multi_classification_message = create_multi_classification_message( + self.classification_attributes, scores, self.classification_labels + ) + multi_classification_message.setTimestamp(output.getTimestamp()) + + self.out.send(multi_classification_message) diff --git a/depthai_nodes/ml/parsers/ppdet.py b/depthai_nodes/ml/parsers/ppdet.py index 8a74fcb..23d81cc 100644 --- a/depthai_nodes/ml/parsers/ppdet.py +++ b/depthai_nodes/ml/parsers/ppdet.py @@ -1,7 +1,7 @@ import depthai as dai from ..messages.creators import create_detection_message -from .utils.ppdet import corners2xyxy, parse_paddle_detection_outputs +from .utils import corners2xyxy, parse_paddle_detection_outputs class PPTextDetectionParser(dai.node.ThreadedHostNode): diff --git a/depthai_nodes/ml/parsers/vehicle_attributes.py b/depthai_nodes/ml/parsers/vehicle_attributes.py deleted file mode 100644 index 5e5d1c1..0000000 --- a/depthai_nodes/ml/parsers/vehicle_attributes.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List - -import depthai as dai - -from ..messages.creators import create_multi_classification_message - - -class MultiClassificationParser(dai.node.ThreadedHostNode): - """Postprocessing logic for Multiple Classification model. - - Attributes - ---------- - input : Node.Input - Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. - out : Node.Output - Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. - classification_attributes : List[str] - List of attributes to be classified. - classification_labels : List[List[str]] - List of class labels for each attribute in `classification_attributes` - - Output Message/s - ---------------- - **Type**: CompositeMessage - - **Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values. - """ - - def __init__( - self, - classification_attributes: List[str], - classification_labels: List[List[str]], - ): - """Initializes the MultipleClassificationParser node.""" - dai.node.ThreadedHostNode.__init__(self) - self.out = self.createOutput() - self.input = self.createInput() - self.classification_attributes: List[str] = classification_attributes - self.classification_labels: List[List[str]] = classification_labels - - def run(self): - while self.isRunning(): - try: - output: dai.NNData = self.input.get() - except dai.MessageQueue.QueueException: - break - - layer_names = output.getAllLayerNames() - - scores = [] - for layer_name in layer_names: - scores.append( - output.getTensor(layer_name, dequantize=True).flatten().tolist() - ) - - multi_classification_message = create_multi_classification_message( - self.classification_attributes, scores, self.classification_labels - ) - multi_classification_message.setTimestamp(output.getTimestamp()) - - self.out.send(multi_classification_message) diff --git a/examples/README.md b/examples/README.md index 8c284c7..49d6027 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,8 +1,18 @@ # DepthAI Nodes examples -The `main.py` script lets you run fully-automated pipeline with the model of your choice. To run the script you need the model slug and then move to `examples` folder and run: +The `main.py` script lets you run fully-automated pipeline with the model of your choice. To run the script: + +Make sure you have the `depthai-nodes` package installed: + +``` +cd depthai-nodes +pip install -e . +``` + +Prepare the model slug and run: ``` +cd examples python main.py -s ``` diff --git a/examples/main.py b/examples/main.py index 9954321..1b921b4 100644 --- a/examples/main.py +++ b/examples/main.py @@ -1,5 +1,5 @@ import depthai as dai -from utils.arguments import initialize_argparser, parse_model_slug +from utils.arguments import initialize_argparser, parse_fps_limit, parse_model_slug from utils.model import get_input_shape, get_model_from_hub, get_parser from utils.parser import setup_parser from visualization.visualize import visualize @@ -9,6 +9,7 @@ # Parse the model slug model_slug, model_version_slug = parse_model_slug(args) +fps_limit = parse_fps_limit(args) # Get the model from the HubAI nn_archive = get_model_from_hub(model_slug, model_version_slug) @@ -27,7 +28,10 @@ # YOLO and MobileNet-SSD have native parsers in DAI - no need to create a separate parser if parser_name == "YOLO" or parser_name == "SSD": network = pipeline.create(dai.node.DetectionNetwork).build( - cam.requestOutput(input_shape, type=dai.ImgFrame.Type.BGR888p), nn_archive + cam.requestOutput( + input_shape, type=dai.ImgFrame.Type.BGR888p, fps=fps_limit + ), + nn_archive, ) parser_queue = network.out.createOutputQueue() else: @@ -45,13 +49,16 @@ manip = pipeline.create(dai.node.ImageManip) manip.initialConfig.setResize(input_shape) large_input_shape = (input_shape[0] * 4, input_shape[1] * 4) - cam.requestOutput(large_input_shape, type=image_type).link(manip.inputImage) + cam.requestOutput(large_input_shape, type=image_type, fps=fps_limit).link( + manip.inputImage + ) network = pipeline.create(dai.node.NeuralNetwork).build( manip.out, nn_archive ) else: network = pipeline.create(dai.node.NeuralNetwork).build( - cam.requestOutput(input_shape, type=image_type), nn_archive + cam.requestOutput(input_shape, type=image_type, fps=fps_limit), + nn_archive, ) parser = pipeline.create(parser_class) diff --git a/examples/utils/arguments.py b/examples/utils/arguments.py index 3551bf5..3b570e8 100644 --- a/examples/utils/arguments.py +++ b/examples/utils/arguments.py @@ -8,7 +8,7 @@ def initialize_argparser(): parser.description = "General example script to run any model available in HubAI on DepthAI device. \ All you need is a model slug of the model and the script will download the model from HubAI and create \ the whole pipeline with visualizations. You also need a DepthAI device connected to your computer. \ - Currently, only RVC2 devices are supported." + Currently, only RVC2 devices are supported. If using OAK-D Lite, please set the FPS limit to 28." parser.add_argument( "-s", @@ -18,6 +18,15 @@ def initialize_argparser(): type=str, ) + parser.add_argument( + "-l", + "--fps_limit", + help="FPS limit for the model runtime.", + required=False, + default=30.0, + type=float, + ) + args = parser.parse_args() return parser, args @@ -41,3 +50,13 @@ def parse_model_slug(args: argparse.Namespace) -> Tuple[str, str]: model_version_slug = model_slug_parts[1] return model_slug, model_version_slug + + +def parse_fps_limit(args: argparse.Namespace) -> int: + """Parse the FPS limit from the arguments. + + Returns the FPS limit. + """ + fps_limit = args.fps_limit + + return fps_limit diff --git a/examples/utils/parser.py b/examples/utils/parser.py index 0f56e8b..8dee3c0 100644 --- a/examples/utils/parser.py +++ b/examples/utils/parser.py @@ -5,6 +5,7 @@ FastSAMParser, KeypointParser, LaneDetectionParser, + MapOutputParser, MPPalmDetectionParser, SCRFDParser, SegmentationParser, @@ -63,6 +64,18 @@ def setup_classification_parser(parser: ClassificationParser, params: dict): ) +def setup_map_output_parser(parser: MapOutputParser, params: dict): + """Setup the map output parser with the required metadata.""" + try: + min_max_scaling = params["min_max_scaling"] + if min_max_scaling: + parser.setMinMaxScaling(True) + except Exception: + print( + "This NN archive does not have required metadata for MapOutputParser. Skipping setup..." + ) + + def setup_xfeat_parser(parser: XFeatParser, params: dict): """Setup the XFeat parser with the required metadata.""" try: @@ -144,6 +157,8 @@ def setup_parser(parser: dai.ThreadedNode, nn_archive: dai.NNArchive, parser_nam setup_keypoint_parser(parser, extraParams) elif parser_name == "ClassificationParser": setup_classification_parser(parser, extraParams) + elif parser_name == "MapOutputParser": + setup_map_output_parser(parser, extraParams) elif parser_name == "XFeatParser": setup_xfeat_parser(parser, extraParams) elif parser_name == "YOLOExtendedParser": diff --git a/examples/visualization/mapping.py b/examples/visualization/mapping.py deleted file mode 100644 index eb460d0..0000000 --- a/examples/visualization/mapping.py +++ /dev/null @@ -1,31 +0,0 @@ -from .classification import visualize_age_gender, visualize_classification -from .detection import ( - visualize_detections, - visualize_lane_detections, - visualize_line_detections, - visualize_yolo_extended, -) -from .image import visualize_image -from .keypoints import visualize_keypoints -from .segmentation import visualize_fastsam, visualize_segmentation - -parser_mapping = { - "YuNetParser": visualize_detections, - "SCRFDParser": visualize_detections, - "MPPalmDetectionParser": visualize_detections, - "YOLO": visualize_detections, - "SSD": visualize_detections, - "SegmentationParser": visualize_segmentation, - "MLSDParser": visualize_line_detections, - "KeypointParser": visualize_keypoints, - "HRNetParser": visualize_keypoints, - "SuperAnimalParser": visualize_keypoints, - "MPHandLandmarkParser": visualize_keypoints, - "ClassificationParser": visualize_classification, - "ImageOutputParser": visualize_image, - "MonocularDepthParser": visualize_image, - "AgeGenderParser": visualize_age_gender, - "YOLOExtendedParser": visualize_yolo_extended, - "LaneDetectionParser": visualize_lane_detections, - "FastSAMParser": visualize_fastsam, -} diff --git a/examples/visualization/visualize.py b/examples/visualization/visualize.py index c4ce279..9e3aa49 100644 --- a/examples/visualization/visualize.py +++ b/examples/visualization/visualize.py @@ -1,6 +1,39 @@ import depthai as dai -from .mapping import parser_mapping +from .visualizers import ( + visualize_age_gender, + visualize_classification, + visualize_detections, + visualize_fastsam, + visualize_image, + visualize_keypoints, + visualize_lane_detections, + visualize_line_detections, + visualize_map, + visualize_segmentation, + visualize_yolo_extended, +) + +visualizers = { + "YuNetParser": visualize_detections, + "SCRFDParser": visualize_detections, + "MPPalmDetectionParser": visualize_detections, + "YOLO": visualize_detections, + "SSD": visualize_detections, + "SegmentationParser": visualize_segmentation, + "MLSDParser": visualize_line_detections, + "KeypointParser": visualize_keypoints, + "HRNetParser": visualize_keypoints, + "SuperAnimalParser": visualize_keypoints, + "MPHandLandmarkParser": visualize_keypoints, + "ClassificationParser": visualize_classification, + "ImageOutputParser": visualize_image, + "MapOutputParser": visualize_map, + "AgeGenderParser": visualize_age_gender, + "YOLOExtendedParser": visualize_yolo_extended, + "LaneDetectionParser": visualize_lane_detections, + "FastSAMParser": visualize_fastsam, +} def visualize( @@ -8,5 +41,5 @@ def visualize( ): """Calls the appropriate visualizer based on the parser name and returns True if the pipeline should be stopped.""" - visualizer = parser_mapping[parser_name] + visualizer = visualizers[parser_name] return visualizer(frame, message, extraParams) diff --git a/examples/visualization/visualizers/__init__.py b/examples/visualization/visualizers/__init__.py new file mode 100644 index 0000000..ed13ebd --- /dev/null +++ b/examples/visualization/visualizers/__init__.py @@ -0,0 +1,25 @@ +from .classification import visualize_age_gender, visualize_classification +from .detection import ( + visualize_detections, + visualize_lane_detections, + visualize_line_detections, + visualize_yolo_extended, +) +from .image import visualize_image +from .keypoints import visualize_keypoints +from .map import visualize_map +from .segmentation import visualize_fastsam, visualize_segmentation + +__all__ = [ + "visualize_image", + "visualize_segmentation", + "visualize_keypoints", + "visualize_classification", + "visualize_map", + "visualize_age_gender", + "visualize_yolo_extended", + "visualize_detections", + "visualize_line_detections", + "visualize_lane_detections", + "visualize_fastsam", +] diff --git a/examples/visualization/classification.py b/examples/visualization/visualizers/classification.py similarity index 84% rename from examples/visualization/classification.py rename to examples/visualization/visualizers/classification.py index 90c3bb0..65eebb2 100644 --- a/examples/visualization/classification.py +++ b/examples/visualization/visualizers/classification.py @@ -1,9 +1,12 @@ import cv2 import depthai as dai -from depthai_nodes.ml.messages import AgeGender, Classifications +from depthai_nodes.ml.messages import Classifications, CompositeMessage -from .messages import parse_classification_message, parser_age_gender_message +from .utils.message_parsers import ( + parse_classification_message, + parser_age_gender_message, +) def visualize_classification( @@ -33,7 +36,9 @@ def visualize_classification( return False -def visualize_age_gender(frame: dai.ImgFrame, message: AgeGender, extraParams: dict): +def visualize_age_gender( + frame: dai.ImgFrame, message: CompositeMessage, extraParams: dict +): """Visualizes the age and predicted gender on the frame.""" if frame.shape[0] < 128: frame = cv2.resize(frame, (frame.shape[1] * 2, frame.shape[0] * 2)) diff --git a/examples/visualization/detection.py b/examples/visualization/visualizers/detection.py similarity index 98% rename from examples/visualization/detection.py rename to examples/visualization/visualizers/detection.py index 1af7db2..ae60115 100644 --- a/examples/visualization/detection.py +++ b/examples/visualization/visualizers/detection.py @@ -4,8 +4,8 @@ from depthai_nodes.ml.messages import Clusters, ImgDetectionsExtended, Lines -from .colors import get_yolo_colors -from .messages import ( +from .utils.colors import get_yolo_colors +from .utils.message_parsers import ( parse_cluster_message, parse_detection_message, parse_line_detection_message, diff --git a/examples/visualization/image.py b/examples/visualization/visualizers/image.py similarity index 86% rename from examples/visualization/image.py rename to examples/visualization/visualizers/image.py index 19255cb..3beb7fa 100644 --- a/examples/visualization/image.py +++ b/examples/visualization/visualizers/image.py @@ -1,7 +1,7 @@ import cv2 import depthai as dai -from .messages import parse_image_message +from .utils.message_parsers import parse_image_message def visualize_image(frame: dai.ImgFrame, message: dai.ImgFrame, extraParams: dict): diff --git a/examples/visualization/keypoints.py b/examples/visualization/visualizers/keypoints.py similarity index 90% rename from examples/visualization/keypoints.py rename to examples/visualization/visualizers/keypoints.py index a587c2e..ea65c8e 100644 --- a/examples/visualization/keypoints.py +++ b/examples/visualization/visualizers/keypoints.py @@ -3,7 +3,7 @@ from depthai_nodes.ml.messages import Keypoints -from .messages import parse_keypoints_message +from .utils.message_parsers import parse_keypoints_message def visualize_keypoints(frame: dai.ImgFrame, message: Keypoints, extraParams: dict): diff --git a/examples/visualization/visualizers/map.py b/examples/visualization/visualizers/map.py new file mode 100644 index 0000000..134c15a --- /dev/null +++ b/examples/visualization/visualizers/map.py @@ -0,0 +1,32 @@ +import cv2 +import depthai as dai +import numpy as np + +from depthai_nodes.ml.messages import Map2D + +from .utils.message_parsers import parse_map_message + + +def visualize_map(frame: dai.ImgFrame, message: Map2D, extraParams: dict): + """Visualizes the map on the frame.""" + + map = parse_map_message(message) + + # make color representation of the map + map_normalized = cv2.normalize(map, None, 0, 255, cv2.NORM_MINMAX) + map_normalized = map_normalized.astype(np.uint8) + colored_map = cv2.applyColorMap(map_normalized, cv2.COLORMAP_INFERNO) + frame_height, frame_width, _ = frame.shape + colored_map = cv2.resize( + colored_map, (frame_width, frame_height), interpolation=cv2.INTER_LINEAR + ) + + alpha = 0.6 + overlay = cv2.addWeighted(colored_map, alpha, frame, 1 - alpha, 0) + + cv2.imshow("Map Overlay", overlay) + if cv2.waitKey(1) == ord("q"): + cv2.destroyAllWindows() + return True + + return False diff --git a/examples/visualization/segmentation.py b/examples/visualization/visualizers/segmentation.py similarity index 93% rename from examples/visualization/segmentation.py rename to examples/visualization/visualizers/segmentation.py index 4ebbd82..e61514a 100644 --- a/examples/visualization/segmentation.py +++ b/examples/visualization/visualizers/segmentation.py @@ -4,8 +4,8 @@ from depthai_nodes.ml.messages import SegmentationMasks -from .colors import get_adas_colors, get_ewasr_colors, get_selfie_colors -from .messages import parse_fast_sam_message, parse_segmentation_message +from .utils.colors import get_adas_colors, get_ewasr_colors, get_selfie_colors +from .utils.message_parsers import parse_fast_sam_message, parse_segmentation_message def visualize_segmentation( diff --git a/examples/visualization/visualizers/utils/__init__.py b/examples/visualization/visualizers/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/visualization/colors.py b/examples/visualization/visualizers/utils/colors.py similarity index 100% rename from examples/visualization/colors.py rename to examples/visualization/visualizers/utils/colors.py diff --git a/examples/visualization/messages.py b/examples/visualization/visualizers/utils/message_parsers.py similarity index 86% rename from examples/visualization/messages.py rename to examples/visualization/visualizers/utils/message_parsers.py index facd212..389c3bf 100644 --- a/examples/visualization/messages.py +++ b/examples/visualization/visualizers/utils/message_parsers.py @@ -1,12 +1,13 @@ import depthai as dai from depthai_nodes.ml.messages import ( - AgeGender, Classifications, Clusters, + CompositeMessage, ImgDetectionsExtended, Keypoints, Lines, + Map2D, SegmentationMasks, ) @@ -49,11 +50,11 @@ def parse_image_message(message: dai.ImgFrame): return image -def parser_age_gender_message(message: AgeGender): +def parser_age_gender_message(message: CompositeMessage): """Parses the age-gender message and return the age and scores for all genders.""" - - age = message.age - gender = message.gender + message = message.getData() + age = message["age"] + gender = message["gender"] gender_scores = gender.scores gender_classes = gender.classes @@ -76,3 +77,9 @@ def parse_fast_sam_message(message: SegmentationMasks): """Parses the fast sam message and returns the masks.""" masks = message.masks return masks + + +def parse_map_message(message: Map2D): + """Parses the map message and returns the map.""" + map = message.map + return map diff --git a/tests/unittests/test_creators/test_classification_sequence.py b/tests/unittests/test_creators/test_classification_sequence.py index f4017e1..a1f85e5 100644 --- a/tests/unittests/test_creators/test_classification_sequence.py +++ b/tests/unittests/test_creators/test_classification_sequence.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from depthai_nodes.ml.messages.creators.classification_sequence import ( +from depthai_nodes.ml.messages.creators import ( create_classification_sequence_message, ) diff --git a/tests/unittests/test_creators/test_misc.py b/tests/unittests/test_creators/test_misc.py index dced104..e49c153 100644 --- a/tests/unittests/test_creators/test_misc.py +++ b/tests/unittests/test_creators/test_misc.py @@ -1,8 +1,8 @@ import numpy as np import pytest -from depthai_nodes.ml.messages import AgeGender -from depthai_nodes.ml.messages.creators.misc import create_age_gender_message +from depthai_nodes.ml.messages import CompositeMessage +from depthai_nodes.ml.messages.creators import create_age_gender_message def test_wrong_age(): @@ -45,10 +45,13 @@ def test_correct_types(): gender = [0.35, 0.65] message = create_age_gender_message(age, gender) - assert isinstance(message, AgeGender) - assert message.age == age - assert message.gender.classes == ["female", "male"] - assert np.all(np.isclose(message.gender.scores, gender)) + assert isinstance(message, CompositeMessage) + result = message.getData() + assert "age" in result + assert "gender" in result + assert result["age"] == age + assert result["gender"].classes == ["female", "male"] + assert np.all(np.isclose(result["gender"].scores, gender)) if __name__ == "__main__": diff --git a/tests/unittests/test_creators/test_vehicle_attributes.py b/tests/unittests/test_creators/test_vehicle_attributes.py index 0736382..a341bfc 100644 --- a/tests/unittests/test_creators/test_vehicle_attributes.py +++ b/tests/unittests/test_creators/test_vehicle_attributes.py @@ -1,7 +1,7 @@ import pytest from depthai_nodes.ml.messages import Classifications, CompositeMessage -from depthai_nodes.ml.messages.creators.misc import create_multi_classification_message +from depthai_nodes.ml.messages.creators import create_multi_classification_message def test_incorect_lengths():