-
Notifications
You must be signed in to change notification settings - Fork 9
/
evaluate
executable file
·83 lines (67 loc) · 2.75 KB
/
evaluate
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python
"""
Script to evaluate a custom trained TopoFit model. If this code is
useful to you, please cite:
TopoFit: Rapid Reconstruction of Topologically-Correct Cortical Surfaces
Andrew Hoopes, Juan Eugenio Iglesias, Bruce Fischl, Douglas Greve, Adrian Dalca
Medical Imaging with Deep Learning. 2022.
"""
import os
import argparse
import numpy as np
import surfa as sf
import torch
import topofit
parser = argparse.ArgumentParser()
parser.add_argument('--subjs', nargs='+', required=True, help='subject(s) to evaluate')
parser.add_argument('--hemi', required=True, help='hemisphere to evaluate (`lr` or `rh`)')
parser.add_argument('--model', required=True, help='model file (.pt) to load')
parser.add_argument('--suffix', default='topofit', help='custom ')
parser.add_argument('--gpu', default='0', help='GPU device ID (default is 0')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU')
parser.add_argument('--vol', help='Input volume (norm.mgz)',default='norm.mgz')
parser.add_argument('--xhemi', action='store_true', help='Xhemi')
args = parser.parse_args()
print(f'Input volume is {args.vol}');
print(f'Xhemi {args.xhemi}');
# sanity check on inputs
if args.hemi not in ('lh', 'rh'):
print("error: hemi must be 'lh' or 'rh'")
exit(1)
# configure device
if args.cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device = torch.device('cpu')
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
device = torch.device('cuda')
topofit.utils.set_device(device)
# configure model
print('Configuring model')
model = topofit.model.SurfNet().to(device)
# initialize model weights
print(f'Loading model weights from {args.model}')
weights = torch.load(args.model, map_location=device)
model.load_state_dict(weights['model_state_dict'])
# enable evaluation mode
model.train(mode=False)
# start training loop
for subj in args.subjs:
# load subject data
data = topofit.io.load_subject_data(subj, args.hemi, vol=args.vol, xhemi=args.xhemi)
# predict surface
with torch.no_grad():
input_image = data['input_image'].to(device)
input_vertices = data['input_vertices'].to(device)
result, topology = model(input_image, input_vertices)
vertices = result['pred_vertices'].cpu().numpy()
faces = topology['faces'].cpu().numpy()
# build mesh and convert to correct space and geometry
surf = sf.Mesh(vertices, faces, space='vox', geometry=data['cropped_geometry'])
surf = surf.convert(geometry=data['input_geometry'])
# write surface
filename = os.path.join(subj, 'surf', f'{args.hemi}.white.{args.suffix}')
surf.save(filename)
print(f'Saved white-matter surface to {filename}')