1
+ from sort import *
1
2
import numpy as np
2
3
import cv2
3
4
from ultralytics import YOLO
4
5
import copy
5
6
6
- import pandas as pd
7
-
8
- # object detector model
7
+ # Object detector model
9
8
from object_detector import predict
10
9
11
10
# Worker class
12
11
from worker import Worker
13
12
14
- # display
13
+ # Display
15
14
from displays import prepare_display
16
15
17
16
# from PIL import Image as im
18
17
19
18
DEBUG_MODE = False
20
19
21
20
# Output video
22
- SAVE_OUTPUT = True
21
+ SAVE_OUTPUT = False
23
22
OUT_PATH = f"../data/debug/result.mp4"
24
23
25
24
#
29
28
FRAME_COUNT = 1000
30
29
31
30
# -------------------------------------------------------
32
- ### Configurations
33
- # Scaling percentage of original frame
31
+ # Configurations
32
+
34
33
CONF_LEVEL = 0.4
35
- # Threshold of centers ( old\new)
36
- THR_CENTERS = 200
37
- # Number of max frames to consider a object lost
38
- FRAME_MAX = 24
39
- # Number of max tracked centers stored
40
- PATIENCE = 100
41
- # ROI area color transparency
42
- ALPHA = 0.1 # unused
43
34
# -------------------------------------------------------
44
35
# Reading video with cv2
45
36
video = cv2 .VideoCapture (VIDEO_PATH )
46
37
47
38
# Objects to detect Yolo
48
39
class_IDS = [0 ] # default id for person is 0
49
40
50
- # Auxiliary variables
51
- centers_old = {}
52
- obj_id = 0
53
- count_p = 0
54
- last_key = ""
55
- # -------------------------------------------------------
56
-
57
-
58
- # temp funcs
59
- def detectWorkers ():
60
- return
61
-
62
-
63
- def filter_tracks (centers , PATIENCE ):
64
- """Function to filter track history"""
65
- filter_dict = {}
66
- for k , i in centers .items ():
67
- d_frames = i .items ()
68
- filter_dict [k ] = dict (list (d_frames )[- PATIENCE :])
69
-
70
- return filter_dict
71
-
72
-
73
- def update_tracking (centers_old , obj_center , THR_CENTERS , last_key , frame , FRAME_MAX ):
74
- """Function to update track of objects"""
75
- is_new = 0
76
- lastpos = [
77
- (k , list (center .keys ())[- 1 ], list (center .values ())[- 1 ])
78
- for k , center in centers_old .items ()
79
- ]
80
- lastpos = [(i [0 ], i [2 ]) for i in lastpos if abs (i [1 ] - frame ) <= FRAME_MAX ]
81
- # Calculating distance from existing centers points
82
- previous_pos = [
83
- (k , obj_center )
84
- for k , centers in lastpos
85
- if (np .linalg .norm (np .array (centers ) - np .array (obj_center )) < THR_CENTERS )
86
- ]
87
- # if distance less than a threshold, it will update its positions
88
- if previous_pos :
89
- id_obj = previous_pos [0 ][0 ]
90
- centers_old [id_obj ][frame ] = obj_center
91
-
92
- # Else a new ID will be set to the given object
93
- else :
94
- if last_key :
95
- last = last_key .split ("D" )[1 ]
96
- id_obj = "ID" + str (int (last ) + 1 )
97
- else :
98
- id_obj = "ID0"
99
-
100
- is_new = 1
101
- centers_old [id_obj ] = {frame : obj_center }
102
- last_key = list (centers_old .keys ())[- 1 ]
103
-
104
- return centers_old , id_obj , is_new , last_key
105
-
106
-
107
41
# loading a YOLO model
108
42
model = YOLO ("yolov8n.pt" )
109
43
@@ -116,64 +50,39 @@ def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME
116
50
# Output video properties
117
51
frame_width = int (video .get (cv2 .CAP_PROP_FRAME_WIDTH ))
118
52
frame_height = int (video .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
53
+
119
54
if SAVE_OUTPUT :
120
55
fps = int (video .get (cv2 .CAP_PROP_FPS ))
121
56
fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
122
57
out = cv2 .VideoWriter (OUT_PATH , fourcc , fps , (frame_width , frame_height ))
123
58
59
+ MOT_DETECTOR = Sort ()
60
+
124
61
for i in range (FRAME_COUNT ):
125
62
success , frame = video .read ()
126
63
127
64
# Continue until desired frame rate.
128
65
if success :
66
+ # Copy frame for display
129
67
annotated_frame = copy .deepcopy (frame )
130
- y_hat = model .predict (frame , conf = CONF_LEVEL , classes = class_IDS )
131
-
132
- boxes = y_hat [0 ].boxes .xyxy .cpu ().numpy ()
133
- conf = y_hat [0 ].boxes .conf .cpu ().numpy ()
134
- classes = y_hat [0 ].boxes .cls .cpu ().numpy ()
68
+ # Convert frame to RGB for models
69
+ frame = cv2 .cvtColor (frame , cv2 .COLOR_BGR2RGB )
70
+ # Run human detection model
71
+ humans_detected = model (frame , conf = CONF_LEVEL , classes = class_IDS )
135
72
136
- # Storing the above information in a dataframe
137
- positions_frame = pd .DataFrame (
138
- y_hat [0 ].cpu ().numpy ().boxes .boxes ,
139
- columns = ["xmin" , "ymin" , "xmax" , "ymax" , "conf" , "class" ],
140
- )
73
+ # Prepare detected persons with initial id's for MOT_DETECTOR
74
+ # columns = ["x1", "y2", "x2", "y1", "conf", "class"]
75
+ idx = [0 , 1 , 2 , 3 ]
76
+ pos_frame = humans_detected [0 ].boxes .data .numpy ()[::, idx ]
141
77
142
- # Translating the numeric class labels to text
143
- labels = [ dict_classes [ i ] for i in classes ]
78
+ # Update MOT_DETECTOR tracker object with respect to human detections
79
+ track_bbs_ids = MOT_DETECTOR . update ( pos_frame ). astype ( np . int32 )
144
80
81
+ # Containers to save detected workers
145
82
worker_info = [] # (id, coord1, coord2)
146
- worker_Images = []
147
-
148
- # For each people, draw the bounding-box and add scaled and cropped images to list
149
- for ix , row in enumerate (positions_frame .iterrows ()):
150
- # Getting the coordinates of each vehicle (row)
151
- (
152
- x1 ,
153
- y2 ,
154
- x2 ,
155
- y1 ,
156
- confidence ,
157
- category ,
158
- ) = row [
159
- 1
160
- ].astype ("int" )
161
-
162
- # Calculating the center of the bounding-box
163
- center_x , center_y = int (((x2 + x1 )) / 2 ), int ((y1 + y2 ) / 2 )
164
-
165
- # Updating the tracking for each object
166
- centers_old , id_obj , is_new , last_key = update_tracking (
167
- centers_old ,
168
- (center_x , center_y ),
169
- THR_CENTERS ,
170
- last_key ,
171
- i ,
172
- FRAME_MAX ,
173
- )
174
-
175
- # Updating people in roi
176
- count_p += is_new
83
+ worker_images = []
84
+ for person in track_bbs_ids :
85
+ x1 , y1 , x2 , y2 , track_id = person
177
86
178
87
# Crop and save persons from images
179
88
# Expand the person according to expand constant.
@@ -196,37 +105,22 @@ def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME
196
105
x2_expanded = frame_width - 1
197
106
198
107
# Cropping worker image
199
- workerImg = frame [y2_expanded :y1_expanded , x1_expanded :x2_expanded ]
200
-
201
- # Drawing center and bounding-box in the given frame
202
- cv2 .rectangle (
203
- annotated_frame , (x1 , y2 ), (x2 , y1 ), (0 , 0 , 255 ), 2
204
- ) # box
205
- """
206
- for center_x,center_y in centers_old[id_obj].values():
207
- cv2.circle(annotated_frame, (center_x,center_y), 5,(0,0,255),-1) # center of box
208
- """
209
-
210
- # Drawing above the bounding-box the name of class recognized.
211
- """
212
- cv2.putText(img=annotated_frame, text=id_obj+':'+str(np.round(conf[ix],2)),
213
- org= (x1,y2-10), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=0.8, color=(0, 0, 255),thickness=1)
214
- """
215
-
216
- worker_Images .append (workerImg )
217
- worker_info .append ((last_key , (x1 , y1 ), (x2 , y2 )))
218
-
219
- # for worker in workers_cropped -> predict yap -> worker objects listesine ekle
220
- worker_objects = []
221
-
222
- for i in range (len (worker_Images )):
223
- # coordinates
108
+ worker_img = frame [y1_expanded :y2_expanded , x1_expanded :x2_expanded ]
109
+ # Save detected workers for equipment detection
110
+ worker_images .append (worker_img )
111
+ worker_id = person [4 ]
112
+ worker_info .append ((worker_id , (x1 , y1 ), (x2 , y2 )))
113
+
114
+ worker_objects = [] # Container to store worker objects
115
+ # Run equipment detection for all workers
116
+ for i in range (len (worker_images )):
117
+ # Set coordinates
224
118
worker_topLeft = worker_info [i ][1 ]
225
119
worker_bottomRight = worker_info [i ][2 ]
226
120
227
- # equipments
228
- worker_helmet = predict (worker_Images [i ], "helmet" ) # status, conf
229
- worker_vest = predict (worker_Images [i ], "vest" ) # status, conf
121
+ # Detect equipments
122
+ worker_helmet = predict (worker_images [i ], "helmet" ) # status, conf
123
+ worker_vest = predict (worker_images [i ], "vest" ) # status, conf
230
124
231
125
equipments = {}
232
126
equipments ["helmet" ] = worker_helmet
@@ -237,20 +131,10 @@ def update_tracking(centers_old, obj_center, THR_CENTERS, last_key, frame, FRAME
237
131
)
238
132
worker_objects .append (worker_instance )
239
133
240
- # display
134
+ # Prepare display and show result
241
135
annotated_frame = prepare_display (annotated_frame , worker_objects )
242
-
243
- # drawing the number of people
244
- """
245
- cv2.putText(img=annotated_frame, text=f'Counts People in ROI: {count_p}',
246
- org= (30,40), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
247
- fontScale=1.5, color=(255, 0, 0), thickness=1)
248
- """
249
-
250
- # Filtering tracks history
251
- centers_old = filter_tracks (centers_old , PATIENCE )
252
-
253
136
cv2 .imshow ("Safety Equipment Detector" , annotated_frame )
137
+
254
138
if SAVE_OUTPUT == True :
255
139
out .write (annotated_frame )
256
140
0 commit comments