Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vehicle attributes parser #67

Merged
merged 8 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .classification import Classifications
from .clusters import Cluster, Clusters
from .composite import CompositeMessage
from .img_detections import (
ImgDetectionExtended,
ImgDetectionsExtended,
Expand All @@ -23,4 +24,5 @@
"Map2D",
"Clusters",
"Cluster",
"CompositeMessage",
]
23 changes: 23 additions & 0 deletions depthai_nodes/ml/messages/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Dict, List, Union

import depthai as dai


class CompositeMessage(dai.Buffer):
"""CompositeMessage class for storing composite of (dai.Buffer, float, List) data.

Attributes
----------
_data : Dict[str, Union[dai.Buffer, float, List]]
Dictionary of data with keys as string and values as either dai.Buffer, float or List.
"""

def __init__(self):
super().__init__()
self._data: Dict[str, Union[dai.Buffer, float, List]] = {}

def setData(self, data: Dict[str, Union[dai.Buffer, float, List]]):
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
self._data = data

def getData(self) -> Dict[str, Union[dai.Buffer, float, List]]:
return self._data
3 changes: 2 additions & 1 deletion depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .map import create_map_message
from .misc import create_age_gender_message
from .misc import create_age_gender_message, create_multi_classification_message
from .segmentation import create_sam_message, create_segmentation_message
from .tracked_features import create_tracked_features_message

Expand All @@ -23,4 +23,5 @@
"create_map_message",
"create_classification_sequence_message",
"create_cluster_message",
"create_multi_classification_message",
]
5 changes: 5 additions & 0 deletions depthai_nodes/ml/messages/creators/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def create_classification_message(
if not np.issubdtype(scores.dtype, np.floating):
raise ValueError(f"Scores should be of type float, got {scores.dtype}.")

if any([value < 0 or value > 1 for value in scores]):
raise ValueError(
f"Scores list must contain probabilities between 0 and 1, instead got {scores}."
)

if not np.isclose(np.sum(scores), 1.0, atol=1e-2):
raise ValueError(f"Scores should sum to 1, got {np.sum(scores)}.")

Expand Down
57 changes: 55 additions & 2 deletions depthai_nodes/ml/messages/creators/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import List
from typing import List, Union

from ...messages import AgeGender, Classifications
import numpy as np

from ...messages import AgeGender, Classifications, CompositeMessage
from ...messages.creators import create_classification_message


def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender:
Expand Down Expand Up @@ -47,3 +50,53 @@ def create_age_gender_message(age: float, gender_prob: List[float]) -> AgeGender
age_gender_message.gender = gender

return age_gender_message


def create_multi_classification_message(
classification_attributes: List[str],
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
classification_scores: Union[np.ndarray, List[List[float]]],
classification_labels: List[List[str]],
) -> CompositeMessage:
"""Create a DepthAI message for multi-classification.

@param classification_attributes: List of attributes being classified.
@type classification_attributes: List[str]
@param classification_scores: A 2D array or list of classification scores for each
attribute.
@type classification_scores: Union[np.ndarray, List[List[float]]]
@param classification_labels: A 2D list of class labels for each classification
attribute.
@type classification_labels: List[List[str]]
@return: MultiClassification message containing a dictionary of classification
attributes and their respective Classifications.
@rtype: dai.Buffer
@raise ValueError: If number of attributes is not same as number of score-label
pairs.
@raise ValueError: If number of scores is not same as number of labels for each
attribute.
@raise ValueError: If each class score not in the range [0, 1].
@raise ValueError: If each class score not a probability distribution that sums to
1.
"""

if len(classification_attributes) != len(classification_scores) or len(
classification_attributes
) != len(classification_labels):
raise ValueError(
f"Number of classification attributes, scores and labels should be equal. Got {len(classification_attributes)} attributes, {len(classification_scores)} scores and {len(classification_labels)} labels."
)

multi_class_dict = {}
for attribute, scores, labels in zip(
classification_attributes, classification_scores, classification_labels
):
if len(scores) != len(labels):
raise ValueError(
f"Number of scores and labels should be equal for each classification attribute, got {len(scores)} scores, {len(labels)} labels for attribute {attribute}."
)
multi_class_dict[attribute] = create_classification_message(labels, scores)

multi_classification_message = CompositeMessage()
multi_classification_message.setData(multi_class_dict)

return multi_classification_message
2 changes: 2 additions & 0 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .scrfd import SCRFDParser
from .segmentation import SegmentationParser
from .superanimal_landmarker import SuperAnimalParser
from .vehicle_attributes import MultiClassificationParser
from .xfeat import XFeatParser
from .yolo import YOLOExtendedParser
from .yunet import YuNetParser
Expand All @@ -38,4 +39,5 @@
"MapOutputParser",
"PaddleOCRParser",
"LaneDetectionParser",
"MultiClassificationParser",
]
61 changes: 61 additions & 0 deletions depthai_nodes/ml/parsers/vehicle_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import List

import depthai as dai

from ..messages.creators import create_multi_classification_message


class MultiClassificationParser(dai.node.ThreadedHostNode):
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
"""Postprocessing logic for Multiple Classification model.

Attributes
----------
input : Node.Input
Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node.
out : Node.Output
Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved.
classification_attributes : List[str]
List of attributes to be classified.
classification_labels : List[List[str]]
List of class labels for each attribute in `classification_attributes`

Output Message/s
----------------
**Type**: CompositeMessage

**Description**: A CompositeMessage containing a dictionary of classification attributes as keys and their respective Classifications as values.
"""

def __init__(
self,
classification_attributes: List[str],
classification_labels: List[List[str]],
):
"""Initializes the MultipleClassificationParser node."""
dai.node.ThreadedHostNode.__init__(self)
self.out = self.createOutput()
self.input = self.createInput()
self.classification_attributes: List[str] = classification_attributes
self.classification_labels: List[List[str]] = classification_labels

def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break

layer_names = output.getAllLayerNames()

scores = []
for layer_name in layer_names:
scores.append(
output.getTensor(layer_name, dequantize=True).flatten().tolist()
)

multi_classification_message = create_multi_classification_message(
self.classification_attributes, scores, self.classification_labels
)
multi_classification_message.setTimestamp(output.getTimestamp())

self.out.send(multi_classification_message)
59 changes: 59 additions & 0 deletions tests/unittests/test_creators/test_vehicle_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

from depthai_nodes.ml.messages import Classifications, CompositeMessage
from depthai_nodes.ml.messages.creators.misc import create_multi_classification_message


def test_incorect_lengths():
with pytest.raises(
ValueError,
match="Number of classification attributes, scores and labels should be equal. Got 1 attributes, 2 scores and 2 labels.",
):
create_multi_classification_message(
["vehicle_type"],
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
[["car", "truck"], ["red", "blue"]],
)


def test_incorect_score_label_lengths():
with pytest.raises(
ValueError,
match="Number of scores and labels should be equal for each classification attribute, got 4 scores, 2 labels for attribute vehicle_type.",
):
create_multi_classification_message(
["vehicle_type", "vehicle_color"],
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6]],
[["car", "truck"], ["red", "blue"]],
)


def test_correct_usage():
attrs = ["vehicle_type", "vehicle_color"]
scores = [[0.1, 0.2, 0.3, 0.4], [0.0, 0.1, 0.4, 0.2, 0.2, 0.1]]
names = [
["car", "truck", "van", "bike"],
["red", "blue", "green", "black", "white", "yellow"],
]

res = create_multi_classification_message(attrs, scores, names)

assert isinstance(res, CompositeMessage)
res = res.getData()
assert isinstance(res["vehicle_type"], Classifications)
assert isinstance(res["vehicle_color"], Classifications)
assert res["vehicle_type"].classes == ["bike", "van", "truck", "car"]
assert res["vehicle_type"].scores == [0.4, 0.3, 0.2, 0.1]
assert res["vehicle_color"].classes == [
"green",
"black",
"white",
"blue",
"yellow",
"red",
]
assert res["vehicle_color"].scores == [0.4, 0.2, 0.2, 0.1, 0.1, 0.0]


if __name__ == "__main__":
pytest.main()
Loading