-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_model.py
159 lines (136 loc) · 5.82 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import pathlib
import sys
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader
from common.args import Args
from common.subsample import MaskFunc
from common.utils import save_reconstructions
from data import transforms
from data.mri_data import SliceData
from models.unet.unet_model import UnetModel, CustomUnetModel
class DataTransform:
"""
Data Transformer for running U-Net models on a test dataset.
"""
def __init__(self, resolution, which_challenge, mask_func=None):
"""
Args:
resolution (int): Resolution of the image.
which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset.
mask_func (common.subsample.MaskFunc): A function that can create a mask of
appropriate shape.
"""
if which_challenge not in ('singlecoil', 'multicoil'):
raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"')
self.resolution = resolution
self.which_challenge = which_challenge
self.mask_func = mask_func
def __call__(self, kspace, target, attrs, fname, slice):
"""
Args:
kspace (numpy.Array): k-space measurements
target (numpy.Array): Target image
attrs (dict): Acquisition related information stored in the HDF5 object
fname (pathlib.Path): Path to the input file
slice (int): Serial number of the slice
Returns:
(tuple): tuple containing:
image (torch.Tensor): Normalized zero-filled input image
mean (float): Mean of the zero-filled image
std (float): Standard deviation of the zero-filled image
fname (pathlib.Path): Path to the input file
slice (int): Serial number of the slice
"""
kspace = transforms.to_tensor(kspace)
if self.mask_func is not None:
seed = tuple(map(ord, fname))
masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)
else:
masked_kspace = kspace
# Inverse Fourier Transform to get zero filled solution
image = transforms.ifft2(masked_kspace)
# Crop input image
image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
# Absolute value
image = transforms.complex_abs(image)
# Apply Root-Sum-of-Squares if multicoil data
if self.which_challenge == 'multicoil':
image = transforms.root_sum_of_squares(image)
# Normalize input
image, mean, std = transforms.normalize_instance(image)
image = image.clamp(-6, 6)
return image, mean, std, fname, slice
def create_data_loaders(args):
mask_func = None
if args.mask_kspace:
mask_func = MaskFunc(args.center_fractions, args.accelerations)
data = SliceData(
root=args.data_path / f'{args.challenge}_{args.data_split}',
transform=DataTransform(args.resolution, args.challenge, mask_func),
sample_rate=1.,
challenge=args.challenge
)
data_loader = DataLoader(
dataset=data,
batch_size=args.batch_size,
num_workers=4,
pin_memory=True,
)
return data_loader
def load_model(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
args = checkpoint['args']
# model = UnetModel(1, 1, args.num_chans, args.num_pools, args.drop_prob).to(args.device)
model = CustomUnetModel(1, 1, args.num_chans, args.num_pools, args.drop_prob).to(args.device)
print(model)
if args.data_parallel:
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint['model'])
return model
def run_unet(args, model, data_loader):
model.eval()
reconstructions = defaultdict(list)
with torch.no_grad():
for (input, mean, std, fnames, slices) in data_loader:
input = input.unsqueeze(1).to(args.device)
res = model(input)
# residual
recons = res + input
recons = recons.squeeze(1).to('cpu')
for i in range(recons.shape[0]):
recons[i] = recons[i] * std[i] + mean[i]
reconstructions[fnames[i]].append((slices[i].numpy(), recons[i].numpy()))
reconstructions = {
fname: np.stack([pred for _, pred in sorted(slice_preds)])
for fname, slice_preds in reconstructions.items()
}
return reconstructions
def main(args):
data_loader = create_data_loaders(args)
model = load_model(args.checkpoint)
reconstructions = run_unet(args, model, data_loader)
save_reconstructions(reconstructions, args.out_dir)
def create_arg_parser():
parser = Args()
parser.add_argument('--mask-kspace', action='store_true',
help='Whether to apply a mask (set to True for val data and False '
'for test data')
parser.add_argument('--data-split', choices=['val', 'test_v2', 'challenge'], required=True,
help='Which data partition to run on: "val" or "test"')
parser.add_argument('--checkpoint', type=pathlib.Path, required=True,
help='Path to the U-Net model')
parser.add_argument('--out-dir', type=pathlib.Path, required=True,
help='Path to save the reconstructions to')
parser.add_argument('--batch-size', default=16, type=int, help='Mini-batch size')
parser.add_argument('--device', type=str, default='cuda', help='Which device to run on')
return parser
if __name__ == '__main__':
args = create_arg_parser().parse_args(sys.argv[1:])
main(args)