Skip to content

Commit 7ed9cee

Browse files
committed
CLI improvements
1 parent 65ec5b3 commit 7ed9cee

File tree

4 files changed

+344
-120
lines changed

4 files changed

+344
-120
lines changed

jstabl/jstabl_control

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import argparse
5+
import os
6+
from tqdm import tqdm
7+
import SimpleITK as sitk
8+
from urllib.request import urlopen
9+
10+
# pytorch
11+
import torch
12+
import torch.nn
13+
from torch.utils.data import DataLoader
14+
from torchvision.transforms import Compose
15+
16+
# TorchIO
17+
import torchio
18+
from torchio import ImagesDataset, Image, Subject, DATA
19+
from torchio.transforms import (
20+
ZNormalization,
21+
CenterCropOrPad,
22+
ToCanonical,
23+
Resample
24+
)
25+
26+
import jstabl
27+
from jstabl.networks.UNetModalityMeanTogether import Generic_UNet
28+
from jstabl.utilities.sampling import GridSampler, GridAggregator
29+
30+
31+
# Define training and patches sampling parameters
32+
patch_size = (128,128,128)
33+
34+
35+
MODALITIES = ['T1']
36+
37+
38+
def multi_inference(paths_dict,
39+
model,
40+
transformation,
41+
device,
42+
opt):
43+
print("[INFO] Loading model.")
44+
subjects_dataset_inf = ImagesDataset(paths_dict, transform=transformation)
45+
46+
border = (0,0,0)
47+
48+
print("[INFO] Starting Inference.")
49+
for index, batch in enumerate(subjects_dataset_inf):
50+
original_shape = batch['T1'][DATA].shape[1:]
51+
reference = sitk.ReadImage(opt.t1)
52+
53+
new_shape = []
54+
for i,dim in enumerate(original_shape):
55+
new_dim = dim if dim>patch_size[i] else patch_size[i]
56+
new_shape.append(new_dim)
57+
58+
batch_pad = CenterCropOrPad(tuple(new_shape))(batch)
59+
affine_pad = batch_pad['T1']['affine']
60+
61+
62+
data = batch_pad['T1'][DATA]
63+
64+
sampler = GridSampler(data, opt.window_size, border)
65+
aggregator = GridAggregator(data, border)
66+
loader = DataLoader(sampler, batch_size=1)
67+
68+
with torch.no_grad():
69+
for batch_elemt in tqdm(loader):
70+
locations = batch_elemt['location']
71+
input_tensor = {'T1':batch_elemt['image'][:,:1,...].to(device)}
72+
#input_tensor['all'] = batch_elemt['image'].to(device)
73+
labels = 0.0
74+
for fold in range(1,4):
75+
path_model = os.path.join(list(jstabl.__path__)[0], f"./pretrained/glioma_{fold}.pth")
76+
model.load_state_dict(torch.load(path_model, map_location=device))
77+
logits,_ = model(input_tensor)
78+
labels+= torch.nn.Softmax(1)(logits)
79+
labels = labels.argmax(dim=1, keepdim=True)
80+
outputs = labels
81+
aggregator.add_batch(outputs, locations)
82+
83+
output = aggregator.output_array.astype(float)
84+
output = torchio.utils.nib_to_sitk(output, affine_pad)
85+
output = sitk.Resample(
86+
output,
87+
reference,
88+
sitk.Transform(),
89+
sitk.sitkNearestNeighbor,
90+
)
91+
sitk.WriteImage(output, opt.res)
92+
93+
94+
def main():
95+
opt = parsing_data()
96+
97+
if torch.cuda.is_available():
98+
print("[INFO] GPU available.")
99+
else:
100+
print("[INFO] GPU isn't available. Using CPU instead.")
101+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
102+
103+
print("[INFO] Reading data.")
104+
assert os.path.exists(opt.t1), 'T1 scan not found'
105+
106+
if opt.res is None:
107+
opt.res = opt.t1.replace('.nii', '_seg.nii')
108+
109+
output_folder = os.path.dirname(opt.res)
110+
if not os.path.exists(output_folder) and len(output_folder)>0:
111+
os.makedirs(output_folder)
112+
113+
if opt.preprocess:
114+
try:
115+
from MRIPreprocessor.mri_preprocessor import Preprocessor
116+
output_folder = os.path.dirname(opt.res)
117+
ppr = Preprocessor({
118+
'T1':opt.t1},
119+
output_folder=output_folder,
120+
reference='T1')
121+
ppr.run_pipeline()
122+
t1 = os.path.join(output_folder,'skullstripping', 'T1.nii.gz')
123+
124+
except ImportError:
125+
raise ImportError('Please install MRIPreprocessor. Run: ' '\n'
126+
'\t pip install git+https://github.com/ReubenDo/MRIPreprocessor#egg=MRIPreprocessor')
127+
else:
128+
t1 = opt.t1
129+
130+
# filing paths
131+
paths_dict = [Subject(
132+
Image('T1', t1, torchio.INTENSITY))]
133+
134+
135+
transform_inference = (
136+
ToCanonical(),
137+
ZNormalization(),
138+
Resample(1),
139+
)
140+
transform_inference = Compose(transform_inference)
141+
142+
# MODEL
143+
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
144+
net_nonlin = torch.nn.LeakyReLU
145+
net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
146+
147+
print("[INFO] Building model.")
148+
model= Generic_UNet(input_modalities=['T1', 'all'],
149+
base_num_features=32,
150+
num_classes=10,
151+
num_pool=4,
152+
num_conv_per_stage=2,
153+
feat_map_mul_on_downscale=2,
154+
conv_op=torch.nn.Conv3d,
155+
norm_op=torch.nn.InstanceNorm3d,
156+
norm_op_kwargs=norm_op_kwargs,
157+
nonlin=net_nonlin,
158+
nonlin_kwargs=net_nonlin_kwargs,
159+
convolutional_pooling=False,
160+
convolutional_upsampling=False,
161+
final_nonlin=lambda x: x,
162+
input_features={'T1':1, 'all':4})
163+
164+
for fold in range(1,4):
165+
path_model = os.path.join(list(jstabl.__path__)[0], f"./pretrained/glioma_{fold}.pth")
166+
if not os.path.isfile(path_model):
167+
url = f"https://zenodo.org/record/4040853/files/glioma_{fold}.pth?download=1"
168+
print("Downloading", url, "...")
169+
data = urlopen(url).read()
170+
with open(path_model, 'wb') as f:
171+
f.write(data)
172+
173+
174+
model.to(device)
175+
model.eval()
176+
177+
multi_inference(paths_dict, model, transform_inference, device, opt)
178+
print(f"[INFO] Inference done. Segmentation saved here: {opt.res}")
179+
print("Have a good day!")
180+
181+
182+
def parsing_data():
183+
parser = argparse.ArgumentParser(
184+
description='Joint Tissue and Glioma segmentation designed for controls')
185+
186+
187+
parser.add_argument('-t1',
188+
type=str,
189+
required=True,
190+
help='Filename of the T1 scan')
191+
192+
parser.add_argument('-res',
193+
type=str,
194+
default=None,
195+
help='Filename of the OUTPUT segmentation')
196+
197+
parser.add_argument('--preprocess',
198+
action='store_true',
199+
help='Tag to use for preprocessing the data (co-registration + skull-stripping)')
200+
201+
202+
parser.add_argument('--window_size',
203+
type=tuple,
204+
default=patch_size,
205+
help='Patch size for patch-based inference')
206+
207+
opt = parser.parse_args()
208+
209+
return opt
210+
211+
212+
if __name__ == '__main__':
213+
main()
214+
215+
216+

0 commit comments

Comments
 (0)