5
5
from bytetracker .kalman_filter import KalmanFilter
6
6
7
7
8
- def xywh2xyxy (x ):
9
- # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
8
+ def xywh2xyxy (x : np .ndarray ):
9
+ """
10
+ Converts bounding boxes from [x, y, w, h] format to [x1, y1, x2, y2] format
11
+
12
+ Parameters
13
+ ----------
14
+ x: Array at [x, y, w, h] format
15
+ Returns
16
+ -------
17
+ y: Array [x1, y1, x2, y2] format
18
+ """
10
19
y = np .copy (x )
11
20
y [:, 0 ] = x [:, 0 ] - x [:, 2 ] / 2 # top left x
12
21
y [:, 1 ] = x [:, 1 ] - x [:, 3 ] / 2 # top left y
@@ -15,8 +24,18 @@ def xywh2xyxy(x):
15
24
return y
16
25
17
26
18
- def xyxy2xywh (x ):
19
- # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
27
+ def xyxy2xywh (x : np .ndarray ):
28
+ """
29
+ Converts bounding boxes from [x1, y1, x2, y2] format to [x, y, w, h] format
30
+
31
+ Parameters
32
+ ----------
33
+ x: Array at [x1, y1, x2, y2] format
34
+
35
+ Returns
36
+ -------
37
+ y: Array at [x, y, w, h] format
38
+ """
20
39
y = np .copy (x )
21
40
y [:, 0 ] = (x [:, 0 ] + x [:, 2 ]) / 2 # x center
22
41
y [:, 1 ] = (x [:, 1 ] + x [:, 3 ]) / 2 # y center
@@ -39,13 +58,25 @@ def __init__(self, tlwh, score, cls):
39
58
self .cls = cls
40
59
41
60
def predict (self ):
61
+ """
62
+ updates the mean and covariance using a Kalman filter prediction, with a condition
63
+ based on the state of the track.
64
+ """
42
65
mean_state = self .mean .copy ()
43
66
if self .state != TrackState .Tracked :
44
67
mean_state [7 ] = 0
45
68
self .mean , self .covariance = self .kalman_filter .predict (mean_state , self .covariance )
46
69
47
70
@staticmethod
48
- def multi_predict (stracks ):
71
+ def multi_predict (stracks : list ["STrack" ]):
72
+ """
73
+ takes a list of tracks, updates their mean and covariance values, and
74
+ performs a Kalman filter prediction step.
75
+
76
+ Parameters
77
+ ----------
78
+ stracks (list): list of STrack objects
79
+ """
49
80
if len (stracks ) > 0 :
50
81
multi_mean = np .asarray ([st .mean .copy () for st in stracks ])
51
82
multi_covariance = np .asarray ([st .covariance for st in stracks ])
@@ -59,8 +90,16 @@ def multi_predict(stracks):
59
90
stracks [i ].mean = mean
60
91
stracks [i ].covariance = cov
61
92
62
- def activate (self , kalman_filter , frame_id ):
63
- """Start a new tracklet"""
93
+ def activate (self , kalman_filter : KalmanFilter , frame_id : int ):
94
+ """
95
+ initializes a new tracklet with a Kalman filter and assigns a track ID and
96
+ state based on the frame ID.
97
+
98
+ Parameters
99
+ ----------
100
+ kalman_filter: Kalman filter object
101
+ frame_id (int): The `frame_id` parameter in the `activate` method.
102
+ """
64
103
self .kalman_filter = kalman_filter
65
104
self .track_id = self .next_id ()
66
105
self .mean , self .covariance = self .kalman_filter .initiate (self .tlwh_to_xyah (self ._tlwh ))
@@ -72,7 +111,19 @@ def activate(self, kalman_filter, frame_id):
72
111
self .frame_id = frame_id
73
112
self .start_frame = frame_id
74
113
75
- def re_activate (self , new_track , frame_id , new_id = False ):
114
+ def re_activate (self , new_track : "STrack" , frame_id : int , new_id : bool = False ):
115
+ """
116
+ Updates a track using Kalman filtering
117
+
118
+ Parameters
119
+ ----------
120
+ new_track : STrack
121
+ The new track object to update.
122
+ frame_id : int
123
+ The frame ID.
124
+ new_id : bool
125
+ Whether to assign a new ID to the track, by default False.
126
+ """
76
127
self .mean , self .covariance = self .kalman_filter .update (
77
128
self .mean , self .covariance , self .tlwh_to_xyah (new_track .tlwh )
78
129
)
@@ -84,13 +135,16 @@ def re_activate(self, new_track, frame_id, new_id=False):
84
135
self .score = new_track .score
85
136
self .cls = new_track .cls
86
137
87
- def update (self , new_track , frame_id ):
138
+ def update (self , new_track : "STrack" , frame_id : int ):
88
139
"""
89
- Update a matched track
90
- :type new_track: STrack
91
- :type frame_id: int
92
- :type update_feature: bool
93
- :return:
140
+ Update a matched track.
141
+
142
+ Parameters
143
+ ----------
144
+ new_track : STrack
145
+ The new track object to update.
146
+ frame_id : int
147
+ The frame ID.
94
148
"""
95
149
self .frame_id = frame_id
96
150
self .cls = new_track .cls
@@ -120,7 +174,8 @@ def tlwh(self):
120
174
@property
121
175
# @jit(nopython=True)
122
176
def tlbr (self ):
123
- """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
177
+ """
178
+ Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
124
179
`(top left, bottom right)`.
125
180
"""
126
181
ret = self .tlwh .copy ()
@@ -130,7 +185,8 @@ def tlbr(self):
130
185
@staticmethod
131
186
# @jit(nopython=True)
132
187
def tlwh_to_xyah (tlwh ):
133
- """Convert bounding box to format `(center x, center y, aspect ratio,
188
+ """
189
+ Convert bounding box to format `(center x, center y, aspect ratio,
134
190
height)`, where the aspect ratio is `width / height`.
135
191
"""
136
192
ret = np .asarray (tlwh ).copy ()
@@ -167,7 +223,23 @@ def reset(self):
167
223
self .kalman_filter = KalmanFilter ()
168
224
BaseTrack ._count = 0
169
225
170
- def update (self , dets , frame_id ):
226
+ def update (self , dets : np .ndarray , frame_id : int ):
227
+ """
228
+ Performs object tracking by associating detections with existing tracks and updating their states accordingly.
229
+
230
+ Parameters
231
+ ----------
232
+ dets : np.ndarray
233
+ Detection boxes of objects in the format (n x 6), where each row contains (x1, y1, x2, y2, score, class).
234
+ frame_id : int
235
+ The ID of the current frame in the video.
236
+
237
+ Returns
238
+ -------
239
+ np.ndarray
240
+ An array of outputs containing bounding box coordinates, track ID, class label, and
241
+ score for each tracked object.
242
+ """
171
243
self .frame_id = frame_id
172
244
activated_starcks = []
173
245
refind_stracks = []
@@ -318,7 +390,23 @@ def update(self, dets, frame_id):
318
390
return outputs
319
391
320
392
321
- def joint_stracks (tlista , tlistb ):
393
+ def joint_stracks (tlista : list ["STrack" ], tlistb : list ["STrack" ]):
394
+ """
395
+ Merges two lists of objects based on a specific attribute while
396
+ ensuring no duplicates are added.
397
+
398
+ Parameters
399
+ ----------
400
+ tlista : List[STrack]
401
+ list of STrack objects.
402
+ tlistb : List[STrack]
403
+ list of STrack objects.
404
+
405
+ Returns
406
+ -------
407
+ List[STrack]
408
+ A list containing all unique elements from both input lists.
409
+ """
322
410
exists = {}
323
411
res = []
324
412
for t in tlista :
@@ -332,7 +420,22 @@ def joint_stracks(tlista, tlistb):
332
420
return res
333
421
334
422
335
- def sub_stracks (tlista , tlistb ):
423
+ def sub_stracks (tlista : list ["STrack" ], tlistb : list ["STrack" ]):
424
+ """
425
+ Returns a list of STrack objects that are present in tlista but not in tlistb.
426
+
427
+ Parameters
428
+ ----------
429
+ tlista : List[STrack]
430
+ list of STrack objects.
431
+ tlistb : List[STrack]
432
+ list of STrack objects.
433
+
434
+ Returns
435
+ -------
436
+ List[STrack]
437
+ A list containing STrack objects present in tlista but not in tlistb.
438
+ """
336
439
stracks = {}
337
440
for t in tlista :
338
441
stracks [t .track_id ] = t
@@ -343,7 +446,22 @@ def sub_stracks(tlista, tlistb):
343
446
return list (stracks .values ())
344
447
345
448
346
- def remove_duplicate_stracks (stracksa , stracksb ):
449
+ def remove_duplicate_stracks (stracksa : list ["STrack" ], stracksb : list ["STrack" ]):
450
+ """
451
+ Removes duplicate STrack objects from the input lists based on their frame IDs.
452
+
453
+ Parameters
454
+ ----------
455
+ stracksa : List[STrack]
456
+ list of STrack objects.
457
+ stracksb : List[STrack]
458
+ list of STrack objects.
459
+
460
+ Returns
461
+ -------
462
+ Tuple[List[STrack], List[STrack]]
463
+ Two lists containing unique STrack objects after removing duplicates.
464
+ """
347
465
pdist = matching .iou_distance (stracksa , stracksb )
348
466
pairs = np .where (pdist < 0.15 )
349
467
dupa , dupb = list (), list ()
0 commit comments