Skip to content

Commit 7ebb7a3

Browse files
Merge pull request #1739 from Fleyderer/fix-yolox
Fix YOLOX
2 parents 755cb4d + ead2604 commit 7ebb7a3

File tree

8 files changed

+209
-72
lines changed

8 files changed

+209
-72
lines changed

boxmot/utils/ops.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Tuple, Union
77

88

9-
109
def xyxy2xywh(x):
1110
"""
1211
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
@@ -186,4 +185,34 @@ def letterbox(
186185
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
187186
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
188187

189-
return img, ratio, (dw, dh)
188+
return img, ratio, (dw, dh)
189+
190+
191+
# This preprocess differs from the current version of YOLOX preprocess, but ByteTrack uses it
192+
# https://github.com/ifzhang/ByteTrack/blob/d1bf0191adff59bc8fcfeaa0b33d3d1642552a99/yolox/data/data_augment.py#L189
193+
def bytetrack_preprocess(image, input_size,
194+
mean=(0.485, 0.456, 0.406),
195+
std=(0.229, 0.224, 0.225),
196+
swap=(2, 0, 1)):
197+
if len(image.shape) == 3:
198+
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
199+
else:
200+
padded_img = np.ones(input_size) * 114.0
201+
img = np.array(image)
202+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
203+
resized_img = cv2.resize(
204+
img,
205+
(int(img.shape[1] * r), int(img.shape[0] * r)),
206+
interpolation=cv2.INTER_LINEAR,
207+
).astype(np.float32)
208+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
209+
210+
padded_img = padded_img[:, :, ::-1]
211+
padded_img /= 255.0
212+
if mean is not None:
213+
padded_img -= mean
214+
if std is not None:
215+
padded_img /= std
216+
padded_img = padded_img.transpose(swap)
217+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
218+
return padded_img, r

examples/det/yolox_boxmot.ipynb

+19-12
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"from yolox.utils import postprocess\n",
2828
"from yolox.utils.model_utils import fuse_model\n",
2929
"from boxmot import BotSort\n",
30-
"from boxmot.utils.ops import letterbox\n",
30+
"from boxmot.utils.ops import bytetrack_preprocess\n",
3131
"\n",
3232
"\n",
3333
"# Dictionary for YOLOX model weights URLs\n",
@@ -40,7 +40,7 @@
4040
"}\n",
4141
"\n",
4242
"# Preprocessing pipeline\n",
43-
"preprocess = transforms.Compose([transforms.ToTensor()])\n",
43+
"input_size = [800, 1440]\n",
4444
"device = torch.device('cpu')\n",
4545
"yolox_model = 'yolox_s.pt'\n",
4646
"yolox_model_path = Path(yolox_model)\n",
@@ -59,8 +59,15 @@
5959
"model = fuse_model(model).to(device).eval()\n",
6060
"\n",
6161
"# Initialize tracker\n",
62-
"tracker = BotSort(reid_weights=Path('osnet_x0_25_msmt17.pt'), device=device, half=False)\n",
63-
"\n",
62+
"tracker = BotSort(reid_weights=Path('osnet_x0_25_msmt17.pt'), device=device, half=False)"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"metadata": {},
69+
"outputs": [],
70+
"source": [
6471
"# Video capture setup\n",
6572
"vid = cv2.VideoCapture(0)\n",
6673
"\n",
@@ -70,20 +77,20 @@
7077
" break\n",
7178
"\n",
7279
" # Preprocess frame\n",
73-
" frame_letterbox, ratio, (dw, dh) = letterbox(frame, new_shape=[640, 640], auto=False, scaleFill=True)\n",
74-
" frame_tensor = preprocess(frame_letterbox).unsqueeze(0).to(device)\n",
80+
" frame_img, ratio = bytetrack_preprocess(frame, input_size=input_size)\n",
81+
" frame_tensor = torch.Tensor(frame_img).unsqueeze(0).to(device)\n",
7582
"\n",
7683
" # Detection with YOLOX\n",
7784
" with torch.no_grad():\n",
7885
" dets = model(frame_tensor)\n",
79-
" dets = postprocess(dets, 1, 0.5, 0.2, class_agnostic=True)[0]\n",
86+
" dets = postprocess(dets, 1, 0.5, 0.7, class_agnostic=True)[0]\n",
8087
"\n",
8188
" if dets is not None:\n",
8289
" # Rescale coordinates from letterbox back to the original frame size\n",
83-
" dets[:, 0] = (dets[:, 0] - dw) / ratio[0]\n",
84-
" dets[:, 1] = (dets[:, 1] - dh) / ratio[1]\n",
85-
" dets[:, 2] = (dets[:, 2] - dw) / ratio[0]\n",
86-
" dets[:, 3] = (dets[:, 3] - dh) / ratio[1]\n",
90+
" dets[:, 0] = (dets[:, 0]) / ratio\n",
91+
" dets[:, 1] = (dets[:, 1]) / ratio\n",
92+
" dets[:, 2] = (dets[:, 2]) / ratio\n",
93+
" dets[:, 3] = (dets[:, 3]) / ratio\n",
8794
" dets[:, 4] *= dets[:, 5]\n",
8895
" dets = dets[:, [0, 1, 2, 3, 4, 6]].cpu().numpy()\n",
8996
" else:\n",
@@ -121,7 +128,7 @@
121128
"name": "python",
122129
"nbconvert_exporter": "python",
123130
"pygments_lexer": "ipython3",
124-
"version": "3.11.5"
131+
"version": "3.12.4"
125132
}
126133
},
127134
"nbformat": 4,

tracking/detectors/__init__.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,29 @@
55

66
checker = RequirementsChecker()
77

8+
UL_MODELS = ['yolov8', 'yolov9', 'yolov10', 'yolo11', 'rtdetr', 'sam']
9+
10+
11+
def is_ultralytics_model(yolo_name):
12+
return any(yolo in str(yolo_name) for yolo in UL_MODELS)
13+
14+
15+
def is_yolox_model(yolo_name):
16+
return 'yolox' in str(yolo_name)
17+
18+
19+
def default_imgsz(yolo_name):
20+
if is_ultralytics_model(yolo_name):
21+
return [640, 640]
22+
elif is_yolox_model(yolo_name):
23+
return [800, 1440]
24+
else:
25+
return [640, 640]
26+
827

928
def get_yolo_inferer(yolo_model):
1029

11-
if 'yolox' in str(yolo_model):
30+
if is_yolox_model(yolo_model):
1231
try:
1332
import yolox # for linear_assignment
1433
assert yolox.__version__

tracking/detectors/yolo_interface.py

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class YoloInterface(ABC):
1414
def __call__(self, im):
1515
pass
1616

17+
@abstractmethod
18+
def preprocess(self, ims):
19+
pass
20+
1721
@abstractmethod
1822
def postprocess(self, preds):
1923
pass

tracking/detectors/yolox.py

+52-8
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import torch
55
from ultralytics.engine.results import Results
66
from ultralytics.utils import ops
7+
from ultralytics.models.yolo.detect import DetectionPredictor
78
from yolox.exp import get_exp
89
from yolox.utils import postprocess
910
from yolox.utils.model_utils import fuse_model
1011

1112
from boxmot.utils import logger as LOGGER
13+
from boxmot.utils.ops import bytetrack_preprocess
1214
from tracking.detectors.yolo_interface import YoloInterface
1315

1416
# default model weigths for these model names
@@ -48,6 +50,7 @@ class YoloXStrategy(YoloInterface):
4850
def __init__(self, model, device, args):
4951

5052
self.args = args
53+
self.imgsz = args.imgsz
5154
self.pt = False
5255
self.stride = 32 # max stride in YOLOX
5356

@@ -80,25 +83,64 @@ def __init__(self, model, device, args):
8083
map_location=torch.device('cpu')
8184
)
8285

86+
self.device = device
8387
self.model = exp.get_model()
8488
self.model.eval()
8589
self.model.load_state_dict(ckpt["model"])
8690
self.model = fuse_model(self.model)
87-
self.model.to(device)
91+
self.model.to(self.device)
8892
self.model.eval()
93+
self.im_paths = []
94+
self._preproc_data = []
8995

9096
@torch.no_grad()
9197
def __call__(self, im, augment, visualize, embed):
98+
if isinstance(im, list):
99+
if len(im[0].shape) == 3:
100+
im = torch.stack(im)
101+
else:
102+
im = torch.vstack(im)
103+
104+
if len(im.shape) == 3:
105+
im = im.unsqueeze(0)
106+
107+
assert len(im.shape) == 4, f"Expected 4D tensor as input, got {im.shape}"
108+
92109
preds = self.model(im)
93110
return preds
94111

95112
def warmup(self, imgsz):
96113
pass
97114

98-
def postprocess(self, path, preds, im, im0s):
115+
def update_im_paths(self, predictor: DetectionPredictor):
116+
"""
117+
This function saves image paths for the current batch,
118+
being passed as callback on_predict_batch_start
119+
"""
120+
assert (isinstance(predictor, DetectionPredictor),
121+
"Only ultralytics predictors are supported")
122+
self.im_paths = predictor.batch[0]
123+
124+
def preprocess(self, im) -> torch.Tensor:
125+
assert isinstance(im, list)
126+
im_preprocessed = []
127+
self._preproc_data = []
128+
for i, img in enumerate(im):
129+
img_pre, ratio = bytetrack_preprocess(img, input_size=self.imgsz)
130+
img_pre = torch.Tensor(img_pre).unsqueeze(0).to(self.device)
131+
132+
im_preprocessed.append(img_pre)
133+
self._preproc_data.append(ratio)
134+
135+
im_preprocessed = torch.vstack(im_preprocessed)
136+
137+
return im_preprocessed
138+
139+
def postprocess(self, preds, im, im0s):
99140

100141
results = []
101142
for i, pred in enumerate(preds):
143+
im_path = self.im_paths[i] if len(self.im_paths) else ""
102144

103145
pred = postprocess(
104146
pred.unsqueeze(0), # YOLOX postprocessor expects 3D arary
@@ -111,25 +153,27 @@ def postprocess(self, path, preds, im, im0s):
111153
if pred is None:
112154
pred = torch.empty((0, 6))
113155
r = Results(
114-
path=path,
156+
path=im_path,
115157
boxes=pred,
116158
orig_img=im0s[i],
117159
names=self.names
118160
)
119161
results.append(r)
120162
else:
121-
# (x, y, x, y, conf, obj, cls) --> (x, y, x, y, conf, cls)
122-
pred[:, 4] = pred[:, 4] * pred[:, 5]
163+
ratio = self._preproc_data[i]
164+
pred[:, 0] = pred[:, 0] / ratio
165+
pred[:, 1] = pred[:, 1] / ratio
166+
pred[:, 2] = pred[:, 2] / ratio
167+
pred[:, 3] = pred[:, 3] / ratio
168+
pred[:, 4] *= pred[:, 5]
123169
pred = pred[:, [0, 1, 2, 3, 4, 6]]
124170

125-
pred[:, :4] = ops.scale_boxes(im.shape[2:], pred[:, :4], im0s[i].shape)
126-
127171
# filter boxes by classes
128172
if self.args.classes:
129173
pred = pred[torch.isin(pred[:, 5].cpu(), torch.as_tensor(self.args.classes))]
130174

131175
r = Results(
132-
path=path,
176+
path=im_path,
133177
boxes=pred,
134178
orig_img=im0s[i],
135179
names=self.names

tracking/track.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from boxmot.tracker_zoo import create_tracker
1313
from boxmot.utils import ROOT, WEIGHTS, TRACKER_CONFIGS
1414
from boxmot.utils.checks import RequirementsChecker
15-
from tracking.detectors import get_yolo_inferer
15+
from tracking.detectors import (get_yolo_inferer, default_imgsz,
16+
is_ultralytics_model, is_yolox_model)
1617

1718
checker = RequirementsChecker()
1819
checker.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install
@@ -56,11 +57,13 @@ def on_predict_start(predictor, persist=False):
5657

5758
@torch.no_grad()
5859
def run(args):
59-
60-
ul_models = ['yolov8', 'yolov9', 'yolov10', 'yolo11', 'rtdetr', 'sam']
60+
61+
if args.imgsz is None:
62+
args.imgsz = default_imgsz(args.yolo_model)
6163

6264
yolo = YOLO(
63-
args.yolo_model if any(yolo in str(args.yolo_model) for yolo in ul_models) else 'yolov8n.pt',
65+
args.yolo_model if is_ultralytics_model(args.yolo_model)
66+
else 'yolov8n.pt',
6467
)
6568

6669
results = yolo.track(
@@ -87,15 +90,23 @@ def run(args):
8790

8891
yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True))
8992

90-
if not any(yolo in str(args.yolo_model) for yolo in ul_models):
93+
if not is_ultralytics_model(args.yolo_model):
9194
# replace yolov8 model
9295
m = get_yolo_inferer(args.yolo_model)
93-
model = m(
94-
model=args.yolo_model,
95-
device=yolo.predictor.device,
96-
args=yolo.predictor.args
97-
)
98-
yolo.predictor.model = model
96+
yolo_model = m(model=args.yolo_model, device=yolo.predictor.device,
97+
args=yolo.predictor.args)
98+
yolo.predictor.model = yolo_model
99+
100+
# If current model is YOLOX, change the preprocess and postprocess
101+
if is_yolox_model(args.yolo_model):
102+
# add callback to save image paths for further processing
103+
yolo.add_callback("on_predict_batch_start",
104+
lambda p: yolo_model.update_im_paths(p))
105+
yolo.predictor.preprocess = (
106+
lambda imgs: yolo_model.preprocess(im=imgs))
107+
yolo.predictor.postprocess = (
108+
lambda preds, im, im0s:
109+
yolo_model.postprocess(preds=preds, im=im, im0s=im0s))
99110

100111
# store custom args in predictor
101112
yolo.predictor.custom_args = args
@@ -112,6 +123,7 @@ def run(args):
112123

113124

114125
def parse_opt():
126+
115127
parser = argparse.ArgumentParser()
116128
parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8n',
117129
help='yolo model path')
@@ -121,7 +133,7 @@ def parse_opt():
121133
help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc')
122134
parser.add_argument('--source', type=str, default='0',
123135
help='file/dir/URL/glob, 0 for webcam')
124-
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640],
136+
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=None,
125137
help='inference size h,w')
126138
parser.add_argument('--conf', type=float, default=0.5,
127139
help='confidence threshold')

tracking/utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ def write_mot_results(txt_path: Path, mot_results: np.ndarray) -> None:
371371
path to the file will be created as well if necessary.
372372
"""
373373
if mot_results is not None:
374-
if mot_results.size != 0:
375-
# Ensure the parent directory of the txt_path exists
376-
txt_path.parent.mkdir(parents=True, exist_ok=True)
374+
# Ensure the parent directory of the txt_path exists
375+
txt_path.parent.mkdir(parents=True, exist_ok=True)
377376

378-
# Ensure the file exists before opening
379-
txt_path.touch(exist_ok=True)
377+
# Ensure the file exists before opening
378+
txt_path.touch(exist_ok=True)
380379

380+
if mot_results.size != 0:
381381
# Open the file in append mode and save the MOT results
382382
with open(str(txt_path), 'a') as file:
383383
np.savetxt(file, mot_results, fmt='%d,%d,%d,%d,%d,%d,%d,%d,%.6f')

0 commit comments

Comments
 (0)