Skip to content

Commit

Permalink
Improved SCRFD decoding. (#21)
Browse files Browse the repository at this point in the history
* Improved SCRFD decoding.

* Confict error fix.

* Variable rename.
  • Loading branch information
kkeroo authored Aug 19, 2024
1 parent bee80cb commit ceb642c
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 86 deletions.
6 changes: 3 additions & 3 deletions depthai_nodes/ml/messages/creators/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def create_detection_message(
if keypoints is not None and len(keypoints) != 0:
if not isinstance(keypoints, List):
raise ValueError(f"keypoints should be list, got {type(keypoints)}.")
for pointcloud in keypoints:
for point in pointcloud:
if not isinstance(point, Tuple):
for object_keypoints in keypoints:
for point in object_keypoints:
if not isinstance(point, Tuple) and not isinstance(point, List):
raise ValueError(
f"keypoint pairs should be list of tuples, got {type(point)}."
)
Expand Down
2 changes: 1 addition & 1 deletion depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def keypoints(self, value: List[Tuple[Union[int, float], Union[int, float]]]):
raise TypeError("Keypoints must be a list")
for item in value:
if (
not isinstance(item, tuple)
not (isinstance(item, tuple) or isinstance(item, list))
or len(item) != 2
or not all(isinstance(i, (int, float)) for i in item)
):
Expand Down
180 changes: 98 additions & 82 deletions depthai_nodes/ml/parsers/scrfd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import cv2
import depthai as dai
import numpy as np

from ..messages.creators import create_detection_message
from .utils.scrfd import decode_scrfd


class SCRFDParser(dai.node.ThreadedHostNode):
Expand All @@ -20,6 +20,12 @@ class SCRFDParser(dai.node.ThreadedHostNode):
Non-maximum suppression threshold.
top_k : int
Maximum number of detections to keep.
feat_stride_fpn : tuple
Tuple of the feature strides.
num_anchors : int
Number of anchors.
input_size : tuple
Input size of the model.
Output Message/s
----------------
Expand All @@ -28,7 +34,15 @@ class SCRFDParser(dai.node.ThreadedHostNode):
**Description**: ImgDetections message containing bounding boxes, labels, and confidence scores of detected faces.
"""

def __init__(self, score_threshold=0.5, nms_threshold=0.5, top_k=100):
def __init__(
self,
score_threshold=0.5,
nms_threshold=0.5,
top_k=100,
input_size=(640, 640),
feat_stride_fpn=(8, 16, 32),
num_anchors=2,
):
"""Initializes the SCRFDParser node.
@param score_threshold: Confidence score threshold for detected faces.
Expand All @@ -37,6 +51,12 @@ def __init__(self, score_threshold=0.5, nms_threshold=0.5, top_k=100):
@type nms_threshold: float
@param top_k: Maximum number of detections to keep.
@type top_k: int
@param feat_stride_fpn: List of the feature strides.
@type feat_stride_fpn: tuple
@param num_anchors: Number of anchors.
@type num_anchors: int
@param input_size: Input size of the model.
@type input_size: tuple
"""
dai.node.ThreadedHostNode.__init__(self)
self.input = dai.Node.Input(self)
Expand All @@ -46,6 +66,10 @@ def __init__(self, score_threshold=0.5, nms_threshold=0.5, top_k=100):
self.nms_threshold = nms_threshold
self.top_k = top_k

self.feat_stride_fpn = feat_stride_fpn
self.num_anchors = num_anchors
self.input_size = input_size

def setConfidenceThreshold(self, threshold):
"""Sets the confidence score threshold for detected faces.
Expand All @@ -70,108 +94,100 @@ def setTopK(self, top_k):
"""
self.top_k = top_k

def setFeatStrideFPN(self, feat_stride_fpn):
"""Sets the feature stride of the FPN.
@param feat_stride_fpn: Feature stride of the FPN.
@type feat_stride_fpn: list
"""
self.feat_stride_fpn = feat_stride_fpn

def setInputSize(self, input_size):
"""Sets the input size of the model.
@param input_size: Input size of the model.
@type input_size: list
"""
self.input_size = input_size

def setNumAnchors(self, num_anchors):
"""Sets the number of anchors.
@param num_anchors: Number of anchors.
@type num_anchors: int
"""
self.num_anchors = num_anchors

def run(self):
while self.isRunning():
try:
output: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

score_8 = output.getTensor("score_8").flatten().astype(np.float32)
score_16 = output.getTensor("score_16").flatten().astype(np.float32)
score_32 = output.getTensor("score_32").flatten().astype(np.float32)
score_8 = (
output.getTensor("score_8", dequantize=True)
.flatten()
.astype(np.float32)
)
score_16 = (
output.getTensor("score_16", dequantize=True)
.flatten()
.astype(np.float32)
)
score_32 = (
output.getTensor("score_32", dequantize=True)
.flatten()
.astype(np.float32)
)
bbox_8 = (
output.getTensor("bbox_8").reshape(len(score_8), 4).astype(np.float32)
output.getTensor("bbox_8", dequantize=True)
.reshape(len(score_8), 4)
.astype(np.float32)
)
bbox_16 = (
output.getTensor("bbox_16").reshape(len(score_16), 4).astype(np.float32)
output.getTensor("bbox_16", dequantize=True)
.reshape(len(score_16), 4)
.astype(np.float32)
)
bbox_32 = (
output.getTensor("bbox_32").reshape(len(score_32), 4).astype(np.float32)
output.getTensor("bbox_32", dequantize=True)
.reshape(len(score_32), 4)
.astype(np.float32)
)
kps_8 = (
output.getTensor("kps_8").reshape(len(score_8), 5, 2).astype(np.float32)
output.getTensor("kps_8", dequantize=True)
.reshape(len(score_8), 10)
.astype(np.float32)
)
kps_16 = (
output.getTensor("kps_16")
.reshape(len(score_16), 5, 2)
output.getTensor("kps_16", dequantize=True)
.reshape(len(score_16), 10)
.astype(np.float32)
)
kps_32 = (
output.getTensor("kps_32")
.reshape(len(score_32), 5, 2)
output.getTensor("kps_32", dequantize=True)
.reshape(len(score_32), 10)
.astype(np.float32)
)

bboxes = []
keypoints = []

for i in range(len(score_8)):
y = int(np.floor(i / 80)) * 4
x = (i % 160) * 4
bbox = bbox_8[i]
xmin = int(x - bbox[0] * 8)
ymin = int(y - bbox[1] * 8)
xmax = int(x + bbox[2] * 8)
ymax = int(y + bbox[3] * 8)
kps = kps_8[i]
kps_batch = []
for kp in kps:
kpx = int(x + kp[0] * 8)
kpy = int(y + kp[1] * 8)
kps_batch.append([kpx, kpy])
keypoints.append(kps_batch)
bbox = [xmin, ymin, xmax, ymax]
bboxes.append(bbox)

for i in range(len(score_16)):
y = int(np.floor(i / 40)) * 8
x = (i % 80) * 8
bbox = bbox_16[i]
xmin = int(x - bbox[0] * 16)
ymin = int(y - bbox[1] * 16)
xmax = int(x + bbox[2] * 16)
ymax = int(y + bbox[3] * 16)
kps = kps_16[i]
kps_batch = []
for kp in kps:
kpx = int(x + kp[0] * 16)
kpy = int(y + kp[1] * 16)
kps_batch.append([kpx, kpy])
keypoints.append(kps_batch)
bbox = [xmin, ymin, xmax, ymax]
bboxes.append(bbox)

for i in range(len(score_32)):
y = int(np.floor(i / 20)) * 16
x = (i % 40) * 16
bbox = bbox_32[i]
xmin = int(x - bbox[0] * 32)
ymin = int(y - bbox[1] * 32)
xmax = int(x + bbox[2] * 32)
ymax = int(y + bbox[3] * 32)
kps = kps_32[i]
kps_batch = []
for kp in kps:
kpx = int(x + kp[0] * 32)
kpy = int(y + kp[1] * 32)
kps_batch.append([kpx, kpy])
keypoints.append(kps_batch)
bbox = [xmin, ymin, xmax, ymax]
bboxes.append(bbox)

scores = np.concatenate([score_8, score_16, score_32])
indices = cv2.dnn.NMSBoxes(
bboxes,
list(scores),
self.score_threshold,
self.nms_threshold,
top_k=self.top_k,
bboxes_concatenated = [bbox_8, bbox_16, bbox_32]
scores_concatenated = [score_8, score_16, score_32]
kps_concatenated = [kps_8, kps_16, kps_32]

bboxes, scores, keypoints = decode_scrfd(
bboxes_concatenated=bboxes_concatenated,
scores_concatenated=scores_concatenated,
kps_concatenated=kps_concatenated,
feat_stride_fpn=self.feat_stride_fpn,
input_size=self.input_size,
num_anchors=self.num_anchors,
score_threshold=self.score_threshold,
nms_threshold=self.nms_threshold,
)
detection_msg = create_detection_message(
bboxes, scores, None, keypoints.tolist()
)
bboxes = np.array(bboxes)[indices]
keypoints = np.array(keypoints)[indices]
scores = scores[indices]

detection_msg = create_detection_message(bboxes, scores, None, None)
detection_msg.setTimestamp(output.getTimestamp())

self.out.send(detection_msg)
Loading

0 comments on commit ceb642c

Please sign in to comment.