Skip to content

Commit f41f19b

Browse files
committed
Added.
1 parent 2cc8624 commit f41f19b

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

conversion/torch2oonx.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import torch
2+
import torchio as tio
3+
from torch.utils.data import DataLoader
4+
from torchio.data import GridSampler, GridAggregator
5+
from rhtorch.utilities.config import UserConfig
6+
from rhtorch.utilities.modules import (
7+
recursive_find_python_class,
8+
find_best_checkpoint
9+
)
10+
import numpy as np
11+
from pathlib import Path
12+
import argparse
13+
import nibabel as nib
14+
import sys
15+
from tqdm import tqdm
16+
17+
18+
def infer_data_from_model(model, subject, ps=None, po=None, bs=1, GPU=True):
19+
"""Infer a full volume given a trained model for 1 patient
20+
21+
Args:
22+
model (torch.nn.Module): trained pytorch model
23+
subject (torchio.Subject): Subject instance from TorchIO library
24+
ps (list, optional): Patch size (from config). Defaults to None.
25+
po (int or list, optional): Patch overlap. Defaults to None.
26+
bs (int, optional): batch_size (from_config). Defaults to 1.
27+
28+
Returns:
29+
[np.ndarray]: Full volume inferred from model
30+
"""
31+
grid_sampler = GridSampler(subject, ps, po)
32+
patch_loader = DataLoader(grid_sampler, batch_size=bs)
33+
aggregator = GridAggregator(grid_sampler, overlap_mode='average')
34+
with torch.no_grad():
35+
for patches_batch in patch_loader:
36+
patch_x, _ = model.prepare_batch(patches_batch)
37+
if GPU:
38+
patch_x = patch_x.to('cuda')
39+
locations = patches_batch[tio.LOCATION]
40+
patch_y = model(patch_x)
41+
aggregator.add_batch(patch_y, locations)
42+
return aggregator.get_output_tensor()
43+
44+
45+
if __name__ == '__main__':
46+
parser = argparse.ArgumentParser(
47+
description='Infer new data from input model.')
48+
parser.add_argument("-c", "--config",
49+
help="Config file of saved model",
50+
type=str, default='config.yaml')
51+
parser.add_argument("--checkpoint",
52+
help="Choose specific checkpoint that overwrites the config",
53+
type=str, default=None)
54+
parser.add_argument("-o", "--onnx",
55+
help="Output onnx path",
56+
type=str, default='model.onnx')
57+
58+
args = parser.parse_args()
59+
60+
# load configs in inference mode
61+
user_configs = UserConfig(args, mode='infer')
62+
model_dir = user_configs.rootdir
63+
configs = user_configs.hparams
64+
project_dir = Path(configs['project_dir'])
65+
model_name = configs['model_name']
66+
data_shape_in = configs['data_shape_in']
67+
patch_size = configs['patch_size']
68+
channels_in = data_shape_in[0]
69+
70+
input_sample = torch.randn([1,channels_in,]+patch_size)
71+
72+
# load the model
73+
module_name = recursive_find_python_class(configs['module'])
74+
model = module_name(configs, data_shape_in)
75+
76+
if args.checkpoint is None:
77+
# Load the final (best) model
78+
if 'best_model' in configs:
79+
ckpt_path = Path(configs['best_model'])
80+
epoch_suffix = ''
81+
# Not done training. Load the most recent (best) ckpt
82+
else:
83+
ckpt_path = find_best_checkpoint(project_dir.joinpath('trained_models', model_name, 'checkpoints'))
84+
epoch_suffix = None
85+
else:
86+
ckpt_path = args.checkpoint
87+
ckpt = torch.load(ckpt_path)
88+
model.load_state_dict(ckpt['state_dict'])
89+
90+
# Export the model
91+
torch.onnx.export(model, # model being run
92+
input_sample, # model input (or a tuple for multiple inputs)
93+
args.onnx, # where to save the model (can be a file or file-like object)
94+
export_params=True, # store the trained parameter weights inside the model file
95+
opset_version=10, # the ONNX version to export the model to
96+
do_constant_folding=True, # whether to execute constant folding for optimization
97+
input_names = ['input'], # the model's input names
98+
output_names = ['output'], # the model's output names
99+
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
100+
'output' : {0 : 'batch_size'}})

0 commit comments

Comments
 (0)