-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
convert_votenet_checkpoints.py
153 lines (122 loc) · 4.97 KB
/
convert_votenet_checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
import torch
from mmengine import Config
from mmengine.runner import load_state_dict
from mmdet3d.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser(
description='MMDet3D upgrade model version(before v0.6.0) of VoteNet')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='path of the output checkpoint file')
args = parser.parse_args()
return args
def parse_config(config_strings):
"""Parse config from strings.
Args:
config_strings (string): strings of model config.
Returns:
Config: model config
"""
temp_file = tempfile.NamedTemporaryFile()
config_path = f'{temp_file.name}.py'
with open(config_path, 'w') as f:
f.write(config_strings)
config = Config.fromfile(config_path)
# Update backbone config
if 'pool_mod' in config.model.backbone:
config.model.backbone.pop('pool_mod')
if 'sa_cfg' not in config.model.backbone:
config.model.backbone['sa_cfg'] = dict(
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True)
if 'type' not in config.model.bbox_head.vote_aggregation_cfg:
config.model.bbox_head.vote_aggregation_cfg['type'] = 'PointSAModule'
# Update bbox_head config
if 'pred_layer_cfg' not in config.model.bbox_head:
config.model.bbox_head['pred_layer_cfg'] = dict(
in_channels=128, shared_conv_channels=(128, 128), bias=True)
if 'feat_channels' in config.model.bbox_head:
config.model.bbox_head.pop('feat_channels')
if 'vote_moudule_cfg' in config.model.bbox_head:
config.model.bbox_head['vote_module_cfg'] = config.model.bbox_head.pop(
'vote_moudule_cfg')
if config.model.bbox_head.vote_aggregation_cfg.use_xyz:
config.model.bbox_head.vote_aggregation_cfg.mlp_channels[0] -= 3
temp_file.close()
return config
def main():
"""Convert keys in checkpoints for VoteNet.
There can be some breaking changes during the development of mmdetection3d,
and this tool is used for upgrading checkpoints trained with old versions
(before v0.6.0) to the latest one.
"""
args = parse_args()
checkpoint = torch.load(args.checkpoint)
cfg = parse_config(checkpoint['meta']['config'])
# Build the model and load checkpoint
model = MODELS.build(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg'))
orig_ckpt = checkpoint['state_dict']
converted_ckpt = orig_ckpt.copy()
if cfg['dataset_type'] == 'ScanNetDataset':
NUM_CLASSES = 18
elif cfg['dataset_type'] == 'SUNRGBDDataset':
NUM_CLASSES = 10
else:
raise NotImplementedError
RENAME_PREFIX = {
'bbox_head.conv_pred.0': 'bbox_head.conv_pred.shared_convs.layer0',
'bbox_head.conv_pred.1': 'bbox_head.conv_pred.shared_convs.layer1'
}
DEL_KEYS = [
'bbox_head.conv_pred.0.bn.num_batches_tracked',
'bbox_head.conv_pred.1.bn.num_batches_tracked'
]
EXTRACT_KEYS = {
'bbox_head.conv_pred.conv_cls.weight':
('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]),
'bbox_head.conv_pred.conv_cls.bias':
('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]),
'bbox_head.conv_pred.conv_reg.weight':
('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]),
'bbox_head.conv_pred.conv_reg.bias':
('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)])
}
# Delete some useless keys
for key in DEL_KEYS:
converted_ckpt.pop(key)
# Rename keys with specific prefix
RENAME_KEYS = dict()
for old_key in converted_ckpt.keys():
for rename_prefix in RENAME_PREFIX.keys():
if rename_prefix in old_key:
new_key = old_key.replace(rename_prefix,
RENAME_PREFIX[rename_prefix])
RENAME_KEYS[new_key] = old_key
for new_key, old_key in RENAME_KEYS.items():
converted_ckpt[new_key] = converted_ckpt.pop(old_key)
# Extract weights and rename the keys
for new_key, (old_key, indices) in EXTRACT_KEYS.items():
cur_layers = orig_ckpt[old_key]
converted_layers = []
for (start, end) in indices:
if end != -1:
converted_layers.append(cur_layers[start:end])
else:
converted_layers.append(cur_layers[start:])
converted_layers = torch.cat(converted_layers, 0)
converted_ckpt[new_key] = converted_layers
if old_key in converted_ckpt.keys():
converted_ckpt.pop(old_key)
# Check the converted checkpoint by loading to the model
load_state_dict(model, converted_ckpt, strict=True)
checkpoint['state_dict'] = converted_ckpt
torch.save(checkpoint, args.out)
if __name__ == '__main__':
main()