-
Notifications
You must be signed in to change notification settings - Fork 0
/
track.py
184 lines (155 loc) · 7.05 KB
/
track.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import hydra
import torch
import cv2
from random import randint
from SORT import * # SORT module
import numpy as np
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
# Global variable to hold the SORT tracker instance
tracker = None
def init_tracker():
"""Initialize the SORT tracker with custom parameters."""
global tracker
sort_max_age = 5
sort_min_hits = 2
sort_iou_thresh = 0.3
tracker = Sort(max_age=sort_max_age, min_hits=sort_min_hits, iou_threshold=sort_iou_thresh)
rand_color_list = []
def random_color_list():
"""Generate a list of random colors for visualizing bounding boxes."""
global rand_color_list
rand_color_list = []
for i in range(0, 5005):
r = randint(0, 255)
g = randint(0, 255)
b = randint(0, 255)
rand_color = (r, g, b)
rand_color_list.append(rand_color)
def draw_boxes(img, bbox, identities=None, categories=None, names=None, offset=(0, 0)):
"""
Draw bounding boxes on an image along with labels and centroids.
Args:
img (numpy.ndarray): The input image.
bbox (list): List of bounding box coordinates [x1, y1, x2, y2].
identities (list): List of object identities (optional).
categories (list): List of object categories (optional).
names (list): List of class names for labeling (optional).
offset (tuple): Offset for drawing (optional).
Returns:
numpy.ndarray: Image with bounding boxes, labels, and centroids.
"""
for i, box in enumerate(bbox):
x1, y1, x2, y2 = [int(i) for i in box]
x1 += offset[0]
x2 += offset[0]
y1 += offset[1]
y2 += offset[1]
cat = int(categories[i]) if categories is not None else 0
id = int(identities[i]) if identities is not None else 0
data = (int((box[0] + box[2]) / 2), (int((box[1] + box[3]) / 2)))
label = names[cat]
(w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 20), 2)
cv2.rectangle(img, (x1, y1 - 20), (x1 + w, y1), (255, 144, 30), -1)
cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, [255, 255, 255], 1)
cv2.circle(img, data, 3, (255, 255, 255), -1) # Centroid of the box
return img
class DetectionPredictor(BasePredictor):
"""
Custom YOLOv5-based detector with SORT object tracking.
This class extends the BasePredictor from Ultralytics YOLOv5 and adds SORT tracking functionality.
"""
def get_annotator(self, img):
"""Initialize an annotator for visualizations."""
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
def preprocess(self, img):
"""Preprocess the input image for inference."""
img = torch.from_numpy(img).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # Convert to fp16/32
img /= 255 # Normalize pixel values to the range [0.0, 1.0]
return img
def postprocess(self, preds, img, orig_img):
"""Postprocess the model's predictions."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det)
for i, pred in enumerate(preds):
shape = orig_img[i].shape if self.drone else orig_img.shape
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
return preds
def write_results(self, idx, preds, batch):
"""
Write the detection results and perform SORT tracking.
Args:
idx (int): Index of the current image in the batch.
preds (List[torch.Tensor]): List of model predictions.
batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): Input batch data.
Returns:
str: Log string containing detection information.
"""
p, im, im0 = batch
log_string = ""
if len(im.shape) == 3:
im = im[None] # Expand for batch dimension
self.seen += 1
im0 = im0.copy()
if self.drone: # Batch_size >= 1
log_string += f'{idx}: '
frame = self.dataset.count
else:
frame = getattr(self.dataset, 'frame', 0)
# Initialize paths for saving results
self.data_path = p
save_path = str(self.save_dir / p.name) # Image save path
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
log_string += '%gx%g ' % im.shape[2:] # Print image dimensions
self.annotator = self.get_annotator(im0)
det = preds[idx]
self.all_outputs.append(det)
if len(det) == 0:
return log_string
# Perform SORT tracking
dets_to_sort = np.empty((0, 6))
for x1, y1, x2, y2, conf, detclass in det.cpu().detach().numpy():
dets_to_sort = np.vstack((dets_to_sort, np.array([x1, y1, x2, y2, conf, detclass])))
tracked_dets = tracker.update(dets_to_sort)
tracks = tracker.getTrackers()
# Draw tracks using random colors
for track in tracks:
[cv2.line(im0, (int(track.centroidarr[i][0]), int(track.centroidarr[i][1])),
(int(track.centroidarr[i+1][0]), int(track.centroidarr[i+1][1])),
rand_color_list[track.id], thickness=3)
for i, _ in enumerate(track.centroidarr) if i < len(track.centroidarr) - 1]
if len(tracked_dets) > 0:
bbox_xyxy = tracked_dets[:, :4]
identities = tracked_dets[:, 8]
categories = tracked_dets[:, 4]
draw_boxes(im0, bbox_xyxy, identities, categories, self.model.names)
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # Normalization gain (whwh)
return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg):
"""Main function for object detection and tracking."""
init_tracker()
random_color_list()
"""
# Load an official or custom model
model = YOLO('yolov8n.pt') # Load an official Detect model
model = YOLO('yolov8n-seg.pt') # Load an official Segment model
model = YOLO('yolov8n-pose.pt') # Load an official Pose model
model = YOLO('path/to/best.pt') # Load a custom trained model
"""
cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.model = cfg.model or "yolov8n.pt"
cfg.model = cfg.model or "yolov8s.pt"
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # Check image size
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
predictor = DetectionPredictor(cfg)
predictor()
if __name__ == "__main__":
predict()