Skip to content

Commit abd0f05

Browse files
committed
created detector_rewrite_live for frame-by-frame prediction
1 parent 8eb40b4 commit abd0f05

File tree

2 files changed

+160
-6
lines changed

2 files changed

+160
-6
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import sys
2+
import numpy as np
3+
import json
4+
import argparse
5+
import torch
6+
import cv2
7+
import pyzed.sl as sl
8+
from ultralytics import YOLO
9+
10+
from threading import Lock, Thread
11+
from time import sleep
12+
13+
14+
def initialize_camera_params(zed, input_type):
15+
init_params = sl.InitParameters(
16+
input_t=input_type, svo_real_time_mode=True
17+
) # input vs input_t??
18+
init_params.coordinate_units = sl.UNIT.METER
19+
init_params.depth_mode = sl.DEPTH_MODE.ULTRA # QUALITY
20+
init_params.coordinate_system = sl.COORDINATE_SYSTEM.RIGHT_HANDED_Y_UP
21+
init_params.depth_maximum_distance = 50
22+
23+
runtime_params = sl.RuntimeParameters()
24+
status = zed.open(init_params)
25+
26+
if status != sl.ERROR_CODE.SUCCESS:
27+
print(f"Failed to open camera: {repr(status)}")
28+
exit()
29+
30+
image_left_tmp = sl.Mat() # not needed?
31+
32+
positional_tracking_parameters = sl.PositionalTrackingParameters()
33+
zed.enable_positional_tracking(positional_tracking_parameters)
34+
35+
obj_param = sl.ObjectDetectionParameters()
36+
obj_param.detection_model = sl.OBJECT_DETECTION_MODEL.CUSTOM_BOX_OBJECTS
37+
obj_param.enable_tracking = True
38+
obj_param.enable_segmentation = (
39+
False # designed to give person pixel mask with internal OD
40+
)
41+
zed.enable_object_detection(obj_param)
42+
43+
# return runtime_params
44+
45+
46+
def main():
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument(
49+
"--weights", type=str, default="yolov8m.pt", help="Path to YOLO model weights."
50+
)
51+
parser.add_argument("--svo", type=str, required=True, help="Path to the SVO file.")
52+
parser.add_argument(
53+
"--output",
54+
type=str,
55+
default="./output.mp4",
56+
help="Path to save the annotated video.",
57+
)
58+
args = parser.parse_args()
59+
60+
# Initialize ZED camera
61+
zed = sl.Camera()
62+
input_type = sl.InputType()
63+
input_type.set_from_svo_file(args.svo)
64+
initialize_camera_params(zed, input_type)
65+
66+
# Load YOLO model
67+
model = YOLO(args.weights)
68+
69+
svo_image = sl.Mat()
70+
obj_runtime_param = sl.ObjectDetectionRuntimeParameters()
71+
obj_runtime_param.detection_confidence_threshold = 40
72+
73+
# Video Writer setup
74+
image_size = zed.get_camera_information().camera_resolution
75+
width, height = image_size.width, image_size.height
76+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
77+
out = cv2.VideoWriter(args.output, fourcc, 30.0, (width, height))
78+
79+
# Detection parameters
80+
conf = 0.2
81+
iou = 0.45
82+
83+
while zed.grab() == sl.ERROR_CODE.SUCCESS:
84+
# Retrieve the left RGB image and depth map
85+
zed.retrieve_image(svo_image, sl.VIEW.LEFT)
86+
image_net = svo_image.get_data()
87+
img = cv2.cvtColor(image_net, cv2.COLOR_RGBA2RGB)
88+
89+
# YOLO detection
90+
results = model.predict(img, save=False, conf=conf, iou=iou)
91+
detections = results[0].cpu().numpy().boxes
92+
93+
# objects_in = []
94+
# for box in detections:
95+
# tmp = sl.CustomBoxObjectData()
96+
# tmp.unique_object_id = sl.generate_unique_id()
97+
# tmp.probability = box.conf[0].item()
98+
# tmp.label = int(box.cls[0].item())
99+
# tmp.bounding_box_2d = box.xyxy[0].cpu().numpy()
100+
# tmp.is_grounded = True
101+
# objects_in.append(tmp)
102+
103+
objects_in = []
104+
for box in detections:
105+
tmp = sl.CustomBoxObjectData()
106+
tmp.unique_object_id = sl.generate_unique_id()
107+
tmp.probability = box.conf
108+
tmp.label = int(box.class_id)
109+
tmp.bounding_box_2d = box.bounding_box
110+
tmp.is_grounded = (
111+
True # objects are moving on the floor plane and tracked in 2D only
112+
)
113+
objects_in.append(tmp)
114+
115+
# Ingest custom 3D objects into ZED SDK for tracking
116+
zed.ingest_custom_box_objects(objects_in)
117+
118+
# Retrieve 3D bounding boxes from ZED
119+
objects = sl.Objects()
120+
zed.retrieve_objects(objects, obj_runtime_param)
121+
122+
if objects.object_list:
123+
first_object = objects.object_list[0]
124+
print(f"Object ID: {first_object.id}, Position: {first_object.position}")
125+
126+
# Draw 3D bounding box
127+
for obj in objects.object_list:
128+
bbox = obj.bounding_box
129+
if bbox:
130+
for i in range(4):
131+
start = (int(bbox[i][0]), int(bbox[i][1]))
132+
end = (int(bbox[(i + 1) % 4][0]), int(bbox[(i + 1) % 4][1]))
133+
cv2.line(img, start, end, (0, 255, 0), 2)
134+
for i in range(4, 8):
135+
start = (int(bbox[i][0]), int(bbox[i][1]))
136+
end = (
137+
int(bbox[(i + 1) % 4 + 4][0]),
138+
int(bbox[(i + 1) % 4 + 4][1]),
139+
)
140+
cv2.line(img, start, end, (0, 255, 0), 2)
141+
for i in range(4):
142+
cv2.line(
143+
img,
144+
(int(bbox[i][0]), int(bbox[i][1])),
145+
(int(bbox[i + 4][0]), int(bbox[i + 4][1])),
146+
(0, 255, 0),
147+
2,
148+
)
149+
150+
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
151+
152+
out.release()
153+
zed.close()
154+
155+
156+
if __name__ == "__main__":
157+
main()

vision/workflow-test/rewrite.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def initialize_camera_params(zed, input_type):
3939

4040
def main():
4141
parser = argparse.ArgumentParser()
42-
parser.add_argument('--weights', type=str, default='yolov8m.pt', help='model.pt path(s)')
42+
parser.add_argument(
43+
"--weights", type=str, default="yolov11n.pt", help="model.pt path(s)"
44+
)
4345
parser.add_argument('--svo', type=str, default=None, help=' svo file')
4446

4547
args = parser.parse_args()
@@ -81,7 +83,6 @@ def main():
8183

8284
objects = sl.Objects() # Structure containing all the detected objects
8385

84-
8586
zed.retrieve_objects(objects, obj_runtime_param) # Retrieve the 3D tracked objects
8687

8788
obj_runtime_param = sl.ObjectDetectionRuntimeParameters()
@@ -90,7 +91,6 @@ def main():
9091
for object in objects.object_list:
9192
print("{} {}".format(object.id, object.position))
9293

93-
9494
object_id = object.id # Get the object id
9595
object_label = object.raw_label; # Get the label
9696
object_position = object.position # Get the object position
@@ -100,8 +100,5 @@ def main():
100100
print("Object {0} is tracked\n".format(object_id))
101101

102102

103-
104-
105-
106103
if __name__ == '__main__':
107104
main()

0 commit comments

Comments
 (0)