From 4390b71e12a2c2f5e9a6b8bbb5e511e450f168e8 Mon Sep 17 00:00:00 2001 From: Randall Date: Fri, 28 Oct 2022 10:00:35 +0800 Subject: [PATCH] add rknn opt for v6.2 --- README_rkopt_manual.md | 16 +++++ export.py | 76 +++++++++++++++++++++- models/common.py | 126 ++++++++++++++++++++++++++---------- models/common_rk_plug_in.py | 30 +++++++++ models/experimental.py | 3 +- models/yolo.py | 7 ++ 6 files changed, 220 insertions(+), 38 deletions(-) create mode 100644 README_rkopt_manual.md create mode 100644 models/common_rk_plug_in.py diff --git a/README_rkopt_manual.md b/README_rkopt_manual.md new file mode 100644 index 000000000000..cbbb45266517 --- /dev/null +++ b/README_rkopt_manual.md @@ -0,0 +1,16 @@ +# YOLOv5 - rkopt 仓库 + +- 基于 https://github.com/ultralytics/yolov5 代码修改,设配 rknpu 设备的部署优化 +- 切换分支 git checkout {分支名} +- 目前支持分支: + - master + - maxpool/ focus 优化,输出改为个branch分支的输出。以上优化代码使用插入宏实现,不影响原来的训练逻辑,这个优化兼容修改前的权重,故支持官方给的预训练权重。 + + - 修改激活函数 silu 为 relu + + - 训练的相关内容请参考 README.md 说明 + + - 导出模型时 python export.py --rknpu {rk_platform} 即可导出优化模型 + + (rk_platform支持 rk1808, rv1109, rv1126, rk3399pro, rk3566, rk3568, rk3588, rv1103, rv1106) + diff --git a/export.py b/export.py index 595039b24bce..37f9c38f2934 100644 --- a/export.py +++ b/export.py @@ -52,6 +52,18 @@ import warnings from pathlib import Path + +# activate rknn hack +if len(sys.argv)>=3 and '--rknpu' in sys.argv: + _index = sys.argv.index('--rknpu') + if sys.argv[_index+1].upper() in ['RK1808', 'RV1109', 'RV1126','RK3399PRO']: + os.environ['RKNN_model_hack'] = 'npu_1' + elif sys.argv[_index+1].upper() in ['RK3566', 'RK3568', 'RK3588','RK3588S','RV1106','RV1103']: + os.environ['RKNN_model_hack'] = 'npu_2' + else: + assert False,"{} not recognized".format(sys.argv[_index+1]) + + import pandas as pd import torch import yaml @@ -514,11 +526,69 @@ def run( m.onnx_dynamic = dynamic m.export = True + if os.getenv('RKNN_model_hack', '0') == 'npu_1': + from models.common import Focus + from models.common import Conv + from models.common_rk_plug_in import surrogate_focus + if isinstance(model.model[0], Focus): + # For yolo v5 version + surrogate_focous = surrogate_focus(int(model.model[0].conv.conv.weight.shape[1]/4), + model.model[0].conv.conv.weight.shape[0], + k=tuple(model.model[0].conv.conv.weight.shape[2:4]), + s=model.model[0].conv.conv.stride, + p=model.model[0].conv.conv.padding, + g=model.model[0].conv.conv.groups, + act=True) + surrogate_focous.conv.conv.weight = model.model[0].conv.conv.weight + surrogate_focous.conv.conv.bias = model.model[0].conv.conv.bias + surrogate_focous.conv.act = model.model[0].conv.act + temp_i = model.model[0].i + temp_f = model.model[0].f + + model.model[0] = surrogate_focous + model.model[0].i = temp_i + model.model[0].f = temp_f + model.model[0].eval() + elif isinstance(model.model[0], Conv) and model.model[0].conv.kernel_size == (6, 6): + # For yolo v6 version + surrogate_focous = surrogate_focus(model.model[0].conv.weight.shape[1], + model.model[0].conv.weight.shape[0], + k=(3,3), # 6/2, 6/2 + s=1, + p=(1,1), # 2/2, 2/2 + g=model.model[0].conv.groups, + act=hasattr(model.model[0], 'act')) + surrogate_focous.conv.conv.weight[:,:3,:,:] = model.model[0].conv.weight[:,:,::2,::2] + surrogate_focous.conv.conv.weight[:,3:6,:,:] = model.model[0].conv.weight[:,:,1::2,::2] + surrogate_focous.conv.conv.weight[:,6:9,:,:] = model.model[0].conv.weight[:,:,::2,1::2] + surrogate_focous.conv.conv.weight[:,9:,:,:] = model.model[0].conv.weight[:,:,1::2,1::2] + surrogate_focous.conv.conv.bias = model.model[0].conv.bias + surrogate_focous.conv.act = model.model[0].act + temp_i = model.model[0].i + temp_f = model.model[0].f + + model.model[0] = surrogate_focous + model.model[0].i = temp_i + model.model[0].f = temp_f + model.model[0].eval() + + if isinstance(model.model[-1], Detect): + # save anchors + print('---> save anchors for RKNN') + RK_anchors = model.model[-1].stride.reshape(3,1).repeat(1,3).reshape(-1,1)* model.model[-1].anchors.reshape(9,2) + with open('RK_anchors.txt', 'w') as anf: + # anf.write(str(model.model[-1].na)+'\n') + for _v in RK_anchors.numpy().flatten(): + anf.write(str(_v)+'\n') + RK_anchors = RK_anchors.tolist() + print(RK_anchors) + + for _ in range(2): y = model(im) # dry runs if half and not coreml: im, model = im.half(), model.half() # to FP16 - shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape + shape = tuple((y[0] if (isinstance(y, tuple) or (isinstance(y, list))) else y).shape) # model output shape LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") # Exports @@ -599,8 +669,9 @@ def parse_opt(): parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') parser.add_argument('--include', nargs='+', - default=['torchscript', 'onnx'], + default=['torchscript'], help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs') + parser.add_argument('--rknpu', default=None, help='RKNN npu platform') opt = parser.parse_args() print_args(vars(opt)) return opt @@ -613,4 +684,5 @@ def main(opt): if __name__ == "__main__": opt = parse_opt() + del opt.rknpu main(opt) diff --git a/models/common.py b/models/common.py index 17e40e60d7d7..3f1d33ece9f3 100644 --- a/models/common.py +++ b/models/common.py @@ -2,7 +2,7 @@ """ Common modules """ - +import os import json import math import platform @@ -40,7 +40,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) - self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + # self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) def forward(self, x): return self.act(self.bn(self.conv(x))) @@ -120,7 +121,8 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) self.cv4 = Conv(2 * c_, c2, 1, 1) self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) - self.act = nn.SiLU() + # self.act = nn.SiLU() + self.act = nn.ReLU() self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): @@ -189,38 +191,92 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) -class SPP(nn.Module): - # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729 - def __init__(self, c1, c2, k=(5, 9, 13)): - super().__init__() - c_ = c1 // 2 # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) - self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) - - def forward(self, x): - x = self.cv1(x) - with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning - return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) - - -class SPPF(nn.Module): - # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher - def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) - super().__init__() - c_ = c1 // 2 # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = Conv(c_ * 4, c2, 1, 1) - self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) - - def forward(self, x): - x = self.cv1(x) - with warnings.catch_warnings(): - warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning - y1 = self.m(x) - y2 = self.m(y1) - return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1)) +if os.getenv('RKNN_model_hack', '0') == '0': + class SPP(nn.Module): + # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729 + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) +elif os.getenv('RKNN_model_hack', '0') in ['npu_1', 'npu_2']: + # TODO remove this hack when rknn-toolkit1/2 add this optimize rules + class SPP(nn.Module): + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + for value in k: + assert (value%2 == 1) and (value!= 1), "value in [{}] only support odd number for RKNN model hack" + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + y = [x] + for maxpool in self.m: + kernel_size = maxpool.kernel_size + m = x + for i in range(math.floor(kernel_size/2)): + m = torch.nn.functional.max_pool2d(m, 3, 1, 1) + y = [*y, m] + return self.cv2(torch.cat(y, 1)) + + +if os.getenv('RKNN_model_hack', '0') in ['0','npu_2']: + class SPPF(nn.Module): + # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher + def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + y1 = self.m(x) + y2 = self.m(y1) + return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) +elif os.getenv('RKNN_model_hack', '0') == 'npu_1': + class SPPF(nn.Module): + # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher + def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + x = self.cv1(x) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + y1 = self.m(x) + y2 = self.m(y1) + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning + y = [x] + kernel_size = self.m.kernel_size + _3x3_stack = math.floor(kernel_size/2) + for i in range(3): + m = y[-1] + for _ in range(_3x3_stack): + m = torch.nn.functional.max_pool2d(m, 3, 1, 1) + y = [*y, m] + return self.cv2(torch.cat(y, 1)) class Focus(nn.Module): diff --git a/models/common_rk_plug_in.py b/models/common_rk_plug_in.py new file mode 100644 index 000000000000..0cff98239223 --- /dev/null +++ b/models/common_rk_plug_in.py @@ -0,0 +1,30 @@ +# This file contains modules common to various models + +import torch +import torch.nn as nn +from models.common import Conv + + +class surrogate_focus(nn.Module): + # surrogate_focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super(surrogate_focus, self).__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + with torch.no_grad(): + self.convsp = nn.Conv2d(3, 12, (2, 2), groups=1, bias=False, stride=(2, 2)) + self.convsp.weight.data = torch.zeros(self.convsp.weight.shape).float() + for i in range(4): + for j in range(3): + ch = i*3 + j + if ch>=0 and ch<3: + self.convsp.weight[ch:ch+1, j:j+1, 0, 0] = 1 + elif ch>=3 and ch<6: + self.convsp.weight[ch:ch+1, j:j+1, 1, 0] = 1 + elif ch>=6 and ch<9: + self.convsp.weight[ch:ch+1, j:j+1, 0, 1] = 1 + elif ch>=9 and ch<12: + self.convsp.weight[ch:ch+1, j:j+1, 1, 1] = 1 + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(self.convsp(x)) \ No newline at end of file diff --git a/models/experimental.py b/models/experimental.py index cb32d01ba46a..244bd8d76918 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -52,7 +52,8 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kern self.m = nn.ModuleList([ nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)]) self.bn = nn.BatchNorm2d(c2) - self.act = nn.SiLU() + # self.act = nn.SiLU() + self.act = nn.ReLU() def forward(self, x): return self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) diff --git a/models/yolo.py b/models/yolo.py index df4209726e0d..3f45552f3434 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -55,6 +55,10 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer def forward(self, x): z = [] # inference output for i in range(self.nl): + if os.getenv('RKNN_model_hack', '0') != '0': + z.append(torch.sigmoid(self.m[i](x[i]))) + continue + x[i] = self.m[i](x[i]) # conv bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() @@ -74,6 +78,9 @@ def forward(self, x): y = torch.cat((xy, wh, conf), 4) z.append(y.view(bs, -1, self.no)) + if os.getenv('RKNN_model_hack', '0') != '0': + return z + return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):