Skip to content

Commit

Permalink
Merge pull request #3 from zen-xingle/main
Browse files Browse the repository at this point in the history
support export rknn optimized type torchscript model
  • Loading branch information
airockchip authored Aug 11, 2023
2 parents feeca81 + 707ba98 commit 6602a0f
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 7 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#### Get model optimized for RKNN
Exports model with optimization for RKNN, please refer here [RKOPT_README.md](./RKOPT_README.md)



---

<div align="center">
<p>
<a href="https://ultralytics.com/yolov8" target="_blank">
Expand Down
7 changes: 7 additions & 0 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#### 导出适配 rknpu 的模型
适配 rknpu 的模型结构可以在 npu 上获得更高的推理效率。关于导出细节请参考 [RKOPT_README_zh.md](./RKOPT_README_zh.md)



---

<div align="center">
<p>
<a href="https://ultralytics.com/yolov8" target="_blank">
Expand Down
52 changes: 52 additions & 0 deletions RKOPT_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## Description - export optimized model for RKNPU

### 1. Model structure Adjustment

- The dfl structure has poor performance on NPU processing, moved outside the model.

Assuming that there are 6000 candidate frames, the original model places the dfl structure before the "box confidence filter", then the 6000 candidate frames need to be calculated through dfl calculation. If the dfl structure is placed after the "box confidence filter", Assuming that there are 100 candidate boxes left after filtering, the calculation amount of the dfl part is reduced to 100, which greatly reduces the occupancy of computing resources and bandwidth resources.



- Assuming that there are 6000 candidate boxes and the detection category is 80, the threshold retrieval operation needs to be repeated 6000* 80 ~= 4.8*10^5 times, which takes a lot of time. Therefore, when exporting the model, an additional summation operation for 80 types of detection targets is added to the model to quickly filter the confidence. (This structure is effective in some cases, related to the training results of the model)

You can comment out this part of the optimization at line 52 to line 54 of **ultralytics/nn/modules/head.py**, and the corresponding code is:

```
cls_sum = torch.clamp(y[-1].sum(1, keepdim=True), 0, 1)
y.append(cls_sum)
```




- (optional) In fact, if the user refers to the structure of yolov5, the output of 80 categories is adjusted to 80+1 category, and the newly added category 1 is used as the confidence level of the control box, which acts as a filter. In this way, the post-processing can reduce the number of logical judgments by 10 to 40 times when the CPU executes the threshold judgment.



### 2. Export model operation

After meeting the environmental requirements of ./requirements.txt, execute the following statement to export the model

```
# Adjust the model file path in ./ultralytics/cfg/default.yaml, the default is yolov8n.pt, if you train the model yourself, please transfer to the corresponding path
export PYTHONPATH=./
python ./ultralytics/engine/exporter.py
After execution, the _rknnopt.torchscript model will be generated. If the original model is yolov8n.pt, generate the yolov8n_rknnopt.torchscript model.
```



Export Code Changes Explained

- In ./ultralytics/cfg/default.yaml, there is a parameter **format** for exporting the model format, and the support for 'rknn' has been added
- When the model is inferred to Detect Head, format=='rknn' takes effect, dfl and post-processing are skipped,
- It should be noted that this repository has not tested the optimization method of pose head and segment head, which is currently not supported. You can try to change it yourself if needed.



### 3. Transfer to RKNN model, Python demo, C demo

Please refer to https://github.com/airockchip/rknn_model_zoo/tree/main/models/CV/object_detection/yolo
53 changes: 53 additions & 0 deletions RKOPT_README_zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
## 导出 RKNPU 适配模型说明

### 1.模型结构上的调整

- dfl 结构在 NPU 处理上性能不佳,移至模型外部。

假设有6000个候选框,原模型将 dfl 结构放置于 ''框置信度过滤" 前,则 6000 个候选框都需要计算经过 dfl 计算;而将 dfl 结构放置于 ''框置信度过滤" 后,假设过滤后剩 100 个候选框,则dfl部分计算量减少至 100 个,大幅减少了计算资源、带宽资源的占用。



- 假设有 6000 个候选框,检测类别是 80 类,则阈值检索操作需要重复 6000* 80 ~= 4.8*10^5 次,占据了较多耗时。故导出模型时,在模型中额外新增了对 80 类检测目标进行求和操作,用于快速过滤置信度。(该结构在部分情况下对有效,与模型的训练结果有关)

可以在 **./ultralytics/nn/modules/head.py** 52行~54行的位置,注释掉这部分优化,对应的代码是:

```
cls_sum = torch.clamp(y[-1].sum(1, keepdim=True), 0, 1)
y.append(cls_sum)
```




- (optional) 实际上,用户可以参考yolov5的结构,将80类输出调整为 80+1类,新增的1类作为控制框的置信度,起到快速过滤作用。这样后处理在cpu执行阈值判断的时候,就可以减少 10~40倍的逻辑判断次数。



### 2.导出模型操作

在满足 ./requirements.txt 的环境要求后,执行以下语句导出模型

```
# 调整 ./ultralytics/cfg/default.yaml 中 model 文件路径,默认为 yolov8n.pt,若自己训练模型,请调接至对应的路径
export PYTHONPATH=./
python ./ultralytics/engine/exporter.py
执行完毕后,会生成 _rknnopt.torchscript 模型。假如原始模型为 yolov8n.pt,则生成 yolov8n_rknnopt.torchscript 模型。
```



导出代码改动解释

- ./ultralytics/cfg/default.yaml 导出模型格式的参数 format, 添加了 'rknn' 的支持
- 模型推理到 Detect Head 时,format=='rknn'生效,跳过dfl与后处理,输出推理结果
- 需要注意,本仓库没有测试对 pose head, segment head 的优化方式,目前暂不支持,如果需求可尝试自行更改。



### 3.转RKNN模型、Python demo、C demo

请参考 https://github.com/airockchip/rknn_model_zoo/tree/main/models/CV/object_detection/yolo

4 changes: 2 additions & 2 deletions ultralytics/cfg/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark

# Train settings -------------------------------------------------------------------------------------------------------
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
model: yolov8n.pt # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: # (str, optional) path to data file, i.e. coco128.yaml
epochs: 100 # (int) number of epochs to train for
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
Expand Down Expand Up @@ -68,7 +68,7 @@ retina_masks: False # (bool) use high-resolution segmentation masks
boxes: True # (bool) Show boxes in segmentation predictions

# Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
format: rknn # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
keras: False # (bool) use Kera=s
optimize: False # (bool) TorchScript: optimize for mobile
int8: False # (bool) CoreML/TF INT8 quantization
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def apply_bboxes(self, bboxes, M):
# Create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T

def apply_segments(self, segments, M):
"""
Expand Down
25 changes: 22 additions & 3 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ def export_formats():
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
['TensorFlow.js', 'tfjs', '_web_model', True, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
['ncnn', 'ncnn', '_ncnn_model', True, True], ]
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
['ncnn', 'ncnn', '_ncnn_model', True, True],
['RKNN', 'rknn', '_rknnopt.torchscript', True, False],
]
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])


Expand Down Expand Up @@ -157,7 +159,8 @@ def __call__(self, model=None):
flags = [x == format for x in fmts]
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, rknn = flags # export booleans


# Load PyTorch model
self.device = select_device('cpu' if self.args.device is None else self.args.device)
Expand Down Expand Up @@ -262,6 +265,8 @@ def __call__(self, model=None):
f[10], _ = self.export_paddle()
if ncnn: # ncnn
f[11], _ = self.export_ncnn()
if rknn:
f[12], _ = self.export_rknn()

# Finish
f = [str(x) for x in f if x] # filter out '' and None
Expand Down Expand Up @@ -297,6 +302,20 @@ def export_torchscript(self, prefix=colorstr('TorchScript:')):
ts.save(str(f), _extra_files=extra_files)
return f, None

@try_export
def export_rknn(self, prefix=colorstr('RKNN:')):
"""YOLOv8 RKNN model export."""
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')

ts = torch.jit.trace(self.model, self.im, strict=False)
f = str(self.file).replace(self.file.suffix, f'_rknnopt.torchscript')
torch.jit.save(ts, str(f))

LOGGER.info(f'\n{prefix} feed {f} to RKNN-Toolkit or RKNN-Toolkit2 to generate RKNN model.\n'
'Refer https://github.com/airockchip/rknn_model_zoo/tree/main/models/CV/object_detection/yolo')
return f, None


@try_export
def export_onnx(self, prefix=colorstr('ONNX:')):
"""YOLOv8 ONNX export."""
Expand Down
4 changes: 3 additions & 1 deletion ultralytics/nn/autobackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton, rknn = \
self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
Expand Down Expand Up @@ -385,6 +385,8 @@ def forward(self, im, augment=False, visualize=False):
mat_out = self.pyncnn.Mat()
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif getattr(self, 'rknn', False):
assert "for inference, please refer to https://github.com/airockchip/rknn_model_zoo/tree/main/models/CV/object_detection/yolo"
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
Expand Down
11 changes: 11 additions & 0 deletions ultralytics/nn/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def __init__(self, nc=80, ch=()): # detection layer
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
shape = x[0].shape # BCHW

if self.export and self.format == 'rknn':
y = []
for i in range(self.nl):
y.append(self.cv2[i](x[i]))
cls = torch.sigmoid(self.cv3[i](x[i]))
cls_sum = torch.clamp(y[-1].sum(1, keepdim=True), 0, 1)
y.append(cls)
y.append(cls_sum)
return y

for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training:
Expand Down

0 comments on commit 6602a0f

Please sign in to comment.