-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_generator.py
122 lines (100 loc) · 4.09 KB
/
data_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
from omegaconf import DictConfig
from utils.general_utils import get_data_paths
from utils.images_utils import prepare_image, prepare_mask
class DataGenerator(Dataset):
"""
Generate batches of data for model by reading images and their
corresponding masks.
There are two options: you can either pass a directory path or a list.
In case of directory, it should contain the relative path of images/mask
folder from project root path.
In case of list of images, every element should contain an absolute path
for each image and mask.
Because this generator is also used for prediction, so during testing you can
set mask path to None if masks are not available for visualization.
"""
def __init__(self, cfg: DictConfig, mode: str):
"""
Initialization
"""
self.cfg = cfg
self.mode = mode
self.batch_size = self.cfg.HYPER_PARAMETERS.BATCH_SIZE
# set seed for reproducibility
np.random.seed(cfg.SEED)
# check if masks are available
self.mask_available = False if cfg.DATASET[mode].MASK_PATH is None or str(
cfg.DATASET[mode].MASK_PATH).lower() == "none" else True
data_paths = get_data_paths(cfg, mode, self.mask_available)
self.images_paths = data_paths[0]
if self.mask_available:
self.mask_paths = data_paths[1]
self.on_epoch_end()
self.__data_generation(self.indexes)
def __len__(self):
"""
Denotes the number of batches per epoch
"""
return int(np.floor(len(self.images_paths) / self.batch_size))
def on_epoch_end(self):
"""
Updates indexes after each epoch
"""
self.indexes = np.arange(len(self.images_paths))
if self.cfg.PREPROCESS_DATA.SHUFFLE[self.mode].VALUE:
np.random.shuffle(self.indexes)
def __getitem__(self, index):
"""
Generate one batch of data
"""
if self.mask_available:
return self.batch_images[index], self.batch_masks[index]
else:
return self.batch_images[index],
def __data_generation(self, indexes):
"""
Generates batch data
"""
# create empty array to store batch data
self.batch_images = []
if self.mask_available:
self.batch_masks = []
for i, index in enumerate(indexes):
# extract path from list
img_path = self.images_paths[int(index)]
if self.mask_available:
mask_path = self.mask_paths[int(index)]
# prepare image for model by resizing and preprocessing it
image = prepare_image(
img_path,
self.cfg.PREPROCESS_DATA.RESIZE,
self.cfg.PREPROCESS_DATA.IMAGE_PREPROCESSING_TYPE,
)
if self.mask_available:
# prepare image for model by resizing and preprocessing it
mask = prepare_mask(
mask_path,
self.cfg.PREPROCESS_DATA.RESIZE,
self.cfg.PREPROCESS_DATA.NORMALIZE_MASK,
)
# convert to PyTorch tensor
image = torch.from_numpy(image).permute(2, 0, 1).float()
if self.mask_available:
mask = torch.from_numpy(mask).long()
# add to batch
self.batch_images.append(image)
if self.mask_available:
# convert mask into one hot vectors
# height x width --> height x width x output classes
mask = F.one_hot(mask, num_classes=self.cfg.OUTPUT.CLASSES+1).permute(2, 0, 1).float()
self.batch_masks.append(mask)
def get_data_loader(cfg: DictConfig, mode: str):
"""
Return data loader for the given configuration and mode
"""
dataset = DataGenerator(cfg=cfg, mode=mode)
return DataLoader(dataset, batch_size=cfg.HYPER_PARAMETERS.BATCH_SIZE, shuffle=cfg.PREPROCESS_DATA.SHUFFLE[mode].VALUE)