Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TORCHSCRIPT export for NCNN #635

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion projects/easydeploy/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .backend import MMYOLOBackend
from .backendwrapper import ORTWrapper, TRTWrapper
from .model import DeployModel

__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper']
__all__ = ['DeployModel', 'TRTWrapper', 'ORTWrapper', 'MMYOLOBackend']
43 changes: 30 additions & 13 deletions projects/easydeploy/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,34 @@
from mmyolo.models import RepVGGBlock
from mmyolo.models.dense_heads import (RTMDetHead, YOLOv5Head, YOLOv7Head,
YOLOXHead)
from mmyolo.models.layers import CSPLayerWithTwoConv
from ..backbone import DeployC2f, DeployFocus, GConvFocus, NcnnFocus
from ..backbone import DeployFocus, GConvFocus, NcnnFocus
from ..bbox_code import (rtmdet_bbox_decoder, yolov5_bbox_decoder,
yolox_bbox_decoder)
from ..nms import batched_nms, efficient_nms, onnx_nms
from .backend import MMYOLOBackend


class DeployModel(nn.Module):
transpose = False

def __init__(self,
baseModel: nn.Module,
backend: MMYOLOBackend,
postprocess_cfg: Optional[ConfigDict] = None):
super().__init__()
self.baseModel = baseModel
self.baseHead = baseModel.bbox_head
self.backend = backend
if postprocess_cfg is None:
self.with_postprocess = False
else:
self.with_postprocess = True
self.baseHead = baseModel.bbox_head
self.__init_sub_attributes()
self.detector_type = type(self.baseHead)
self.pre_top_k = postprocess_cfg.get('pre_top_k', 1000)
self.keep_top_k = postprocess_cfg.get('keep_top_k', 100)
self.iou_threshold = postprocess_cfg.get('iou_threshold', 0.65)
self.score_threshold = postprocess_cfg.get('score_threshold', 0.25)
self.backend = postprocess_cfg.get('backend', 1)
self.__switch_deploy()

def __init_sub_attributes(self):
Expand All @@ -47,21 +49,25 @@ def __init_sub_attributes(self):
self.num_classes = self.baseHead.num_classes

def __switch_deploy(self):
if self.backend in (MMYOLOBackend.HORIZONX3, MMYOLOBackend.NCNN,
MMYOLOBackend.TORCHSCRIPT):
self.transpose = True
for layer in self.baseModel.modules():
if isinstance(layer, RepVGGBlock):
layer.switch_to_deploy()
elif isinstance(layer, Focus):
# onnxruntime tensorrt8 tensorrt7
if self.backend in (1, 2, 3):
# onnxruntime openvino tensorrt8 tensorrt7
if self.backend in (MMYOLOBackend.ONNXRUNTIME,
MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8,
MMYOLOBackend.TENSORRT7):
self.baseModel.backbone.stem = DeployFocus(layer)
# ncnn
elif self.backend == 4:
elif self.backend == MMYOLOBackend.NCNN:
self.baseModel.backbone.stem = NcnnFocus(layer)
# switch focus to group conv
else:
self.baseModel.backbone.stem = GConvFocus(layer)
elif isinstance(layer, CSPLayerWithTwoConv):
setattr(layer, '__class__', DeployC2f)

def pred_by_feat(self,
cls_scores: List[Tensor],
Expand Down Expand Up @@ -129,11 +135,11 @@ def pred_by_feat(self,
self.score_threshold, self.pre_top_k, self.keep_top_k)

def select_nms(self):
if self.backend == 1:
if self.backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO):
nms_func = onnx_nms
elif self.backend == 2:
elif self.backend == MMYOLOBackend.TENSORRT8:
nms_func = efficient_nms
elif self.backend == 3:
elif self.backend == MMYOLOBackend.TENSORRT7:
nms_func = batched_nms
else:
raise NotImplementedError
Expand All @@ -147,4 +153,15 @@ def forward(self, inputs: Tensor):
if self.with_postprocess:
return self.pred_by_feat(*neck_outputs)
else:
return neck_outputs
outputs = []
if self.transpose:
for feats in zip(*neck_outputs):
if self.backend in (MMYOLOBackend.NCNN,
MMYOLOBackend.TORCHSCRIPT):
outputs.append(
torch.cat(
[feat.permute(0, 2, 3, 1) for feat in feats],
-1))
else:
outputs.append(torch.cat(feats, 1).permute(0, 2, 3, 1))
return tuple(outputs)
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch
from mmdet.apis import init_detector
from mmengine.config import ConfigDict
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist

from projects.easydeploy.model import DeployModel
from projects.easydeploy.model import DeployModel, MMYOLOBackend

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
Expand Down Expand Up @@ -42,7 +43,10 @@ def parse_args():
parser.add_argument(
'--opset', type=int, default=11, help='ONNX opset version')
parser.add_argument(
'--backend', type=int, default=1, help='Backend for export onnx')
'--backend',
type=str,
default='onnxruntime',
help='Backend for export onnx')
parser.add_argument(
'--pre-topk',
type=int,
Expand Down Expand Up @@ -77,7 +81,15 @@ def build_model_from_cfg(config_path, checkpoint_path, device):
def main():
args = parse_args()
mkdir_or_exist(args.work_dir)

backend = MMYOLOBackend(args.backend.lower())
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
if not args.model_only:
print_log('Export ONNX with bbox decoder and NMS ...')
else:
args.model_only = True
print_log(f'Can not export postprocess for {args.backend.lower()}.\n'
f'Set "args.model_only=True" default.')
if args.model_only:
postprocess_cfg = None
output_names = None
Expand All @@ -86,21 +98,22 @@ def main():
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold,
backend=args.backend)
score_threshold=args.score_threshold)
output_names = ['num_dets', 'boxes', 'scores', 'labels']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)

deploy_model = DeployModel(
baseModel=baseModel, postprocess_cfg=postprocess_cfg)
baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg)
deploy_model.eval()

fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)

save_onnx_path = os.path.join(args.work_dir, 'end2end.onnx')
save_onnx_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'onnx'))
# export onnx
with BytesIO() as f:
torch.onnx.export(
Expand All @@ -115,7 +128,7 @@ def main():
onnx.checker.check_model(onnx_model)

# Fix tensorrt onnx output shape, just for view
if args.backend in (2, 3):
if backend in (MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT8):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
Expand All @@ -130,9 +143,9 @@ def main():
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print(f'Simplify failure: {e}')
print_log(f'Simplify failure: {e}')
onnx.save(onnx_model, save_onnx_path)
print(f'ONNX export success, save into {save_onnx_path}')
print_log(f'ONNX export success, save into {save_onnx_path}')


if __name__ == '__main__':
Expand Down
71 changes: 71 additions & 0 deletions projects/easydeploy/tools/export_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import argparse
import os
import warnings

import torch
from mmdet.apis import init_detector
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist

from projects.easydeploy.model import DeployModel, MMYOLOBackend

warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args


def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model


def main():
args = parse_args()
mkdir_or_exist(args.work_dir)

baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)

deploy_model = DeployModel(
baseModel=baseModel,
backend=MMYOLOBackend.TORCHSCRIPT,
postprocess_cfg=None)
deploy_model.eval()

fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)

save_torchscript_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'torchscript'))
mod = torch.jit.trace(deploy_model, fake_input)
mod.save(save_torchscript_path)
print_log(f'TORCHSCRIPT export success, save into {save_torchscript_path}')


if __name__ == '__main__':
main()