-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader_classifier.py
51 lines (36 loc) · 1.47 KB
/
dataloader_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from image_augment_pairs import *
class RoomDataset(Dataset):
def __init__(self, file_path, train=True, augment=False):
self.file_path = file_path
self.augment = augment
self.train = train
self.img_list = [f for f in os.listdir(file_path) if 'image' in f]
if self.train:
self.label_list = [f.replace('image', 'room') for f in self.img_list]
def __len__(self):
return len(self.img_list)
def _to_tensor(self, array, is_label=False):
assert (isinstance(array, np.ndarray))
tensor = torch.from_numpy(array)
return tensor.long() if is_label else tensor.float()
def __getitem__(self, index):
image = np.load(os.path.join(self.file_path, self.img_list[index])).transpose(1,2,0)
if not self.train:
return self._to_tensor(image).permute(2,0,1), self.img_list[index]
label = np.load(os.path.join(self.file_path, self.label_list[index]))
height, width = label.shape
if self.train and self.augment:
# random rotations
random_rotation(image, label)
# random h-flips
horizontal_flip(image, label)
# random v-flips
vertical_flip(image, label)
# random crops
scale_augmentation(image, label)
return self._to_tensor(image).permute(2,0,1), self._to_tensor(label, is_label=True)