Skip to content

Commit

Permalink
add complete code for STARK-Lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterBin-IIAU committed Jul 24, 2021
1 parent 169a068 commit 75c05a9
Show file tree
Hide file tree
Showing 14 changed files with 91 additions and 463 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ STARK is implemented purely based on the PyTorch.
## What's new
**July 24, 2021**
- We release an extremely fast version of STARK called **STARK-Lightning** :zap: . It can run at **200~300 FPS** on a RTX TITAN GPU.
Besides, its performance can beat DiMP50, while the model size is even less than that of SiamFC!
Besides, its performance can beat DiMP50, while the model size is even less than that of SiamFC! More details can be found at [STARK_Lightning_En.md](lib/tutorials/STARK_Lightning_En.md)/[中文教程](lib/tutorials/STARK_Lightning_En.md)
**July 23, 2021**
- STARK is accepted by ICCV2021

Expand Down Expand Up @@ -94,6 +94,8 @@ python tracking/train.py --script stark_st2 --config baseline --save_dir . --mod
# STARK-ST101
python tracking/train.py --script stark_st1 --config baseline_R101 --save_dir . --mode multiple --nproc_per_node 8 # STARK-ST101 Stage1
python tracking/train.py --script stark_st2 --config baseline_R101 --save_dir . --mode multiple --nproc_per_node 8 --script_prv stark_st1 --config_prv baseline_R101 # STARK-ST101 Stage2
# STARK-Lightning
python tracking/train.py --script stark_lightning_X_trt --config baseline_rephead_4_lite_search5 --save_dir . --mode multiple --nproc_per_node 8 # STARK-Lightning
```
(Optionally) Debugging training with a single GPU
```
Expand Down
6 changes: 3 additions & 3 deletions lib/models/stark/backbone_X.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from lib.models.stark.repvgg import get_RepVGG_func_by_name
import os
from .swin_transformer import build_swint
from .lighttrack.childnet import build_subnet


class FrozenBatchNorm2d(torch.nn.Module):
Expand Down Expand Up @@ -135,8 +134,9 @@ def build_backbone_x_cnn(cfg, phase='train'):
ckpt_new[k_new] = v
ckpt = ckpt_new
missing_keys, unexpected_keys = backbone.body.load_state_dict(ckpt, strict=False)
print("missing keys:", missing_keys)
print("unexpected keys:", unexpected_keys)
if is_main_process():
print("missing keys:", missing_keys)
print("unexpected keys:", unexpected_keys)

"""freeze some layers"""
if cfg.MODEL.BACKBONE.TYPE != "LightTrack":
Expand Down
4 changes: 2 additions & 2 deletions lib/test/tracker/stark_lightning_X_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lib.models.stark.repvgg import repvgg_model_convert
# for onnxruntime
from lib.test.tracker.stark_utils import PreprocessorX_onnx
# import onnxruntime
import onnxruntime
import multiprocessing


Expand Down Expand Up @@ -170,7 +170,7 @@ def map_box_back(self, pred_box: list, resize_factor: float):


def get_tracker_class():
use_onnx = False
use_onnx = True
if use_onnx:
print("Using onnx model")
return STARK_LightningXtrt_onnx
Expand Down
5 changes: 0 additions & 5 deletions lib/train/actors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from .base_actor import BaseActor
from .stark_s import STARKSActor
from .stark_st import STARKSTActor
from .stark_s_plus import STARKSPLUSActor
from .stark_s_plus_sp import STARKSPLUSSPActor
from .stark_st_plus_sp import STARKSTPLUSSPActor
from .stark_st_plus_sp_debug import STARKSTPLUSSPActor_debug
from .stark_lightningX import STARKLightningXActor
from .stark_lightningXtrt import STARKLightningXtrtActor
from .stark_lightningXtrt_distill import STARKLightningXtrtdistillActor
37 changes: 37 additions & 0 deletions lib/tutorials/STARK_Lightning_Ch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# STARK-Lightning 中文教程
**前言**[ONNXRUNTIME](https://github.com/microsoft/onnxruntime) 是微软开源的一个用于网络推理加速的库,在该教程中我们将教给大家如何将训练好的模型导出成ONNX格式,
并使用ONNXRUNTIME来进一步加速推理,加速后的STARK-Lightning在RTX TITAN上的运行速度可达200~300 FPS!让我们开始吧
## 安装onnx和onnxruntime
如果想在GPU上使用onnxruntime完成推理
```
pip install onnx onnxruntime-gpu==1.6.0
```
- 这里onnxruntime-gpu的版本需要和机器上的CUDA版本还有CUDNN版本适配,版本对应关系请参考https://www.onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html
。在我的电脑上,CUDA版本10.2,CUDNN版本8.0.3,故安装的是onnxruntime-gpu==1.6.0

如果只需要在CPU上使用
```
pip install onnx onnxruntime
```
##ONNX模型转换与推理测试
下载训练好的PyTorch模型权重文件 [STARK_Lightning](https://drive.google.com/file/d/18xxbMKCjWi6Gvn5T4o2w5jIbwd3AWN55/view?usp=sharing)

将训练好的PyTorch模型转换成onnx格式,并测试onnxruntime
```
python tracking/ORT_lightning_X_trt_backbone_bottleneck_pe.py # for the template branch
python tracking/ORT_lightning_X_trt_complete.py # for the search region branch
```
- 模型转换在终端里可以跑通,但是在pycharm里面会报找不到libcudnn8.so的错误,后面就在终端运行吧

在LaSOT上测试转换后的模型(支持多卡推理)
- 首先在lib/test/tracker/stark_lightning_X_trt.py中设置 use_onnx = True, 之后运行
```
python tracking/test.py stark_lightning_X_trt baseline_rephead_4_lite_search5 --threads 8 --num_gpus 2
```
其中num_gpus是想使用的GPU数量,threads是进程数量,我们通常将其设置成GPU数量的4倍。
如果想一个一个视频来跑,可以运行以下指令
```
python tracking/test.py stark_lightning_X_trt baseline_rephead_4_lite_search5 --threads 0 --num_gpus 1
```
- 评估跟踪指标
```python tracking/analysis_results_ITP.py --script stark_lightning_X_trt --config baseline_rephead_4_lite_search5```
39 changes: 39 additions & 0 deletions lib/tutorials/STARK_Lightning_En.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# STARK-Lightning Tutorial
**Introduction**[ONNXRUNTIME](https://github.com/microsoft/onnxruntime) is an open-source library by Microsoft for network inference acceleration. In this tutorial, we will show how to export the trained model to ONNX format
and use ONNXRUNTIME to further accelerate the inference. The accelerated STARK-Lightning can run at 200~300 FPS on a RTX TITAN GPU! let's get started.
## Install onnx and onnxruntime
for inference on GPU
```
pip install onnx onnxruntime-gpu==1.6.0
```
- Here the version of onnxruntime-gpu needs to be compatible to the CUDA version and CUDNN version on the machine. For more details, please refer to https://www.onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html
. For example, on my computer, CUDA version is 10.2, CUDNN version is 8.0.3, so I choose onnxruntime-gpu==1.6.0

for inference only on CPU
```
pip install onnx onnxruntime
```
## ONNX Conversion and Inference
Download trained PyTorch checkpoints [STARK_Lightning](https://drive.google.com/file/d/18xxbMKCjWi6Gvn5T4o2w5jIbwd3AWN55/view?usp=sharing)

Export the trained PyTorch model to onnx format, then test it with onnxruntime
```
python tracking/ORT_lightning_X_trt_backbone_bottleneck_pe.py # for the template branch
python tracking/ORT_lightning_X_trt_complete.py # for the search region branch
```
- The conversion can run successfully in the terminal. However, it leads to an error of "libcudnn8.so is not found" when running in Pycharm.
So please run these two commands in the terminal.

Evaluate the converted onnx model on LaSOT (Support multiple-GPU inference).
- Set ```use_onnx=True``` in lib/test/tracker/stark_lightning_X_trt.py, then run
```
python tracking/test.py stark_lightning_X_trt baseline_rephead_4_lite_search5 --threads 8 --num_gpus 2
```
```num_gpus``` is the the number of GPUs to use,```threads``` is the number of processes. we usually set ```threads``` to be four times ```num_gpus```.
If the user want to run the sequences one by one, you can run the following command
```
python tracking/test.py stark_lightning_X_trt baseline_rephead_4_lite_search5 --threads 0 --num_gpus 1
```
- Evaluate the tracking results
```python tracking/analysis_results_ITP.py --script stark_lightning_X_trt --config baseline_rephead_4_lite_search5```

70 changes: 0 additions & 70 deletions tracking/ORT_TRT_complete_run.py

This file was deleted.

5 changes: 4 additions & 1 deletion tracking/ORT_lightning_X_trt_backbone_bottleneck_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Parse args for training')
parser.add_argument('--script', type=str, default='stark_lightning_X_trt', help='script name')
parser.add_argument('--config', type=str, default='baseline_rephead', help='yaml configure file name')
parser.add_argument('--config', type=str, default='baseline_rephead_4_lite_search5', help='yaml configure file name')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -114,6 +114,9 @@ def to_numpy(tensor):
ort_inputs = {'img_z': to_numpy(img_z),
'mask_z': to_numpy(mask_z)}
# print(onnxruntime.get_device())
# warmup
for i in range(10):
ort_outs = ort_session.run(None, ort_inputs)
s = time.time()
for i in range(N):
ort_outs = ort_session.run(None, ort_inputs)
Expand Down
3 changes: 2 additions & 1 deletion tracking/ORT_lightning_X_trt_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def parse_args():
parser = argparse.ArgumentParser(description='Parse args for training')
parser.add_argument('--script', type=str, default='stark_lightning_X_trt', help='script name')
parser.add_argument('--config', type=str, default='baseline_rephead', help='yaml configure file name')
parser.add_argument('--config', type=str, default='baseline_rephead_4_lite_search5', help='yaml configure file name')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -112,6 +112,7 @@ def to_numpy(tensor):
sz_x = cfg.TEST.SEARCH_SIZE
hw_z = cfg.DATA.TEMPLATE.FEAT_SIZE ** 2
c = cfg.MODEL.HIDDEN_DIM
print(bs, sz_x, hw_z, c)
img_x, mask_x, feat_vec_z, mask_vec_z, pos_vec_z = get_data(bs=bs, sz_x=sz_x, hw_z=hw_z, c=c)
torch_outs = torch_model(img_x, mask_x, feat_vec_z, mask_vec_z, pos_vec_z)
torch.onnx.export(torch_model, # model being run
Expand Down
2 changes: 1 addition & 1 deletion tracking/analysis_results_ITP.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ def parse_args():
trackers = []
trackers.extend(trackerlist(args.script, args.config, "None", None, args.config))

dataset = get_dataset('lasot_lmdb')
dataset = get_dataset('lasot')

print_results(trackers, dataset, 'LaSOT', merge_results=True, plot_types=('success', 'prec', 'norm_prec'))
81 changes: 0 additions & 81 deletions tracking/others_onnx/ORT_lightning_X_trt_backbone.py

This file was deleted.

Loading

0 comments on commit 75c05a9

Please sign in to comment.