Skip to content

Commit

Permalink
add rknn opt for v6.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Randall committed Oct 28, 2022
1 parent d3ea0df commit 4390b71
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 38 deletions.
16 changes: 16 additions & 0 deletions README_rkopt_manual.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# YOLOv5 - rkopt 仓库

- 基于 https://github.com/ultralytics/yolov5 代码修改,设配 rknpu 设备的部署优化
- 切换分支 git checkout {分支名}
- 目前支持分支:
- master
- maxpool/ focus 优化,输出改为个branch分支的输出。以上优化代码使用插入宏实现,不影响原来的训练逻辑,这个优化兼容修改前的权重,故支持官方给的预训练权重。

- 修改激活函数 silu 为 relu

- 训练的相关内容请参考 README.md 说明

- 导出模型时 python export.py --rknpu {rk_platform} 即可导出优化模型

(rk_platform支持 rk1808, rv1109, rv1126, rk3399pro, rk3566, rk3568, rk3588, rv1103, rv1106)

76 changes: 74 additions & 2 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@
import warnings
from pathlib import Path


# activate rknn hack
if len(sys.argv)>=3 and '--rknpu' in sys.argv:
_index = sys.argv.index('--rknpu')
if sys.argv[_index+1].upper() in ['RK1808', 'RV1109', 'RV1126','RK3399PRO']:
os.environ['RKNN_model_hack'] = 'npu_1'
elif sys.argv[_index+1].upper() in ['RK3566', 'RK3568', 'RK3588','RK3588S','RV1106','RV1103']:
os.environ['RKNN_model_hack'] = 'npu_2'
else:
assert False,"{} not recognized".format(sys.argv[_index+1])


import pandas as pd
import torch
import yaml
Expand Down Expand Up @@ -514,11 +526,69 @@ def run(
m.onnx_dynamic = dynamic
m.export = True

if os.getenv('RKNN_model_hack', '0') == 'npu_1':
from models.common import Focus
from models.common import Conv
from models.common_rk_plug_in import surrogate_focus
if isinstance(model.model[0], Focus):
# For yolo v5 version
surrogate_focous = surrogate_focus(int(model.model[0].conv.conv.weight.shape[1]/4),
model.model[0].conv.conv.weight.shape[0],
k=tuple(model.model[0].conv.conv.weight.shape[2:4]),
s=model.model[0].conv.conv.stride,
p=model.model[0].conv.conv.padding,
g=model.model[0].conv.conv.groups,
act=True)
surrogate_focous.conv.conv.weight = model.model[0].conv.conv.weight
surrogate_focous.conv.conv.bias = model.model[0].conv.conv.bias
surrogate_focous.conv.act = model.model[0].conv.act
temp_i = model.model[0].i
temp_f = model.model[0].f

model.model[0] = surrogate_focous
model.model[0].i = temp_i
model.model[0].f = temp_f
model.model[0].eval()
elif isinstance(model.model[0], Conv) and model.model[0].conv.kernel_size == (6, 6):
# For yolo v6 version
surrogate_focous = surrogate_focus(model.model[0].conv.weight.shape[1],
model.model[0].conv.weight.shape[0],
k=(3,3), # 6/2, 6/2
s=1,
p=(1,1), # 2/2, 2/2
g=model.model[0].conv.groups,
act=hasattr(model.model[0], 'act'))
surrogate_focous.conv.conv.weight[:,:3,:,:] = model.model[0].conv.weight[:,:,::2,::2]
surrogate_focous.conv.conv.weight[:,3:6,:,:] = model.model[0].conv.weight[:,:,1::2,::2]
surrogate_focous.conv.conv.weight[:,6:9,:,:] = model.model[0].conv.weight[:,:,::2,1::2]
surrogate_focous.conv.conv.weight[:,9:,:,:] = model.model[0].conv.weight[:,:,1::2,1::2]
surrogate_focous.conv.conv.bias = model.model[0].conv.bias
surrogate_focous.conv.act = model.model[0].act
temp_i = model.model[0].i
temp_f = model.model[0].f

model.model[0] = surrogate_focous
model.model[0].i = temp_i
model.model[0].f = temp_f
model.model[0].eval()

if isinstance(model.model[-1], Detect):
# save anchors
print('---> save anchors for RKNN')
RK_anchors = model.model[-1].stride.reshape(3,1).repeat(1,3).reshape(-1,1)* model.model[-1].anchors.reshape(9,2)
with open('RK_anchors.txt', 'w') as anf:
# anf.write(str(model.model[-1].na)+'\n')
for _v in RK_anchors.numpy().flatten():
anf.write(str(_v)+'\n')
RK_anchors = RK_anchors.tolist()
print(RK_anchors)


for _ in range(2):
y = model(im) # dry runs
if half and not coreml:
im, model = im.half(), model.half() # to FP16
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
shape = tuple((y[0] if (isinstance(y, tuple) or (isinstance(y, list))) else y).shape) # model output shape
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")

# Exports
Expand Down Expand Up @@ -599,8 +669,9 @@ def parse_opt():
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include',
nargs='+',
default=['torchscript', 'onnx'],
default=['torchscript'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
parser.add_argument('--rknpu', default=None, help='RKNN npu platform')
opt = parser.parse_args()
print_args(vars(opt))
return opt
Expand All @@ -613,4 +684,5 @@ def main(opt):

if __name__ == "__main__":
opt = parse_opt()
del opt.rknpu
main(opt)
126 changes: 91 additions & 35 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Common modules
"""

import os
import json
import math
import platform
Expand Down Expand Up @@ -40,7 +40,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
# self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
self.act = nn.ReLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

def forward(self, x):
return self.act(self.bn(self.conv(x)))
Expand Down Expand Up @@ -120,7 +121,8 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.SiLU()
# self.act = nn.SiLU()
self.act = nn.ReLU()
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

def forward(self, x):
Expand Down Expand Up @@ -189,38 +191,92 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))


class SPP(nn.Module):
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))


class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
if os.getenv('RKNN_model_hack', '0') == '0':
class SPP(nn.Module):
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
elif os.getenv('RKNN_model_hack', '0') in ['npu_1', 'npu_2']:
# TODO remove this hack when rknn-toolkit1/2 add this optimize rules
class SPP(nn.Module):
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
for value in k:
assert (value%2 == 1) and (value!= 1), "value in [{}] only support odd number for RKNN model hack"

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y = [x]
for maxpool in self.m:
kernel_size = maxpool.kernel_size
m = x
for i in range(math.floor(kernel_size/2)):
m = torch.nn.functional.max_pool2d(m, 3, 1, 1)
y = [*y, m]
return self.cv2(torch.cat(y, 1))


if os.getenv('RKNN_model_hack', '0') in ['0','npu_2']:
class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
elif os.getenv('RKNN_model_hack', '0') == 'npu_1':
class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)

with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y = [x]
kernel_size = self.m.kernel_size
_3x3_stack = math.floor(kernel_size/2)
for i in range(3):
m = y[-1]
for _ in range(_3x3_stack):
m = torch.nn.functional.max_pool2d(m, 3, 1, 1)
y = [*y, m]
return self.cv2(torch.cat(y, 1))


class Focus(nn.Module):
Expand Down
30 changes: 30 additions & 0 deletions models/common_rk_plug_in.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file contains modules common to various models

import torch
import torch.nn as nn
from models.common import Conv


class surrogate_focus(nn.Module):
# surrogate_focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(surrogate_focus, self).__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)

with torch.no_grad():
self.convsp = nn.Conv2d(3, 12, (2, 2), groups=1, bias=False, stride=(2, 2))
self.convsp.weight.data = torch.zeros(self.convsp.weight.shape).float()
for i in range(4):
for j in range(3):
ch = i*3 + j
if ch>=0 and ch<3:
self.convsp.weight[ch:ch+1, j:j+1, 0, 0] = 1
elif ch>=3 and ch<6:
self.convsp.weight[ch:ch+1, j:j+1, 1, 0] = 1
elif ch>=6 and ch<9:
self.convsp.weight[ch:ch+1, j:j+1, 0, 1] = 1
elif ch>=9 and ch<12:
self.convsp.weight[ch:ch+1, j:j+1, 1, 1] = 1

def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(self.convsp(x))
3 changes: 2 additions & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kern
self.m = nn.ModuleList([
nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU()
# self.act = nn.SiLU()
self.act = nn.ReLU()

def forward(self, x):
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
Expand Down
7 changes: 7 additions & 0 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
def forward(self, x):
z = [] # inference output
for i in range(self.nl):
if os.getenv('RKNN_model_hack', '0') != '0':
z.append(torch.sigmoid(self.m[i](x[i])))
continue

x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
Expand All @@ -74,6 +78,9 @@ def forward(self, x):
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, -1, self.no))

if os.getenv('RKNN_model_hack', '0') != '0':
return z

return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
Expand Down

0 comments on commit 4390b71

Please sign in to comment.