Skip to content

Commit 4fb1460

Browse files
authored
feat: typing functions and adding some docstring (#26)
1 parent 4d8a20d commit 4fb1460

File tree

3 files changed

+186
-31
lines changed

3 files changed

+186
-31
lines changed

bytetracker/byte_tracker.py

+138-20
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,17 @@
55
from bytetracker.kalman_filter import KalmanFilter
66

77

8-
def xywh2xyxy(x):
9-
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
8+
def xywh2xyxy(x: np.ndarray):
9+
"""
10+
Converts bounding boxes from [x, y, w, h] format to [x1, y1, x2, y2] format
11+
12+
Parameters
13+
----------
14+
x: Array at [x, y, w, h] format
15+
Returns
16+
-------
17+
y: Array [x1, y1, x2, y2] format
18+
"""
1019
y = np.copy(x)
1120
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
1221
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
@@ -15,8 +24,18 @@ def xywh2xyxy(x):
1524
return y
1625

1726

18-
def xyxy2xywh(x):
19-
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
27+
def xyxy2xywh(x: np.ndarray):
28+
"""
29+
Converts bounding boxes from [x1, y1, x2, y2] format to [x, y, w, h] format
30+
31+
Parameters
32+
----------
33+
x: Array at [x1, y1, x2, y2] format
34+
35+
Returns
36+
-------
37+
y: Array at [x, y, w, h] format
38+
"""
2039
y = np.copy(x)
2140
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
2241
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
@@ -39,13 +58,25 @@ def __init__(self, tlwh, score, cls):
3958
self.cls = cls
4059

4160
def predict(self):
61+
"""
62+
updates the mean and covariance using a Kalman filter prediction, with a condition
63+
based on the state of the track.
64+
"""
4265
mean_state = self.mean.copy()
4366
if self.state != TrackState.Tracked:
4467
mean_state[7] = 0
4568
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
4669

4770
@staticmethod
48-
def multi_predict(stracks):
71+
def multi_predict(stracks: list["STrack"]):
72+
"""
73+
takes a list of tracks, updates their mean and covariance values, and
74+
performs a Kalman filter prediction step.
75+
76+
Parameters
77+
----------
78+
stracks (list): list of STrack objects
79+
"""
4980
if len(stracks) > 0:
5081
multi_mean = np.asarray([st.mean.copy() for st in stracks])
5182
multi_covariance = np.asarray([st.covariance for st in stracks])
@@ -59,8 +90,16 @@ def multi_predict(stracks):
5990
stracks[i].mean = mean
6091
stracks[i].covariance = cov
6192

62-
def activate(self, kalman_filter, frame_id):
63-
"""Start a new tracklet"""
93+
def activate(self, kalman_filter: KalmanFilter, frame_id: int):
94+
"""
95+
initializes a new tracklet with a Kalman filter and assigns a track ID and
96+
state based on the frame ID.
97+
98+
Parameters
99+
----------
100+
kalman_filter: Kalman filter object
101+
frame_id (int): The `frame_id` parameter in the `activate` method.
102+
"""
64103
self.kalman_filter = kalman_filter
65104
self.track_id = self.next_id()
66105
self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))
@@ -72,7 +111,19 @@ def activate(self, kalman_filter, frame_id):
72111
self.frame_id = frame_id
73112
self.start_frame = frame_id
74113

75-
def re_activate(self, new_track, frame_id, new_id=False):
114+
def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False):
115+
"""
116+
Updates a track using Kalman filtering
117+
118+
Parameters
119+
----------
120+
new_track : STrack
121+
The new track object to update.
122+
frame_id : int
123+
The frame ID.
124+
new_id : bool
125+
Whether to assign a new ID to the track, by default False.
126+
"""
76127
self.mean, self.covariance = self.kalman_filter.update(
77128
self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
78129
)
@@ -84,13 +135,16 @@ def re_activate(self, new_track, frame_id, new_id=False):
84135
self.score = new_track.score
85136
self.cls = new_track.cls
86137

87-
def update(self, new_track, frame_id):
138+
def update(self, new_track: "STrack", frame_id: int):
88139
"""
89-
Update a matched track
90-
:type new_track: STrack
91-
:type frame_id: int
92-
:type update_feature: bool
93-
:return:
140+
Update a matched track.
141+
142+
Parameters
143+
----------
144+
new_track : STrack
145+
The new track object to update.
146+
frame_id : int
147+
The frame ID.
94148
"""
95149
self.frame_id = frame_id
96150
self.cls = new_track.cls
@@ -120,7 +174,8 @@ def tlwh(self):
120174
@property
121175
# @jit(nopython=True)
122176
def tlbr(self):
123-
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
177+
"""
178+
Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
124179
`(top left, bottom right)`.
125180
"""
126181
ret = self.tlwh.copy()
@@ -130,7 +185,8 @@ def tlbr(self):
130185
@staticmethod
131186
# @jit(nopython=True)
132187
def tlwh_to_xyah(tlwh):
133-
"""Convert bounding box to format `(center x, center y, aspect ratio,
188+
"""
189+
Convert bounding box to format `(center x, center y, aspect ratio,
134190
height)`, where the aspect ratio is `width / height`.
135191
"""
136192
ret = np.asarray(tlwh).copy()
@@ -167,7 +223,23 @@ def reset(self):
167223
self.kalman_filter = KalmanFilter()
168224
BaseTrack._count = 0
169225

170-
def update(self, dets, frame_id):
226+
def update(self, dets: np.ndarray, frame_id: int):
227+
"""
228+
Performs object tracking by associating detections with existing tracks and updating their states accordingly.
229+
230+
Parameters
231+
----------
232+
dets : np.ndarray
233+
Detection boxes of objects in the format (n x 6), where each row contains (x1, y1, x2, y2, score, class).
234+
frame_id : int
235+
The ID of the current frame in the video.
236+
237+
Returns
238+
-------
239+
np.ndarray
240+
An array of outputs containing bounding box coordinates, track ID, class label, and
241+
score for each tracked object.
242+
"""
171243
self.frame_id = frame_id
172244
activated_starcks = []
173245
refind_stracks = []
@@ -318,7 +390,23 @@ def update(self, dets, frame_id):
318390
return outputs
319391

320392

321-
def joint_stracks(tlista, tlistb):
393+
def joint_stracks(tlista: list["STrack"], tlistb: list["STrack"]):
394+
"""
395+
Merges two lists of objects based on a specific attribute while
396+
ensuring no duplicates are added.
397+
398+
Parameters
399+
----------
400+
tlista : List[STrack]
401+
list of STrack objects.
402+
tlistb : List[STrack]
403+
list of STrack objects.
404+
405+
Returns
406+
-------
407+
List[STrack]
408+
A list containing all unique elements from both input lists.
409+
"""
322410
exists = {}
323411
res = []
324412
for t in tlista:
@@ -332,7 +420,22 @@ def joint_stracks(tlista, tlistb):
332420
return res
333421

334422

335-
def sub_stracks(tlista, tlistb):
423+
def sub_stracks(tlista: list["STrack"], tlistb: list["STrack"]):
424+
"""
425+
Returns a list of STrack objects that are present in tlista but not in tlistb.
426+
427+
Parameters
428+
----------
429+
tlista : List[STrack]
430+
list of STrack objects.
431+
tlistb : List[STrack]
432+
list of STrack objects.
433+
434+
Returns
435+
-------
436+
List[STrack]
437+
A list containing STrack objects present in tlista but not in tlistb.
438+
"""
336439
stracks = {}
337440
for t in tlista:
338441
stracks[t.track_id] = t
@@ -343,7 +446,22 @@ def sub_stracks(tlista, tlistb):
343446
return list(stracks.values())
344447

345448

346-
def remove_duplicate_stracks(stracksa, stracksb):
449+
def remove_duplicate_stracks(stracksa: list["STrack"], stracksb: list["STrack"]):
450+
"""
451+
Removes duplicate STrack objects from the input lists based on their frame IDs.
452+
453+
Parameters
454+
----------
455+
stracksa : List[STrack]
456+
list of STrack objects.
457+
stracksb : List[STrack]
458+
list of STrack objects.
459+
460+
Returns
461+
-------
462+
Tuple[List[STrack], List[STrack]]
463+
Two lists containing unique STrack objects after removing duplicates.
464+
"""
347465
pdist = matching.iou_distance(stracksa, stracksb)
348466
pairs = np.where(pdist < 0.15)
349467
dupa, dupb = list(), list()

bytetracker/kalman_filter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def initiate(self, measurement):
6868
covariance = np.diag(np.square(std))
6969
return mean, covariance
7070

71-
def predict(self, mean, covariance):
71+
def predict(self, mean: np.ndarray, covariance: np.ndarray):
7272
"""Run Kalman filter prediction step.
7373
7474
Parameters
@@ -109,7 +109,7 @@ def predict(self, mean, covariance):
109109

110110
return mean, covariance
111111

112-
def project(self, mean, covariance):
112+
def project(self, mean: np.ndarray, covariance: np.ndarray):
113113
"""Project state distribution to measurement space.
114114
115115
Parameters
@@ -138,7 +138,7 @@ def project(self, mean, covariance):
138138
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
139139
return mean, covariance + innovation_cov
140140

141-
def multi_predict(self, mean, covariance):
141+
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
142142
"""Run Kalman filter prediction step (Vectorized version).
143143
Parameters
144144
----------
@@ -179,7 +179,7 @@ def multi_predict(self, mean, covariance):
179179

180180
return mean, covariance
181181

182-
def update(self, mean, covariance, measurement):
182+
def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
183183
"""Run Kalman filter correction step.
184184
185185
Parameters

bytetracker/matching.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,21 @@
22
import numpy as np
33

44

5-
def linear_assignment(cost_matrix, thresh):
5+
def linear_assignment(cost_matrix: np.ndarray, thresh: float):
6+
"""
7+
Assigns detections to existing tracks based on a cost matrix using linear assignment.
8+
9+
Parameters
10+
----------
11+
cost_matrix : np.ndarray
12+
The cost matrix representing the association cost between detections and tracks.
13+
thresh : float
14+
The threshold for cost matching.
15+
16+
Returns
17+
-------
18+
Tuple containing matches, unmatched detections, and unmatched tracks.
19+
"""
620
if cost_matrix.size == 0:
721
return (
822
np.empty((0, 2), dtype=int),
@@ -22,11 +36,19 @@ def linear_assignment(cost_matrix, thresh):
2236

2337
def ious(atlbrs, btlbrs):
2438
"""
25-
Compute cost based on IoU
26-
:type atlbrs: list[tlbr] | np.ndarray
27-
:type atlbrs: list[tlbr] | np.ndarray
39+
Compute cost over Union (IoU) between bounding box pairs
40+
41+
Parameters
42+
----------
43+
atlbrs : Union[list, np.ndarray]
44+
The bounding boxes of the first set in (min x, min y, max x, max y) format.
45+
btlbrs : Union[list, np.ndarray]
46+
The bounding boxes of the second set in (min x, min y, max x, max y) format.
2847
29-
:rtype ious np.ndarray
48+
Returns
49+
-------
50+
np.ndarray
51+
An array containing IoU values for each pair of bounding boxes.
3052
"""
3153
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
3254
if ious.size == 0:
@@ -63,7 +85,22 @@ def iou_distance(atracks, btracks):
6385
return cost_matrix
6486

6587

66-
def fuse_score(cost_matrix, detections):
88+
def fuse_score(cost_matrix: np.ndarray, detections: np.ndarray):
89+
"""
90+
Fuse detection scores with similarity scores from a cost matrix.
91+
92+
Parameters
93+
----------
94+
cost_matrix : np.ndarray
95+
The cost matrix representing the dissimilarity between tracks and detections.
96+
detections : np.ndarray
97+
The array of detections, each containing a score.
98+
99+
Returns
100+
-------
101+
np.ndarray
102+
The fused cost matrix, incorporating both similarity scores and detection scores.
103+
"""
67104
if cost_matrix.size == 0:
68105
return cost_matrix
69106
iou_sim = 1 - cost_matrix
@@ -74,7 +111,7 @@ def fuse_score(cost_matrix, detections):
74111
return fuse_cost
75112

76113

77-
def bbox_ious(boxes, query_boxes):
114+
def bbox_ious(boxes: np.ndarray, query_boxes: np.ndarray):
78115
"""
79116
Parameters
80117
----------

0 commit comments

Comments
 (0)