Skip to content

Commit d680a32

Browse files
committed
Add SORT tracking
* Merge new tracker with model * Update displays.py * Provides better tracking
1 parent f9a4387 commit d680a32

File tree

2 files changed

+40
-156
lines changed

2 files changed

+40
-156
lines changed

source/displays.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def prepare_display(frame, workers):
9595
# Put worker id text
9696
cv2.putText(
9797
display_frame,
98-
f"Worker ID: {worker_id[2:]}",
98+
f"Worker ID: {worker_id}",
9999
(equipments_x1 + 20, equipments_y1 - 10),
100100
TEXT_FONT,
101101
TEXT_SIZE,

source/main.py

+39-155
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1+
from sort import *
12
import numpy as np
23
import cv2
34
from ultralytics import YOLO
45
import copy
56

6-
import pandas as pd
7-
8-
# object detector model
7+
# Object detector model
98
from object_detector import predict
109

1110
# Worker class
1211
from worker import Worker
1312

14-
# display
13+
# Display
1514
from displays import prepare_display
1615

1716
# from PIL import Image as im
1817

1918
DEBUG_MODE = False
2019

2120
# Output video
22-
SAVE_OUTPUT = True
21+
SAVE_OUTPUT = False
2322
OUT_PATH = f"../data/debug/result.mp4"
2423

2524
#
@@ -29,81 +28,16 @@
2928
FRAME_COUNT = 1000
3029

3130
# -------------------------------------------------------
32-
### Configurations
33-
# Scaling percentage of original frame
31+
# Configurations
32+
3433
CONF_LEVEL = 0.4
35-
# Threshold of centers ( old\new)
36-
THR_CENTERS = 200
37-
# Number of max frames to consider a object lost
38-
FRAME_MAX = 24
39-
# Number of max tracked centers stored
40-
PATIENCE = 100
41-
# ROI area color transparency
42-
ALPHA = 0.1 # unused
4334
# -------------------------------------------------------
4435
# Reading video with cv2
4536
video = cv2.VideoCapture(VIDEO_PATH)
4637

4738
# Objects to detect Yolo
4839
class_IDS = [0] # default id for person is 0
4940

50-
# Auxiliary variables
51-
centers_old = {}
52-
obj_id = 0
53-
count_p = 0
54-
last_key = ""
55-
# -------------------------------------------------------
56-
57-
58-
# temp funcs
59-
def detectWorkers():
60-
return
61-
62-
63-
def filter_tracks(centers, PATIENCE):
64-
"""Function to filter track history"""
65-
filter_dict = {}
66-
for k, i in centers.items():
67-
d_frames = i.items()
68-
filter_dict[k] = dict(list(d_frames)[-PATIENCE:])
69-
70-
return filter_dict
71-
72-
73-
def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME_MAX):
74-
"""Function to update track of objects"""
75-
is_new = 0
76-
lastpos = [
77-
(k, list(center.keys())[-1], list(center.values())[-1])
78-
for k, center in centers_old.items()
79-
]
80-
lastpos = [(i[0], i[2]) for i in lastpos if abs(i[1] - frame) <= FRAME_MAX]
81-
# Calculating distance from existing centers points
82-
previous_pos = [
83-
(k, obj_center)
84-
for k, centers in lastpos
85-
if (np.linalg.norm(np.array(centers) - np.array(obj_center)) < THR_CENTERS)
86-
]
87-
# if distance less than a threshold, it will update its positions
88-
if previous_pos:
89-
id_obj = previous_pos[0][0]
90-
centers_old[id_obj][frame] = obj_center
91-
92-
# Else a new ID will be set to the given object
93-
else:
94-
if last_key:
95-
last = last_key.split("D")[1]
96-
id_obj = "ID" + str(int(last) + 1)
97-
else:
98-
id_obj = "ID0"
99-
100-
is_new = 1
101-
centers_old[id_obj] = {frame: obj_center}
102-
last_key = list(centers_old.keys())[-1]
103-
104-
return centers_old, id_obj, is_new, last_key
105-
106-
10741
# loading a YOLO model
10842
model = YOLO("yolov8n.pt")
10943

@@ -116,64 +50,39 @@ def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME
11650
# Output video properties
11751
frame_width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
11852
frame_height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
53+
11954
if SAVE_OUTPUT:
12055
fps = int(video.get(cv2.CAP_PROP_FPS))
12156
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
12257
out = cv2.VideoWriter(OUT_PATH, fourcc, fps, (frame_width, frame_height))
12358

59+
MOT_DETECTOR = Sort()
60+
12461
for i in range(FRAME_COUNT):
12562
success, frame = video.read()
12663

12764
# Continue until desired frame rate.
12865
if success:
66+
# Copy frame for display
12967
annotated_frame = copy.deepcopy(frame)
130-
y_hat = model.predict(frame, conf=CONF_LEVEL, classes=class_IDS)
131-
132-
boxes = y_hat[0].boxes.xyxy.cpu().numpy()
133-
conf = y_hat[0].boxes.conf.cpu().numpy()
134-
classes = y_hat[0].boxes.cls.cpu().numpy()
68+
# Convert frame to RGB for models
69+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
70+
# Run human detection model
71+
humans_detected = model(frame, conf=CONF_LEVEL, classes=class_IDS)
13572

136-
# Storing the above information in a dataframe
137-
positions_frame = pd.DataFrame(
138-
y_hat[0].cpu().numpy().boxes.boxes,
139-
columns=["xmin", "ymin", "xmax", "ymax", "conf", "class"],
140-
)
73+
# Prepare detected persons with initial id's for MOT_DETECTOR
74+
# columns = ["x1", "y2", "x2", "y1", "conf", "class"]
75+
idx = [0, 1, 2, 3]
76+
pos_frame = humans_detected[0].boxes.data.numpy()[::, idx]
14177

142-
# Translating the numeric class labels to text
143-
labels = [dict_classes[i] for i in classes]
78+
# Update MOT_DETECTOR tracker object with respect to human detections
79+
track_bbs_ids = MOT_DETECTOR.update(pos_frame).astype(np.int32)
14480

81+
# Containers to save detected workers
14582
worker_info = [] # (id, coord1, coord2)
146-
worker_Images = []
147-
148-
# For each people, draw the bounding-box and add scaled and cropped images to list
149-
for ix, row in enumerate(positions_frame.iterrows()):
150-
# Getting the coordinates of each vehicle (row)
151-
(
152-
x1,
153-
y2,
154-
x2,
155-
y1,
156-
confidence,
157-
category,
158-
) = row[
159-
1
160-
].astype("int")
161-
162-
# Calculating the center of the bounding-box
163-
center_x, center_y = int(((x2 + x1)) / 2), int((y1 + y2) / 2)
164-
165-
# Updating the tracking for each object
166-
centers_old, id_obj, is_new, last_key = update_tracking(
167-
centers_old,
168-
(center_x, center_y),
169-
THR_CENTERS,
170-
last_key,
171-
i,
172-
FRAME_MAX,
173-
)
174-
175-
# Updating people in roi
176-
count_p += is_new
83+
worker_images = []
84+
for person in track_bbs_ids:
85+
x1, y1, x2, y2, track_id = person
17786

17887
# Crop and save persons from images
17988
# Expand the person according to expand constant.
@@ -196,37 +105,22 @@ def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME
196105
x2_expanded = frame_width - 1
197106

198107
# Cropping worker image
199-
workerImg = frame[y2_expanded:y1_expanded, x1_expanded:x2_expanded]
200-
201-
# Drawing center and bounding-box in the given frame
202-
cv2.rectangle(
203-
annotated_frame, (x1, y2), (x2, y1), (0, 0, 255), 2
204-
) # box
205-
"""
206-
for center_x,center_y in centers_old[id_obj].values():
207-
cv2.circle(annotated_frame, (center_x,center_y), 5,(0,0,255),-1) # center of box
208-
"""
209-
210-
# Drawing above the bounding-box the name of class recognized.
211-
"""
212-
cv2.putText(img=annotated_frame, text=id_obj+':'+str(np.round(conf[ix],2)),
213-
org= (x1,y2-10), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=0.8, color=(0, 0, 255),thickness=1)
214-
"""
215-
216-
worker_Images.append(workerImg)
217-
worker_info.append((last_key, (x1, y1), (x2, y2)))
218-
219-
# for worker in workers_cropped -> predict yap -> worker objects listesine ekle
220-
worker_objects = []
221-
222-
for i in range(len(worker_Images)):
223-
# coordinates
108+
worker_img = frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded]
109+
# Save detected workers for equipment detection
110+
worker_images.append(worker_img)
111+
worker_id = person[4]
112+
worker_info.append((worker_id, (x1, y1), (x2, y2)))
113+
114+
worker_objects = [] # Container to store worker objects
115+
# Run equipment detection for all workers
116+
for i in range(len(worker_images)):
117+
# Set coordinates
224118
worker_topLeft = worker_info[i][1]
225119
worker_bottomRight = worker_info[i][2]
226120

227-
# equipments
228-
worker_helmet = predict(worker_Images[i], "helmet") # status, conf
229-
worker_vest = predict(worker_Images[i], "vest") # status, conf
121+
# Detect equipments
122+
worker_helmet = predict(worker_images[i], "helmet") # status, conf
123+
worker_vest = predict(worker_images[i], "vest") # status, conf
230124

231125
equipments = {}
232126
equipments["helmet"] = worker_helmet
@@ -237,20 +131,10 @@ def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME
237131
)
238132
worker_objects.append(worker_instance)
239133

240-
# display
134+
# Prepare display and show result
241135
annotated_frame = prepare_display(annotated_frame, worker_objects)
242-
243-
# drawing the number of people
244-
"""
245-
cv2.putText(img=annotated_frame, text=f'Counts People in ROI: {count_p}',
246-
org= (30,40), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
247-
fontScale=1.5, color=(255, 0, 0), thickness=1)
248-
"""
249-
250-
# Filtering tracks history
251-
centers_old = filter_tracks(centers_old, PATIENCE)
252-
253136
cv2.imshow("Safety Equipment Detector", annotated_frame)
137+
254138
if SAVE_OUTPUT == True:
255139
out.write(annotated_frame)
256140

0 commit comments

Comments
 (0)