forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 1
/
hubconf.py
175 lines (134 loc) Β· 7.79 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# YOLOv5 π by Ultralytics, AGPL-3.0 license
"""
PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5
Usage:
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # official model
model = torch.hub.load('ultralytics/yolov5:master', 'yolov5s') # from branch
model = torch.hub.load('ultralytics/yolov5', 'custom', 'yolov5s.pt') # custom/local model
model = torch.hub.load('.', 'custom', 'yolov5s.pt', source='local') # local repo
"""
import torch
import os
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None, cfg = None):
"""Creates or loads a YOLOv5 model
Arguments:
name (str): model name 'yolov5s' or path 'path/to/best.pt'
pretrained (bool): load pretrained weights into the model
channels (int): number of input channels
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters
Returns:
YOLOv5 model
"""
from pathlib import Path
from models.common import AutoShape, DetectMultiBackend
from models.experimental import attempt_load
from models.yolo import ClassificationModel, DetectionModel, SegmentationModel
from utils.downloads import attempt_download
from utils.general import LOGGER, ROOT, check_requirements, intersect_dicts, logging
from utils.torch_utils import select_device
if not verbose:
LOGGER.setLevel(logging.WARNING)
check_requirements(ROOT / 'requirements.txt', exclude=('opencv-python', 'tensorboard', 'thop'))
name = Path(name)
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
os.environ['cfg'] = cfg
os.environ['nc'] = str(classes)
try:
device = select_device(device)
if pretrained and channels == 3:
try:
model = DetectMultiBackend(path, device=device, fuse=False) # detection model
if autoshape:
if model.pt and isinstance(model.model, ClassificationModel):
LOGGER.warning('WARNING β οΈ YOLOv5 ClassificationModel is not yet AutoShape compatible. '
'You must pass torch tensors in BCHW to this model, i.e. shape(1,3,224,224).')
elif model.pt and isinstance(model.model, SegmentationModel):
LOGGER.warning('WARNING β οΈ YOLOv5 SegmentationModel is not yet AutoShape compatible. '
'You will not be able to run inference with this model.')
else:
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
except Exception as e:
print(e)
model = attempt_load(weights = path, device=device, fuse=False) # arbitrary model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
model = DetectionModel(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
model.load_state_dict(csd, strict=False) # load
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
model = AutoShape(model)
if not verbose:
LOGGER.setLevel(logging.INFO) # reset to default
return model.to(device)
except Exception as e:
help_url = 'https://docs.ultralytics.com/yolov5/tutorials/pytorch_hub_model_loading'
s = f'{e}. Cache may be out of date, try `force_reload=True` or see {help_url} for help.'
raise Exception(s) from e
def custom(path='path/to/model.pt', autoshape=True,classes=80,cfg=None, _verbose=True, device=None):
# YOLOv5 custom or local model
return _create(path, autoshape=autoshape, pretrained=True, classes = classes,cfg=cfg,verbose=_verbose, device=device)
def yolov5n(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-nano model https://github.com/ultralytics/yolov5
return _create('yolov5n', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-small model https://github.com/ultralytics/yolov5
return _create('yolov5s', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-medium model https://github.com/ultralytics/yolov5
return _create('yolov5m', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-large model https://github.com/ultralytics/yolov5
return _create('yolov5l', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
return _create('yolov5x', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5n6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-nano-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5n6', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5s6', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5m6', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5l6', pretrained, channels, classes, autoshape, _verbose, device)
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=True, device=None):
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5x6', pretrained, channels, classes, autoshape, _verbose, device)
if __name__ == '__main__':
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
from utils.general import cv2, print_args
# Argparser
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='yolov5s', help='model name')
opt = parser.parse_args()
print_args(vars(opt))
# Model
model = _create(name=opt.model, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)
# model = custom(path='path/to/model.pt') # custom
# Images
imgs = [
'data/images/zidane.jpg', # filename
Path('data/images/zidane.jpg'), # Path
'https://ultralytics.com/images/zidane.jpg', # URI
cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
Image.open('data/images/bus.jpg'), # PIL
np.zeros((320, 640, 3))] # numpy
# Inference
results = model(imgs, size=320) # batched inference
# Results
results.print()
results.save()