Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification parser #15

Merged
merged 9 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

Expand Down Expand Up @@ -162,4 +161,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.DS_Store
.DS_Store
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .classification import Classifications
from .img_detections import ImgDetectionsWithKeypoints, ImgDetectionWithKeypoints
from .keypoints import HandKeypoints, Keypoints
from .lines import Line, Lines
Expand All @@ -9,4 +10,5 @@
"Keypoints",
"Line",
"Lines",
"Classifications",
]
24 changes: 24 additions & 0 deletions depthai_nodes/ml/messages/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import List

import depthai as dai


class Classifications(dai.Buffer):
def __init__(self):
dai.Buffer.__init__(self)
self._classes = []

aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
@property
def classes(self) -> List:
return self._classes

@classes.setter
def classes(self, value: List):
if not isinstance(value, list):
raise TypeError("Must be a list.")
for item in value:
if not isinstance(item, list) or len(item) != 2:
raise TypeError(
"Each item must be a list of [class_name, probability_score], got {item}."
)
self._classes = value
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .classification_message import create_classification_message
from .depth import create_depth_message
from .detection import create_detection_message, create_line_detection_message
from .image import create_image_message
Expand All @@ -16,4 +17,5 @@
"create_tracked_features_message",
"create_keypoints_message",
"create_thermal_message",
"create_classification_message",
]
68 changes: 68 additions & 0 deletions depthai_nodes/ml/messages/creators/classification_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import depthai as dai
import numpy as np

from ...messages import Classifications


def create_classification_message(
scores: np.ndarray, classes: np.ndarray = None
) -> dai.Buffer:
"""Create a message for classification. The message contains the class names and
their respective scores, sorted in descending order of scores.

Parameters
----------
scores : np.ndarray
A numpy array of shape (n_classes,) containing the probability score of each class.

classes : np.ndarray = []
A numpy array of class names. If not provided, class names are set to None.


Returns
--------
Classifications : dai.Buffer
A message with parameter `classes` which is a list of shape (n_classes, 2)
where each item is [class_name, probability_score].
If no class names are provided, class_name is set to None.
"""
if classes is None:
classes = np.array([])

if len(scores) == 0:
raise ValueError("Scores should not be empty.")

if len(scores) != len(scores.flatten()):
raise ValueError(f"Scores should be a 1D array, got {scores.shape}.")

scores = scores.flatten()

if not np.issubdtype(scores.dtype, np.floating):
raise ValueError(f"Scores should be of type float, got {scores.dtype}.")

print("scores", np.sum(scores))
if not np.isclose(np.sum(scores), 1.0, atol=1e-1):
raise ValueError(f"Scores should sum to 1, got {np.sum(scores)}.")

if len(scores) != len(classes) and len(classes) != 0:
raise ValueError(
f"Number of labels and scores mismatch. Provided {len(scores)} scores and {len(classes)} class names."
)

classification_msg = Classifications()

sorted_args = np.argsort(scores)[::-1]
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
scores = scores[sorted_args]

if len(classes) == 0:
classification_msg.classes = [
[None, float(scores[i])] for i in range(len(scores))
]
return classification_msg

classes = classes[sorted_args]
classification_msg.classes = [
[str(classes[i]), float(scores[i])] for i in range(len(classes))
]

return classification_msg
2 changes: 2 additions & 0 deletions depthai_nodes/ml/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .classification_parser import ClassificationParser
from .image_output import ImageOutputParser
from .keypoints import KeypointParser
from .mediapipe_hand_landmarker import MPHandLandmarkParser
Expand All @@ -24,4 +25,5 @@
"MLSDParser",
"XFeatParser",
"ThermalImageParser",
"ClassificationParser",
]
66 changes: 66 additions & 0 deletions depthai_nodes/ml/parsers/classification_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import depthai as dai
import numpy as np

from ..messages.creators import create_classification_message


class ClassificationParser(dai.node.ThreadedHostNode):
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
"""Postprocessing logic for Classification model.

Parameters
----------
classes : list[str]
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
List of class labels.
is_softmax : bool = True
True, if output is already softmaxed.

Returns
-------
Classifications: dai.Buffer
An object with parameter `classes`, which is a list of items like [class_name, probability_score].
If no class names are provided, class_name is set to None.
"""

def __init__(self, classes: list[str] = None, is_softmax: bool = True):
dai.node.ThreadedHostNode.__init__(self)
self.out = self.createOutput()
self.input = self.createInput()
if classes is None:
self.classes = []
else:
self.classes = np.array(classes)
self.n_classes = len(classes)
self.is_softmax = is_softmax

def setClasses(self, classes):
self.classes = classes
self.n_classes = len(classes)

aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

output_layer_names = output.getAllLayerNames()
if len(output_layer_names) != 1:
raise ValueError(
f"Expected 1 output layer, got {len(output_layer_names)}."
)

scores = output.getTensor(output_layer_names[0])
scores = np.array(scores).flatten()

if len(scores) != self.n_classes and self.n_classes != 0:
aljazkonec1 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Number of labels and scores mismatch. Provided {self.n_classes} class names and {len(scores)} scores."
)

if not self.is_softmax:
ex = np.exp(scores)
scores = ex / np.sum(ex)

msg = create_classification_message(scores, self.classes)

self.out.send(msg)