-
Notifications
You must be signed in to change notification settings - Fork 37
/
datasets.py
111 lines (71 loc) · 3.29 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import cv2
import numpy as np
import os.path as osp
import scipy.io as sio
from torch.utils.data import Dataset
IMAGE_MEAN = np.array([103.939, 116.779, 123.675], dtype=np.float32)
class FluxSegmentationDataset(Dataset):
def __init__(self, dataset='PascalContext', mode='train'):
self.dataset = dataset
self.mode = mode
file_dir = 'datasets/' + self.dataset + '/' + self.mode + '.txt'
self.random_flip = False
if self.dataset == 'PascalContext' and mode == 'train':
self.random_flip = True
with open(file_dir, 'r') as f:
self.image_names = f.read().splitlines()
self.dataset_length = len(self.image_names)
def __len__(self):
return self.dataset_length
def __getitem__(self, index):
random_int = np.random.randint(0,2)
image_name = self.image_names[index]
image_path = osp.join('datasets', self.dataset, 'images', image_name + '.jpg')
image = cv2.imread(image_path, 1)
if self.random_flip:
if random_int:
image = cv2.flip(image, 1)
vis_image = image.copy()
height, width = image.shape[:2]
image = image.astype(np.float32)
image -= IMAGE_MEAN
image = image.transpose(2, 0, 1)
if self.dataset == 'PascalContext':
label_path = osp.join('datasets', self.dataset, 'labels', image_name + '.mat')
label = sio.loadmat(label_path)['LabelMap']
elif self.dataset == 'BSDS500':
label_path = osp.join('datasets', self.dataset, 'labels', image_name + '.png')
label = cv2.imread(label_path, 0)
if self.random_flip:
if random_int:
label = cv2.flip(label, 1)
label += 1
gt_mask = label.astype(np.float32)
categories = np.unique(label)
if 0 in categories:
raise RuntimeError('invalid category')
label = cv2.copyMakeBorder(label, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
weight_matrix = np.zeros((height+2, width+2), dtype=np.float32)
direction_field = np.zeros((2, height+2, width+2), dtype=np.float32)
for category in categories:
img = (label == category).astype(np.uint8)
weight_matrix[img > 0] = 1. / np.sqrt(img.sum())
_, labels = cv2.distanceTransformWithLabels(img, cv2.DIST_L2, cv2.DIST_MASK_PRECISE, labelType=cv2.DIST_LABEL_PIXEL)
index = np.copy(labels)
index[img > 0] = 0
place = np.argwhere(index > 0)
nearCord = place[labels-1,:]
x = nearCord[:, :, 0]
y = nearCord[:, :, 1]
nearPixel = np.zeros((2, height+2, width+2))
nearPixel[0,:,:] = x
nearPixel[1,:,:] = y
grid = np.indices(img.shape)
grid = grid.astype(float)
diff = grid - nearPixel
direction_field[:, img > 0] = diff[:, img > 0]
weight_matrix = weight_matrix[1:-1, 1:-1]
direction_field = direction_field[:, 1:-1, 1:-1]
if self.dataset == 'BSDS500':
image_name = image_name.split('/')[-1]
return image, vis_image, gt_mask, direction_field, weight_matrix, self.dataset_length, image_name