forked from zhaofenqiang/Spherical_U-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_3hg.py
100 lines (80 loc) · 4.05 KB
/
predict_3hg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
'''
Author: HenryVarro666 1504517223@qq.com
Date: 2024-06-12 14:40:18
LastEditors: HenryVarro666 1504517223@qq.com
LastEditTime: 2024-07-25 16:29:02
FilePath: \Spherical_U-Net\predict.py
'''
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import argparse
import numpy as np
import os
from model import Unet_40k, Unet_160k
from sphericalunet.utils.vtk import read_vtk, write_vtk, resample_label
from sphericalunet.utils.interp_numpy import resampleSphereSurf
from torch.nn.functional import sigmoid
def inference(curv, sulc, model, device):
feats = torch.cat((curv, sulc), 1)
# feat_max = torch.tensor([1.2, 13.7], device=device).view(1,2,1) # Add batch dimension
feat_max = torch.tensor([1.2, 13.7], device=device)
feats = feats / feat_max
with torch.no_grad():
feats = feats.unsqueeze(0) # Add batch dimension
prediction = model(feats)
prediction = sigmoid(prediction)
return prediction.cpu().numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict the parcellation maps with 36 regions from the input surfaces',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--hemisphere', '-hemi', default='left',
choices=['left', 'right'],
help="Specify the hemisphere for parcellation, left or right.")
parser.add_argument('--level', '-l', default='7',
choices=['7', '8'],
help="Specify the level of the surfaces' resolution. Generally, level 7 with 40962 vertices is sufficient, level 8 with 163842 vertices is more accurate but slower.")
parser.add_argument('--input', '-i', metavar='INPUT',
help='filename of input surface')
parser.add_argument('--output', '-o', default='[input].parc.vtk', metavar='OUTPUT',
help='Filename of output surface.')
parser.add_argument('--device', default='GPU', choices=['GPU', 'CPU'],
help='the device for running the model.')
args = parser.parse_args()
in_file = args.input
out_file = args.output
hemi = args.hemisphere
level = args.level
device = torch.device('cuda:0' if args.device == 'GPU' else 'cpu')
if not in_file:
raise ValueError('Input filename is required')
if out_file == '[input].parc.vtk':
out_file = in_file.replace('.vtk', '.parc.vtk')
model = Unet_40k(2, 1) if level == '7' else Unet_160k(2, 1)
# model_path = f'trained_models/Unet_{"40k_1.pkl" if level == "7" else "160k_curv_sulc.pkl"}'
model_path = f'trained_models_4/Unet_{"40k_1_final.pkl" if level == "7" else "160k_curv_sulc.pkl"}'
n_vertices = 40962 if level == '7' else 163842
model.to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
template = read_vtk(f'neigh_indices/sphere_{n_vertices}_rotated_0.vtk')
orig_surf = read_vtk(in_file)
curv_temp = orig_surf['curv']
if len(curv_temp) != n_vertices:
sucu = resampleSphereSurf(orig_surf['vertices'], template['vertices'],
np.concatenate((orig_surf['sulc'][:, np.newaxis],
orig_surf['curv'][:, np.newaxis]), axis=1))
sulc, curv = sucu[:, 0], sucu[:, 1]
else:
curv, sulc = orig_surf['curv'][:n_vertices], orig_surf['sulc'][:n_vertices]
curv = torch.from_numpy(curv).unsqueeze(1).to(device)
sulc = torch.from_numpy(sulc).unsqueeze(1).to(device)
# curv = torch.from_numpy(curv).unsqueeze(0).unsqueeze(1).to(device) # Add batch and channel dimensions
# sulc = torch.from_numpy(sulc).unsqueeze(0).unsqueeze(1).to(device) # Add batch and channel dimensions
pred = inference(curv, sulc, model, device)
# 移除批量维度
pred_prob = pred.squeeze() # (N, 1) -> (N)
pred = np.array(pred_prob > 0.5, dtype=np.int32)
orig_surf['gyralnet_prediction'] = pred
orig_surf['gyralnet_prediction_prob'] = pred_prob
write_vtk(orig_surf, out_file)