From ec9ad5a6c9536eb218200f485ec29aaef4173c87 Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Fri, 16 Aug 2024 14:26:19 +0200 Subject: [PATCH] fix: HRNetParser formatting, remove comments, add normalization --- depthai_nodes/ml/parsers/hrnet.py | 73 +++++++++++++++++++------------ 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/depthai_nodes/ml/parsers/hrnet.py b/depthai_nodes/ml/parsers/hrnet.py index 081b2dd..7d1b83b 100644 --- a/depthai_nodes/ml/parsers/hrnet.py +++ b/depthai_nodes/ml/parsers/hrnet.py @@ -1,62 +1,79 @@ import depthai as dai import numpy as np -import cv2 from ..messages.creators import create_keypoints_message class HRNetParser(dai.node.ThreadedHostNode): - def __init__(self, score_threshold=0.5, input_size=[256, 256], heatmap_size=[64, 64]): + """Parser class for parsing the output of the HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation. + + Attributes + ---------- + input : Node.Input + Node's input. It is a linking point to which the Neural Network's output is linked. It accepts the output of the Neural Network node. + out : Node.Output + Parser sends the processed network results to this output in a form of DepthAI message. It is a linking point from which the processed network results are retrieved. + score_threshold : float + Confidence score threshold for detected keypoints. + + Output Message/s + ---------------- + **Type**: Keypoints + + **Description**: Keypoints message containing detected body keypoints. + """ + + def __init__(self, score_threshold=0.5): + """Initializes the HRNetParser node. + + @param score_threshold: Confidence score threshold for detected keypoints. + @type score_threshold: float + """ dai.node.ThreadedHostNode.__init__(self) self.input = dai.Node.Input(self) self.out = dai.Node.Output(self) - self.input_size = input_size - self.heatmap_size = heatmap_size + self.score_threshold = score_threshold def setScoreThreshold(self, threshold): - self.score_threshold = threshold + """Sets the confidence score threshold for the detected body keypoints. - def run(self): - """Postprocessing logic for HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation - - Returns: - ... + @param threshold: Confidence score threshold for detected keypoints. + @type threshold: float """ + self.score_threshold = threshold + def run(self): while self.isRunning(): try: output: dai.NNData = self.input.get() except dai.MessageQueue.QueueException: break # Pipeline was stopped - img_width, img_height = self.input_size - heatmaps = output.getTensor("heatmaps", dequantize=True) - - if len(heatmaps.shape) == 4: # add new axis for batch size - heatmaps = heatmaps[0] - if heatmaps.shape[2] == 16: # HW_ instead of _HW + if len(heatmaps.shape) == 4: + heatmaps = heatmaps[0] + if heatmaps.shape[2] == 16: # HW_ instead of _HW heatmaps = heatmaps.transpose(2, 0, 1) - _, map_h, map_w = heatmaps.shape - # Find the maximum value in each of the heatmaps and its location - max_vals = np.array([np.max(heatmap) for heatmap in heatmaps]) - keypoints = np.array([np.unravel_index(heatmap.argmax(), heatmap.shape) - for heatmap in heatmaps]) + scores = np.array([np.max(heatmap) for heatmap in heatmaps]) + keypoints = np.array( + [ + np.unravel_index(heatmap.argmax(), heatmap.shape) + for heatmap in heatmaps + ] + ) keypoints = keypoints.astype(np.float32) - keypoints[max_vals < self.score_threshold] = np.array([np.nan, np.nan]) - - # Scale keypoints to the image size - # TODO: remove and have relative keypoint values? e.g. * np.array([64 / map_w, 64 / map_h]) to get relative values? - keypoints = keypoints[:, ::-1] * np.array([img_width / map_w, img_height / map_h]) + keypoints = keypoints[:, ::-1] / np.array( + [map_w, map_h] + ) # normalize keypoints to [0, 1] keypoints_msg = create_keypoints_message( keypoints=keypoints, - #scores=max_vals, # TODO: add scores - #confidence_threshold=self.confidence_threshold # TODO: add confidence threshold + scores=scores, + confidence_threshold=self.score_threshold, ) self.out.send(keypoints_msg)