1
+ import torch
2
+ import torchio as tio
3
+ from torch .utils .data import DataLoader
4
+ from torchio .data import GridSampler , GridAggregator
5
+ from rhtorch .utilities .config import UserConfig
6
+ from rhtorch .utilities .modules import (
7
+ recursive_find_python_class ,
8
+ find_best_checkpoint
9
+ )
10
+ import numpy as np
11
+ from pathlib import Path
12
+ import argparse
13
+ import nibabel as nib
14
+ import sys
15
+ from tqdm import tqdm
16
+
17
+
18
+ def infer_data_from_model (model , subject , ps = None , po = None , bs = 1 , GPU = True ):
19
+ """Infer a full volume given a trained model for 1 patient
20
+
21
+ Args:
22
+ model (torch.nn.Module): trained pytorch model
23
+ subject (torchio.Subject): Subject instance from TorchIO library
24
+ ps (list, optional): Patch size (from config). Defaults to None.
25
+ po (int or list, optional): Patch overlap. Defaults to None.
26
+ bs (int, optional): batch_size (from_config). Defaults to 1.
27
+
28
+ Returns:
29
+ [np.ndarray]: Full volume inferred from model
30
+ """
31
+ grid_sampler = GridSampler (subject , ps , po )
32
+ patch_loader = DataLoader (grid_sampler , batch_size = bs )
33
+ aggregator = GridAggregator (grid_sampler , overlap_mode = 'average' )
34
+ with torch .no_grad ():
35
+ for patches_batch in patch_loader :
36
+ patch_x , _ = model .prepare_batch (patches_batch )
37
+ if GPU :
38
+ patch_x = patch_x .to ('cuda' )
39
+ locations = patches_batch [tio .LOCATION ]
40
+ patch_y = model (patch_x )
41
+ aggregator .add_batch (patch_y , locations )
42
+ return aggregator .get_output_tensor ()
43
+
44
+
45
+ if __name__ == '__main__' :
46
+ parser = argparse .ArgumentParser (
47
+ description = 'Infer new data from input model.' )
48
+ parser .add_argument ("-c" , "--config" ,
49
+ help = "Config file of saved model" ,
50
+ type = str , default = 'config.yaml' )
51
+ parser .add_argument ("--checkpoint" ,
52
+ help = "Choose specific checkpoint that overwrites the config" ,
53
+ type = str , default = None )
54
+ parser .add_argument ("-o" , "--onnx" ,
55
+ help = "Output onnx path" ,
56
+ type = str , default = 'model.onnx' )
57
+
58
+ args = parser .parse_args ()
59
+
60
+ # load configs in inference mode
61
+ user_configs = UserConfig (args , mode = 'infer' )
62
+ model_dir = user_configs .rootdir
63
+ configs = user_configs .hparams
64
+ project_dir = Path (configs ['project_dir' ])
65
+ model_name = configs ['model_name' ]
66
+ data_shape_in = configs ['data_shape_in' ]
67
+ patch_size = configs ['patch_size' ]
68
+ channels_in = data_shape_in [0 ]
69
+
70
+ input_sample = torch .randn ([1 ,channels_in ,]+ patch_size )
71
+
72
+ # load the model
73
+ module_name = recursive_find_python_class (configs ['module' ])
74
+ model = module_name (configs , data_shape_in )
75
+
76
+ if args .checkpoint is None :
77
+ # Load the final (best) model
78
+ if 'best_model' in configs :
79
+ ckpt_path = Path (configs ['best_model' ])
80
+ epoch_suffix = ''
81
+ # Not done training. Load the most recent (best) ckpt
82
+ else :
83
+ ckpt_path = find_best_checkpoint (project_dir .joinpath ('trained_models' , model_name , 'checkpoints' ))
84
+ epoch_suffix = None
85
+ else :
86
+ ckpt_path = args .checkpoint
87
+ ckpt = torch .load (ckpt_path )
88
+ model .load_state_dict (ckpt ['state_dict' ])
89
+
90
+ # Export the model
91
+ torch .onnx .export (model , # model being run
92
+ input_sample , # model input (or a tuple for multiple inputs)
93
+ args .onnx , # where to save the model (can be a file or file-like object)
94
+ export_params = True , # store the trained parameter weights inside the model file
95
+ opset_version = 10 , # the ONNX version to export the model to
96
+ do_constant_folding = True , # whether to execute constant folding for optimization
97
+ input_names = ['input' ], # the model's input names
98
+ output_names = ['output' ], # the model's output names
99
+ dynamic_axes = {'input' : {0 : 'batch_size' }, # variable length axes
100
+ 'output' : {0 : 'batch_size' }})
0 commit comments