Skip to content

Commit

Permalink
Merge pull request #71 from luxonis/corner-support-for-ppdet
Browse files Browse the repository at this point in the history
[HOT-FIX] Update ppdet to return polygons.
  • Loading branch information
aljazkonec1 authored Sep 17, 2024
2 parents f61e1db + a79b8ad commit 7345faa
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 8 deletions.
2 changes: 2 additions & 0 deletions depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .clusters import Cluster, Clusters
from .composite import CompositeMessage
from .img_detections import (
CornerDetections,
ImgDetectionExtended,
ImgDetectionsExtended,
)
Expand All @@ -23,4 +24,5 @@
"Clusters",
"Cluster",
"CompositeMessage",
"CornerDetections",
]
7 changes: 6 additions & 1 deletion depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
create_multi_classification_message,
)
from .clusters import create_cluster_message
from .detection import create_detection_message, create_line_detection_message
from .detection import (
create_corner_detection_message,
create_detection_message,
create_line_detection_message,
)
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .map import create_map_message
Expand All @@ -27,4 +31,5 @@
"create_classification_sequence_message",
"create_cluster_message",
"create_multi_classification_message",
"create_corner_detection_message",
]
54 changes: 54 additions & 0 deletions depthai_nodes/ml/messages/creators/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import numpy as np

from ...messages import (
CornerDetections,
ImgDetectionExtended,
ImgDetectionsExtended,
Line,
Lines,
)
from .keypoints import create_keypoints_message


def create_detection_message(
Expand Down Expand Up @@ -233,3 +235,55 @@ def create_line_detection_message(lines: np.ndarray, scores: np.ndarray):
lines_msg = Lines()
lines_msg.lines = line_detections
return lines_msg


def create_corner_detection_message(
bboxes: np.ndarray,
scores: np.ndarray,
labels: List[int] = None,
) -> CornerDetections:
"""Create a DepthAI message for an object detection.
@param bbox: Bounding boxes of detected objects in corner format of shape (N,4,2) meaning [...,[[x1, y1], [x2, y2], [x3, y3], [x4, y4]],...].
@type bbox: np.ndarray
@param scores: Confidence scores of detected objects of shape (N,).
@type scores: np.ndarray
@param labels: Labels of detected objects of shape (N,).
@type labels: List[int]
@return: CornerDetections message containing a list of corners, a list of labels, and a list of scores.
@rtype: CornerDetections
"""
if bboxes.shape[0] == 0:
return CornerDetections()

if bboxes.shape[1] != 4 or bboxes.shape[2] != 2:
raise ValueError(
f"Bounding boxes should be of shape (N,4,2), got {bboxes.shape}."
)

if bboxes.shape[0] != len(scores):
raise ValueError(
f"Number of bounding boxes and scores should have the same length, got {len(scores)} scores and {bboxes.shape[0]} bounding boxes."
)

if labels is not None:
if len(labels) != len(scores):
raise ValueError(
f"Number of labels and scores should have the same length, got {len(labels)} labels and {len(scores)} scores."
)

corner_boxes = []

for bbox in bboxes:
corner_box = create_keypoints_message(bbox)
corner_boxes.append(corner_box)

message = CornerDetections()
if labels is not None:
message.labels = labels

message.detections = corner_boxes
message.scores = list(scores)

return message
97 changes: 97 additions & 0 deletions depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import depthai as dai
import numpy as np

from .keypoints import Keypoints


class ImgDetectionExtended(dai.ImgDetection):
"""ImgDetectionExtended class for storing image detection with keypoints and masks.
Expand Down Expand Up @@ -137,3 +139,98 @@ def detections(self, value: List[ImgDetectionExtended]):
"Each detection must be an instance of ImgDetectionExtended"
)
self._detections = value


class CornerDetections(dai.Buffer):
"""Detection Class for storing object detections in corner format.
Attributes
----------
detections: List[Keypoints]
List of detections in keypoint format.
labels: List[int]
List of labels for each detection
"""

def __init__(self):
"""Initializes the CornerDetections object."""
dai.Buffer.__init__(self)
self._detections: List[Keypoints] = []
self._scores: List[float] = None
self._labels: List[int] = None

@property
def detections(self) -> List[Keypoints]:
"""Returns the detections.
@return: List of detections.
@rtype: List[Keypoints]
"""
return self._detections

@detections.setter
def detections(self, value: List[Keypoints]):
"""Sets the detections.
@param value: List of detections.
@type value: List[Keypoints]
@raise TypeError: If the detections are not a list.
@raise TypeError: If each detection is not an instance of Keypoints.
"""
if not isinstance(value, list):
raise TypeError("Detections must be a list")
for item in value:
if not isinstance(item, Keypoints):
raise TypeError("Each detection must be an instance of Keypoints")
self._detections = value

@property
def labels(self) -> List[int]:
"""Returns the labels.
@return: List of labels.
@rtype: List[int]
"""
return self._labels

@labels.setter
def labels(self, value: List[int]):
"""Sets the labels.
@param value: List of labels.
@type value: List[int]
@raise TypeError: If the labels are not a list.
@raise TypeError: If each label is not an integer.
"""
if not isinstance(value, list):
raise TypeError("Labels must be a list")
for item in value:
if not isinstance(item, int):
raise TypeError("Each label must be an integer")
self._labels = value

@property
def scores(self) -> List[float]:
"""Returns the scores.
@return: List of scores.
@rtype: List[float]
"""
return self._scores

@scores.setter
def scores(self, value: List[float]):
"""Sets the scores.
@param value: List of scores.
@type value: List[float]
@raise TypeError: If the scores are not a list.
@raise TypeError: If each score is not a float.
"""
if not isinstance(value, list):
raise TypeError("Scores must be a list")
for item in value:
if not isinstance(item, float):
raise TypeError("Each score must be a float")
self._scores = value
8 changes: 3 additions & 5 deletions depthai_nodes/ml/parsers/ppdet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import depthai as dai

from ..messages.creators import create_detection_message
from .utils import corners2xyxy, parse_paddle_detection_outputs
from ..messages.creators import create_corner_detection_message
from .utils import parse_paddle_detection_outputs


class PPTextDetectionParser(dai.node.ThreadedHostNode):
Expand Down Expand Up @@ -91,9 +91,7 @@ def run(self):
self.max_detections,
)

bboxes = corners2xyxy(bboxes)

message = create_detection_message(bboxes, scores)
message = create_corner_detection_message(bboxes, scores)
message.setTimestamp(output.getTimestamp())

self.out.send(message)
4 changes: 2 additions & 2 deletions media/coverage_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 7345faa

Please sign in to comment.