4
4
import torch
5
5
from ultralytics .engine .results import Results
6
6
from ultralytics .utils import ops
7
+ from ultralytics .models .yolo .detect import DetectionPredictor
7
8
from yolox .exp import get_exp
8
9
from yolox .utils import postprocess
9
10
from yolox .utils .model_utils import fuse_model
10
11
11
12
from boxmot .utils import logger as LOGGER
13
+ from boxmot .utils .ops import bytetrack_preprocess
12
14
from tracking .detectors .yolo_interface import YoloInterface
13
15
14
16
# default model weigths for these model names
@@ -48,6 +50,7 @@ class YoloXStrategy(YoloInterface):
48
50
def __init__ (self , model , device , args ):
49
51
50
52
self .args = args
53
+ self .imgsz = args .imgsz
51
54
self .pt = False
52
55
self .stride = 32 # max stride in YOLOX
53
56
@@ -80,25 +83,64 @@ def __init__(self, model, device, args):
80
83
map_location = torch .device ('cpu' )
81
84
)
82
85
86
+ self .device = device
83
87
self .model = exp .get_model ()
84
88
self .model .eval ()
85
89
self .model .load_state_dict (ckpt ["model" ])
86
90
self .model = fuse_model (self .model )
87
- self .model .to (device )
91
+ self .model .to (self . device )
88
92
self .model .eval ()
93
+ self .im_paths = []
94
+ self ._preproc_data = []
89
95
90
96
@torch .no_grad ()
91
97
def __call__ (self , im , augment , visualize , embed ):
98
+ if isinstance (im , list ):
99
+ if len (im [0 ].shape ) == 3 :
100
+ im = torch .stack (im )
101
+ else :
102
+ im = torch .vstack (im )
103
+
104
+ if len (im .shape ) == 3 :
105
+ im = im .unsqueeze (0 )
106
+
107
+ assert len (im .shape ) == 4 , f"Expected 4D tensor as input, got { im .shape } "
108
+
92
109
preds = self .model (im )
93
110
return preds
94
111
95
112
def warmup (self , imgsz ):
96
113
pass
97
114
98
- def postprocess (self , path , preds , im , im0s ):
115
+ def update_im_paths (self , predictor : DetectionPredictor ):
116
+ """
117
+ This function saves image paths for the current batch,
118
+ being passed as callback on_predict_batch_start
119
+ """
120
+ assert (isinstance (predictor , DetectionPredictor ),
121
+ "Only ultralytics predictors are supported" )
122
+ self .im_paths = predictor .batch [0 ]
123
+
124
+ def preprocess (self , im ) -> torch .Tensor :
125
+ assert isinstance (im , list )
126
+ im_preprocessed = []
127
+ self ._preproc_data = []
128
+ for i , img in enumerate (im ):
129
+ img_pre , ratio = bytetrack_preprocess (img , input_size = self .imgsz )
130
+ img_pre = torch .Tensor (img_pre ).unsqueeze (0 ).to (self .device )
131
+
132
+ im_preprocessed .append (img_pre )
133
+ self ._preproc_data .append (ratio )
134
+
135
+ im_preprocessed = torch .vstack (im_preprocessed )
136
+
137
+ return im_preprocessed
138
+
139
+ def postprocess (self , preds , im , im0s ):
99
140
100
141
results = []
101
142
for i , pred in enumerate (preds ):
143
+ im_path = self .im_paths [i ] if len (self .im_paths ) else ""
102
144
103
145
pred = postprocess (
104
146
pred .unsqueeze (0 ), # YOLOX postprocessor expects 3D arary
@@ -111,25 +153,27 @@ def postprocess(self, path, preds, im, im0s):
111
153
if pred is None :
112
154
pred = torch .empty ((0 , 6 ))
113
155
r = Results (
114
- path = path ,
156
+ path = im_path ,
115
157
boxes = pred ,
116
158
orig_img = im0s [i ],
117
159
names = self .names
118
160
)
119
161
results .append (r )
120
162
else :
121
- # (x, y, x, y, conf, obj, cls) --> (x, y, x, y, conf, cls)
122
- pred [:, 4 ] = pred [:, 4 ] * pred [:, 5 ]
163
+ ratio = self ._preproc_data [i ]
164
+ pred [:, 0 ] = pred [:, 0 ] / ratio
165
+ pred [:, 1 ] = pred [:, 1 ] / ratio
166
+ pred [:, 2 ] = pred [:, 2 ] / ratio
167
+ pred [:, 3 ] = pred [:, 3 ] / ratio
168
+ pred [:, 4 ] *= pred [:, 5 ]
123
169
pred = pred [:, [0 , 1 , 2 , 3 , 4 , 6 ]]
124
170
125
- pred [:, :4 ] = ops .scale_boxes (im .shape [2 :], pred [:, :4 ], im0s [i ].shape )
126
-
127
171
# filter boxes by classes
128
172
if self .args .classes :
129
173
pred = pred [torch .isin (pred [:, 5 ].cpu (), torch .as_tensor (self .args .classes ))]
130
174
131
175
r = Results (
132
- path = path ,
176
+ path = im_path ,
133
177
boxes = pred ,
134
178
orig_img = im0s [i ],
135
179
names = self .names
0 commit comments