Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Torch half models #3130

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/en/user_guides/how_to_deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ For example:
python tools/misc/publish_model.py ./epoch_10.pth ./epoch_10_publish.pth
```

To save model as float16 (half) add --float16, which is as follows:

```shell
python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} --float16
```

The script will automatically simplify the model, save the simplified model to the specified path, and add a timestamp to the filename, for example, `./epoch_10_publish-21815b2c_20230726.pth`.

## Deployment with MMDeploy
Expand Down
6 changes: 6 additions & 0 deletions docs/zh_cn/user_guides/how_to_deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE}
python tools/misc/publish_model.py ./epoch_10.pth ./epoch_10_publish.pth
```

要将模型保存为 float16 (half),请添加 --float16,如下所示:

```shell
python tools/misc/publish_model.py ${IN_FILE} ${OUT_FILE} --float16
```

脚本会自动对模型进行精简,并将精简后的模型保存到制定路径,并在文件名的最后加上时间戳,例如 `./epoch_10_publish-21815b2c_20230726.pth`。

## 使用 MMDeploy 部署
Expand Down
3 changes: 3 additions & 0 deletions mmpose/models/pose_estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def forward(self,
if self.metainfo is not None:
for data_sample in data_samples:
data_sample.set_metainfo(self.metainfo)
param = next(self.backbone.parameters())
if param.is_cuda and param.dtype == torch.float16:
inputs = inputs.half()
return self.predict(inputs, data_samples)
elif mode == 'tensor':
return self._forward(inputs)
Expand Down
25 changes: 23 additions & 2 deletions tools/misc/publish_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,21 @@ def parse_args():
type=str,
default=['meta', 'state_dict'],
help='keys to save in published checkpoint (default: meta state_dict)')
parser.add_argument(
'--float16',
action='store_true',
default=False,
help='Whether save model as float16')
args = parser.parse_args()
return args


def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']):
def process_checkpoint(in_file,
out_file,
save_keys=['meta', 'state_dict'],
float16=False):
checkpoint = torch.load(in_file, map_location='cpu')
checkpoint['meta']['float16'] = float16

# only keep `meta` and `state_dict` for smaller file size
ckpt_keys = list(checkpoint.keys())
Expand All @@ -41,6 +50,17 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']):
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.

if float16:
print(save_keys)
if 'meta' not in save_keys:
raise ValueError(
'Key `meta` must be in save_keys to save model as float16. '
'Change float16 to False or add `meta` in save_keys.')
print_log('Saving model as float16.', logger='current')
for key in checkpoint['state_dict'].keys():
checkpoint['state_dict'][key] = checkpoint['state_dict'][key].half(
)

if digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
Expand All @@ -58,7 +78,8 @@ def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']):

def main():
args = parse_args()
process_checkpoint(args.in_file, args.out_file, args.save_keys)
process_checkpoint(args.in_file, args.out_file, args.save_keys,
args.float16)


if __name__ == '__main__':
Expand Down