-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUNetInferenceAgent.py
69 lines (48 loc) · 1.9 KB
/
UNetInferenceAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Contains class that runs inferencing
"""
import torch
import numpy as np
from networks.RecursiveUNet import UNet
from utils.utils import med_reshape
class UNetInferenceAgent:
"""
Stores model and parameters and some methods to handle inferencing
"""
def __init__(self, parameter_file_path='', model=None, device="cpu", patch_size=64):
self.model = model
self.patch_size = patch_size
self.device = device
if model is None:
self.model = UNet(num_classes=3)
if parameter_file_path:
self.model.load_state_dict(torch.load(parameter_file_path, map_location=self.device))
self.model.to(device)
def single_volume_inference_unpadded(self, volume):
"""
Runs inference on a single volume of arbitrary patch size,
padding it to the conformant size first
Arguments:
volume {Numpy array} -- 3D array representing the volume
Returns:
3D NumPy array with prediction mask
"""
raise NotImplementedError
def single_volume_inference(self, volume):
"""
Runs inference on a single volume of conformant patch size
Arguments:
volume {Numpy array} -- 3D array representing the volume
Returns:
3D NumPy array with prediction mask
"""
self.model.eval()
# Assuming volume is a numpy array of shape [X,Y,Z] and we need to slice X axis
slices = []
slices = np.zeros(volume.shape)
for idx in range(volume.shape[0]):
slc = volume[idx, :, :]
slc_mask = torch.from_numpy(slc.astype(np.single)/np.max(slc)).unsqueeze(0).unsqueeze(0)
pred = np.squeeze(self.model(slc_mask.to(self.device)).cpu().detach())
slices[idx, :, :] = torch.argmax(pred, dim=0)
return slices