Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Aug 29, 2024
1 parent 852e1fb commit 56b3abd
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 122 deletions.
2 changes: 1 addition & 1 deletion depthai_nodes/ml/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
)
from .keypoints import HandKeypoints, Keypoints
from .lines import Line, Lines
from .segmentation import SegmentationMasks
from .misc import AgeGender
from .segmentation import SegmentationMasks

__all__ = [
"ImgDetectionWithAdditionalOutput",
Expand Down
2 changes: 1 addition & 1 deletion depthai_nodes/ml/messages/creators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from .detection import create_detection_message, create_line_detection_message
from .image import create_image_message
from .keypoints import create_hand_keypoints_message, create_keypoints_message
from .segmentation import create_sam_message, create_segmentation_message
from .misc import create_age_gender_message
from .segmentation import create_sam_message, create_segmentation_message
from .thermal import create_thermal_message
from .tracked_features import create_tracked_features_message

Expand Down
8 changes: 6 additions & 2 deletions depthai_nodes/ml/messages/creators/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def create_detection_message(
bboxes: np.ndarray,
scores: np.ndarray,
labels: List[int] = None,
keypoints: Union[List[Tuple[float, float]], List[Tuple[float, float, float]]] = None,
keypoints: Union[
List[Tuple[float, float]], List[Tuple[float, float, float]]
] = None,
masks: List[np.ndarray] = None,
) -> dai.ImgDetections:
"""Create a DepthAI message for an object detection.
Expand Down Expand Up @@ -118,7 +120,9 @@ def create_detection_message(
if not isinstance(mask, np.ndarray):
raise ValueError(f"mask should be numpy array, got {type(mask)}.")
if len(mask.shape) != 2:
raise ValueError(f"mask should be of shape (H/4, W/4), got {mask.shape}.")
raise ValueError(
f"mask should be of shape (H/4, W/4), got {mask.shape}."
)

if len(masks) != bboxes.shape[0]:
raise ValueError(
Expand Down
17 changes: 13 additions & 4 deletions depthai_nodes/ml/messages/img_detections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import List, Tuple, Union

import depthai as dai
Expand All @@ -19,11 +18,15 @@ class ImgDetectionWithAdditionalOutput(dai.ImgDetection):
def __init__(self):
"""Initializes the ImgDetectionWithAdditionalOutput object."""
dai.ImgDetection.__init__(self) # TODO: change to super().__init__()?
self._keypoints: Union[List[Tuple[float, float]], List[Tuple[float, float, float]]] = []
self._keypoints: Union[
List[Tuple[float, float]], List[Tuple[float, float, float]]
] = []
self._mask: np.ndarray = np.array([])

@property
def keypoints(self) -> Union[List[Tuple[float, float]], List[Tuple[float, float, float]]]:
def keypoints(
self,
) -> Union[List[Tuple[float, float]], List[Tuple[float, float, float]]]:
"""Returns the keypoints.
@return: List of keypoints.
Expand All @@ -32,7 +35,13 @@ def keypoints(self) -> Union[List[Tuple[float, float]], List[Tuple[float, float,
return self._keypoints

@keypoints.setter
def keypoints(self, value: Union[List[Tuple[Union[int, float], Union[int, float]]], List[Tuple[Union[int, float, float], Union[int, float, float]]]]):
def keypoints(
self,
value: Union[
List[Tuple[Union[int, float], Union[int, float]]],
List[Tuple[Union[int, float, float], Union[int, float, float]]],
],
):
"""Sets the keypoints.
@param value: List of keypoints.
Expand Down
1 change: 0 additions & 1 deletion depthai_nodes/ml/messages/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import depthai as dai
import numpy as np

Expand Down
82 changes: 58 additions & 24 deletions depthai_nodes/ml/parsers/fastsam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@

class FastSAMParser(YOLOParser):
def __init__(
self,
confidence_threshold: int = 0.5,
num_classes: int = 1,
iou_threshold: int = 0.5,
mask_conf: float = 0.5,
input_shape: Tuple[int, int] = (640, 640),
prompt: str = "everything",
points: Optional[Tuple[int, int]] = None,
point_label: Optional[int] = None,
bbox: Optional[Tuple[int, int, int, int]] = None
):
self,
confidence_threshold: int = 0.5,
num_classes: int = 1,
iou_threshold: int = 0.5,
mask_conf: float = 0.5,
input_shape: Tuple[int, int] = (640, 640),
prompt: str = "everything",
points: Optional[Tuple[int, int]] = None,
point_label: Optional[int] = None,
bbox: Optional[Tuple[int, int, int, int]] = None,
):
"""Initialize the YOLOParser node.
@param confidence_threshold: The confidence threshold for the detections
Expand All @@ -47,7 +47,9 @@ def __init__(
@param bbox: The bounding box
@type bbox: Optional[Tuple[int, int, int, int]]
"""
YOLOParser.__init__(self, confidence_threshold, num_classes, iou_threshold, mask_conf)
YOLOParser.__init__(
self, confidence_threshold, num_classes, iou_threshold, mask_conf
)
self.input_shape = input_shape
self.prompt = prompt
self.points = points
Expand Down Expand Up @@ -99,21 +101,36 @@ def setBoundingBox(self, bbox):
def run(self):
while self.isRunning():
try:
nnDataIn : dai.NNData = self.input.get()
nnDataIn: dai.NNData = self.input.get()
except dai.MessageQueue.QueueException:
break # Pipeline was stopped, no more data
break # Pipeline was stopped, no more data
# Get all the layer names
layer_names = nnDataIn.getAllLayerNames()

outputs_names = sorted([name for name in layer_names if "_yolo" in name])
outputs_values = [nnDataIn.getTensor(o, dequantize=True).astype(np.float32) for o in outputs_names]
outputs_values = [
nnDataIn.getTensor(o, dequantize=True).astype(np.float32)
for o in outputs_names
]
# Get the segmentation outputs
masks_outputs_values, protos_output, protos_len = self._get_segmentation_outputs(nnDataIn)
(
masks_outputs_values,
protos_output,
protos_len,
) = self._get_segmentation_outputs(nnDataIn)

if len(outputs_values[0].shape) != 4:
# RVC4
outputs_values = [o.transpose((2, 0, 1))[np.newaxis, ...] for o in outputs_values]
protos_output, protos_len, masks_outputs_values = self._reshape_seg_outputs(protos_output, protos_len, masks_outputs_values)
outputs_values = [
o.transpose((2, 0, 1))[np.newaxis, ...] for o in outputs_values
]
(
protos_output,
protos_len,
masks_outputs_values,
) = self._reshape_seg_outputs(
protos_output, protos_len, masks_outputs_values
)

# Decode the outputs
results = decode_fastsam_output(
Expand All @@ -123,25 +140,42 @@ def run(self):
img_shape=self.input_shape[::-1],
conf_thres=self.confidence_threshold,
iou_thres=self.iou_threshold,
num_classes=self.num_classes
num_classes=self.num_classes,
)

bboxes, masks = [], []
for i in range(results.shape[0]):
bbox, conf, label, seg_coeff = results[i, :4].astype(int), results[i, 4], results[i, 5].astype(int), results[i, 6:].astype(int)
bbox, conf, label, seg_coeff = (
results[i, :4].astype(int),
results[i, 4],
results[i, 5].astype(int),
results[i, 6:].astype(int),
)
bboxes.append(bbox.tolist() + [conf, int(label)])
hi, ai, xi, yi = seg_coeff
mask_coeff = masks_outputs_values[hi][0, ai*protos_len:(ai+1)*protos_len, yi, xi]
mask = process_single_mask(protos_output[0], mask_coeff, self.mask_conf, self.input_shape, bbox)
mask_coeff = masks_outputs_values[hi][
0, ai * protos_len : (ai + 1) * protos_len, yi, xi
]
mask = process_single_mask(
protos_output[0], mask_coeff, self.mask_conf, self.input_shape, bbox
)
masks.append(mask)

results_bboxes = np.array(bboxes)
results_masks = np.array(masks)

if self.prompt == "bbox":
results_masks = box_prompt(results_masks, bbox=self.bbox, orig_shape=self.input_shape[::-1])
results_masks = box_prompt(
results_masks, bbox=self.bbox, orig_shape=self.input_shape[::-1]
)
elif self.prompt == "point":
results_masks = point_prompt(results_bboxes, results_masks, points=self.points, pointlabel=self.point_label, orig_shape=self.input_shape[::-1])
results_masks = point_prompt(
results_bboxes,
results_masks,
points=self.points,
pointlabel=self.point_label,
orig_shape=self.input_shape[::-1],
)

segmentation_message = create_sam_message(results_masks)
self.out.send(segmentation_message)
60 changes: 45 additions & 15 deletions depthai_nodes/ml/parsers/utils/fastsam.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,17 @@ def point_prompt(bboxes, masks, points, pointlabel, orig_shape): # numpy
h = masks[0]["segmentation"].shape[0]
w = masks[0]["segmentation"].shape[1]
if h != target_height or w != target_width:
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
points = [
[int(point[0] * w / target_width), int(point[1] * h / target_height)]
for point in points
]
onemask = np.zeros((h, w))
for annotation in masks:
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
mask = (
annotation["segmentation"]
if isinstance(annotation, dict)
else annotation
)
for i, point in enumerate(points):
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
onemask += mask
Expand All @@ -89,7 +96,9 @@ def point_prompt(bboxes, masks, points, pointlabel, orig_shape): # numpy
return masks


def adjust_bboxes_to_image_border(boxes: np.ndarray, image_shape: Tuple[int, int], threshold: int = 20) -> np.ndarray:
def adjust_bboxes_to_image_border(
boxes: np.ndarray, image_shape: Tuple[int, int], threshold: int = 20
) -> np.ndarray:
"""
Source: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/utils.py#L6 (Ultralytics)
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Expand All @@ -114,12 +123,12 @@ def adjust_bboxes_to_image_border(boxes: np.ndarray, image_shape: Tuple[int, int


def bbox_iou(
box1: np.ndarray,
boxes: np.ndarray,
iou_thres: float = 0.9,
image_shape: Tuple[int, int] = (640, 640),
raw_output: bool = False
) -> np.ndarray:
box1: np.ndarray,
boxes: np.ndarray,
iou_thres: float = 0.9,
image_shape: Tuple[int, int] = (640, 640),
raw_output: bool = False,
) -> np.ndarray:
"""
Source: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/utils.py#L30 (Ultralytics - rewritten to numpy)
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Expand Down Expand Up @@ -161,7 +170,15 @@ def bbox_iou(
return np.flatnonzero(iou > iou_thres)


def decode_fastsam_output(yolo_outputs, strides, anchors, img_shape: Tuple[int, int], conf_thres=0.5, iou_thres=0.45, num_classes=1):
def decode_fastsam_output(
yolo_outputs,
strides,
anchors,
img_shape: Tuple[int, int],
conf_thres=0.5,
iou_thres=0.45,
num_classes=1,
):
"""
Decode the bounding boxes
Expand All @@ -182,13 +199,20 @@ def decode_fastsam_output(yolo_outputs, strides, anchors, img_shape: Tuple[int,
conf_thres=conf_thres,
iou_thres=iou_thres,
num_classes=num_classes,
kpts_mode=False
kpts_mode=False,
)[0]

full_box = np.zeros(output_nms.shape[1])
full_box[2], full_box[3], full_box[4], full_box[6:] = img_shape[1], img_shape[0], 1.0, 1.0
full_box[2], full_box[3], full_box[4], full_box[6:] = (
img_shape[1],
img_shape[0],
1.0,
1.0,
)
full_box = full_box.reshape((1, -1))
critical_iou_index = bbox_iou(full_box[0][:4], output_nms[:, :4], iou_thres=0.9, image_shape=img_shape)
critical_iou_index = bbox_iou(
full_box[0][:4], output_nms[:, :4], iou_thres=0.9, image_shape=img_shape
)

if critical_iou_index.size > 0:
full_box[0][4] = output_nms[critical_iou_index][:, 4]
Expand Down Expand Up @@ -216,8 +240,14 @@ def crop_mask(masks, box):
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))


def process_single_mask(protos, mask_coeff, mask_conf, img_shape: Tuple[int, int], bbox: Tuple[int, int, int, int]) -> np.ndarray:
mask = sigmoid(np.sum(protos * mask_coeff[..., np.newaxis, np.newaxis], axis = 0))
def process_single_mask(
protos,
mask_coeff,
mask_conf,
img_shape: Tuple[int, int],
bbox: Tuple[int, int, int, int],
) -> np.ndarray:
mask = sigmoid(np.sum(protos * mask_coeff[..., np.newaxis, np.newaxis], axis=0))
mask = cv2.resize(mask, img_shape, interpolation=cv2.INTER_NEAREST)
mask = crop_mask(mask, np.array(bbox))
return (mask > mask_conf).astype(np.uint8)
Loading

0 comments on commit 56b3abd

Please sign in to comment.