-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_masks.py
121 lines (86 loc) · 3.09 KB
/
gen_masks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from tqdm import tqdm
import numpy as np
from scipy import ndimage
import torch
from torch.utils.data import DataLoader
from data.dataset import ISPRS_Dataset
from models.Unet import UNET
configs = {
'path_to_weakly_masks': 'data/preprocessed/weakly_masks',
'path_to_weak_mask_erosion': 'data/preprocessed/weakly_masks_erosion',
'model_path': 'checkpoints/baseline_Unet.pth'
}
def makedirs(dirs: str):
"""
Creates tree of directories
Parameters:
dirs: directory tree to be created
Returns:
None
"""
if not os.path.exists(dirs):
os.makedirs(dirs)
return
def erosion(img: np.array):
"""
Corrects borders of masks.
Parameters:
img: input image mask
Returns:
Mask with corrected borders
"""
img = img.copy()
# make dilation with thickness of 10
for i in range(4):
img[i] = ndimage.binary_dilation(img[i], iterations=10).astype(int)
orig_img = img.copy()
# make erosion with thickness of 20
for i in range(4):
img[i] = ndimage.binary_erosion(img[i], iterations=20).astype(int)
# Assign region of erosion with -1
return img - (orig_img - img)
def get_masks(model, dataloader, device, path_to_weakly_masks, path_to_weak_mask_erosion):
"""
Generates masks for given dataloader and correct them with image level labels.
Parameters:
model: model that produces masks
dataloader: dataloader
device: device
path_to_weakly_masks: path to save generated masks
path_to_weak_mask_erosion: path to save generated masks with erosion
Returns:
None
"""
makedirs(path_to_weakly_masks)
makedirs(path_to_weak_mask_erosion)
model.eval()
for batch in tqdm(dataloader, desc='Gen masks'):
file_names = batch['file_name']
imgs = batch['img']
imgs = imgs.to(device=device, dtype=torch.float32)
img_label = batch['label']
with torch.no_grad():
pred_masks = model(imgs).cpu().detach().numpy()
pred_masks[pred_masks > 0] = 1
pred_masks[pred_masks < 0] = 0
# correct mask with image label
pred_masks[img_label == 0] = 0
for i in range(0, len(imgs)):
np.save(f'{path_to_weakly_masks}/{file_names[i]}', pred_masks[i])
pred_masks[i] = erosion(pred_masks[i])
np.save(f'{path_to_weak_mask_erosion}/{file_names[i]}', pred_masks[i])
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNET(in_channels=3, out_channels=5)
model.load_state_dict(torch.load(configs['model_path']))
model.to(device)
dataset = ISPRS_Dataset('data/preprocessed', 'data/preprocessed/metadata.csv', 'weak_train')
dataloader = DataLoader(dataset, shuffle=False, batch_size=3)
get_masks(
model=model,
dataloader=dataloader,
device=device,
path_to_weakly_masks=configs['path_to_weakly_masks'],
path_to_weak_mask_erosion=configs['path_to_weak_mask_erosion']
)