Skip to content

Commit

Permalink
RVC4 support for parsers. (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeroo authored Aug 19, 2024
1 parent 4fffe46 commit bee80cb
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 13 deletions.
2 changes: 1 addition & 1 deletion depthai_nodes/ml/parsers/image_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
raise ValueError(
f"Expected 1 output layer, got {len(output_layer_names)}."
)
output = output.getTensor(output_layer_names[0])
output = output.getTensor(output_layer_names[0], dequantize=True)

if len(output.shape) == 4:
image = output[0]
Expand Down
4 changes: 3 additions & 1 deletion depthai_nodes/ml/parsers/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def run(self):
f"Expected 1 output layer, got {len(output_layer_names)}."
)

keypoints = output.getTensor(output_layer_names[0])
keypoints = output.getTensor(output_layer_names[0], dequantize=True).astype(
np.float32
)
num_coords = int(np.prod(keypoints.shape) / self.num_keypoints)

if num_coords not in [2, 3]:
Expand Down
18 changes: 15 additions & 3 deletions depthai_nodes/ml/parsers/mediapipe_hand_landmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,21 @@ def run(self):
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

landmarks = output.getTensor("Identity").reshape(21, 3).astype(np.float32)
hand_score = output.getTensor("Identity_1").reshape(-1).astype(np.float32)
handedness = output.getTensor("Identity_2").reshape(-1).astype(np.float32)
landmarks = (
output.getTensor("Identity", dequantize=True)
.reshape(21, 3)
.astype(np.float32)
)
hand_score = (
output.getTensor("Identity_1", dequantize=True)
.reshape(-1)
.astype(np.float32)
)
handedness = (
output.getTensor("Identity_2", dequantize=True)
.reshape(-1)
.astype(np.float32)
)
hand_score = hand_score[0]
handedness = handedness[0]

Expand Down
12 changes: 10 additions & 2 deletions depthai_nodes/ml/parsers/mediapipe_palm_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,16 @@ def run(self):
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

bboxes = output.getTensor("Identity").reshape(2016, 18).astype(np.float32)
scores = output.getTensor("Identity_1").reshape(2016).astype(np.float32)
bboxes = (
output.getTensor("Identity", dequantize=True)
.reshape(2016, 18)
.astype(np.float32)
)
scores = (
output.getTensor("Identity_1", dequantize=True)
.reshape(2016)
.astype(np.float32)
)

decoded_bboxes = generate_anchors_and_decode(
bboxes=bboxes, scores=scores, threshold=self.score_threshold, scale=192
Expand Down
8 changes: 5 additions & 3 deletions depthai_nodes/ml/parsers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ def run(self):
f"Expected 1 output layer, got {len(output_layer_names)}."
)

segmentation_mask = output.getTensor(output_layer_names[0])[
0
] # num_clases x H x W
segmentation_mask = output.getTensor(output_layer_names[0], dequantize=True)
if len(segmentation_mask.shape) == 4:
segmentation_mask = segmentation_mask[0]
else:
segmentation_mask = segmentation_mask.transpose(2, 0, 1)

if len(segmentation_mask.shape) != 3:
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion depthai_nodes/ml/parsers/superanimal_landmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def run(self):
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

heatmaps = output.getTensor("heatmaps").astype(np.float32)
heatmaps = output.getTensor("heatmaps", dequantize=True).astype(np.float32)

if len(heatmaps.shape) == 3:
heatmaps = heatmaps.reshape((1,) + heatmaps.shape)

heatmaps_scale_factor = (
self.scale_factor / heatmaps.shape[1],
Expand Down
13 changes: 11 additions & 2 deletions depthai_nodes/ml/parsers/xfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ def run(self):
except dai.MessageQueue.QueueException:
break # Pipeline was stopped

feats = output.getTensor("feats").astype(np.float32)
keypoints = output.getTensor("keypoints").astype(np.float32)
feats = output.getTensor("feats", dequantize=True).astype(np.float32)
keypoints = output.getTensor("keypoints", dequantize=True).astype(
np.float32
)

if len(feats.shape) == 3:
feats = feats.reshape((1,) + feats.shape).transpose(0, 3, 1, 2)
if len(keypoints.shape) == 3:
keypoints = keypoints.reshape((1,) + keypoints.shape).transpose(
0, 3, 1, 2
)

result = detect_and_compute(
feats, keypoints, resize_rate_w, resize_rate_h, self.input_size
Expand Down

0 comments on commit bee80cb

Please sign in to comment.