forked from joewong00/3D-CNN-Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
88 lines (63 loc) · 3.29 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import torch
import logging
import os
import nibabel as nib
from residual3dunet.model import UNet3D, ResidualUNet3D
# from residual3dunet.res3dunetmodel import ResidualUNet3D
from torch.nn import DataParallel
from utils.segmentation_statistics import SegmentationStatistics
from utils.utils import load_checkpoint, read_data_as_numpy, numpy_to_nii, visualize2d, plot_sidebyside, plot_overlapped, preprocess, predict
def get_args():
# Test settings
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--network', '-u', default='Unet3D', help='Specify the network (Unet3D / ResidualUnet3D)')
parser.add_argument('--model', '-m', default='model.pt', metavar='FILE', help='Specify the path to the file in which the model is stored (default:model.pt)')
parser.add_argument('--input', '-i', metavar='INPUT', help='Path to the image file (format: nii.gz)', required=True)
parser.add_argument('--mask', '-l', metavar='INPUT', default=None, help='Path to the ground truth of the input image (if_available) (default:None)')
parser.add_argument('--viz', '-v', action='store_true', default=True, help='Visualize the output (default:True)')
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA testing (default: False)')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, help='Minimum probability value to consider a mask pixel white (default: 0.5)')
return parser.parse_args()
def main():
args = get_args()
filename = os.path.basename(args.input)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
assert args.network.casefold() == "unet3d" or args.network.casefold() == "residualunet3d", 'Network must be either (Unet3D / ResidualUnet3D)'
# Specify network
if args.network.casefold() == "unet3d":
model = UNet3D(in_channels=1, out_channels=1, testing=True).to(device)
else:
model = ResidualUNet3D(in_channels=1, out_channels=1, testing=True).to(device)
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
# If using multiple gpu
if torch.cuda.device_count() > 1 and use_cuda:
model = DataParallel(model)
load_checkpoint(args.model, model ,device=device)
logging.info('Model loaded!')
logging.info(f'\nPredicting image {filename} ...')
data = preprocess(read_data_as_numpy(args.input), rotate=True, to_tensor=False)
prediction = predict(model, data, args.mask_threshold, device)
if not args.no_save:
# Save prediction mask as nii.gz at output dir
if not os.path.exists('output'):
os.mkdir('output')
image_data = numpy_to_nii(prediction)
nib.save(image_data, f"output/Mask_{filename}")
logging.info(f'\nMask saved to output/Mask_{filename}')
if args.viz:
visualize2d(prediction)
# Evaluation statistics
if args.mask is not None:
target = preprocess(read_data_as_numpy(args.mask),rotate=True, to_tensor=False)
plot_overlapped(data, prediction, target)
plot_sidebyside(data, prediction, target)
prediction = prediction.astype(bool)
target = target.astype(bool)
stat = SegmentationStatistics(prediction, target, (3,2,1))
stat.print_table()
if __name__ == '__main__':
main()