-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader.py
68 lines (55 loc) · 2.55 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from torchvision import transforms
class RoomDataset(Dataset):
def __init__(self, file_path, train=True, augment=False):
self.file_path = file_path
self.augment = augment
self.train = train
self.img_list = []
self.label_list = []
for img in [f for f in os.listdir(self.file_path) if 'image' in f]:
self.img_list += [img]
self.label_list += [img.replace('image', 'room')]
def __len__(self):
return len(self.img_list)
# convert PIL image to ndarray
def _pil2np(img):
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
# convert ndarray to PIL image
def _np2pil(img):
if isinstance(img, np.ndarray):
if img.dtype != np.uint8:
img = img.astype(np.uint8)
img = F.to_pil_image(img)
return img
def __getitem__(self, index):
image = np.load(os.path.join(self.file_path, self.img_list[index]))[:, :, ::-1]
label = np.load(os.path.join(self.file_path, self.label_list[index]))
height, width = image.shape[1:]
ch_label = label[2]
if self.train and self.augment:
# random rotations
if np.random.randint(2) == 0:
ang = np.random.choice([90, -90])
image = np.dstack([F.rotate(_np2pil(image[:, :, i]), ang) for i in range(3)])
label = np.dstack([F.rotate(_np2pil(label[:, :, i]), ang) for i in range(ch_label)])
# random h-flips
if np.random.randint(2) == 0:
image = np.dstack([F.hflip(_np2pil(image[:, :, i])) for i in range(3)])
label = np.dstack([F.hflip(_np2pil(label[:, :, i])) for i in range(ch_label)])
# random v-flips
if np.random.randint(2) == 0:
image = np.dstack([F.vflip(_np2pil(image[:, :, i])) for i in range(3)])
label = np.dstack([F.vflip(_np2pil(label[:, :, i])) for i in range(ch_label)])
# random crops
if np.random.randint(2) == 0:
i, j, h, w = transforms.RandomCrop.get_params(_np2pil(label), output_size=(height//2, width//2))
image = np.dstack([F.resized_crop(_np2pil(image[:, :, ii]), i, j, h, w, (height, width)) for ii in range(3)])
label = np.dstack([F.resized_crop(_np2pil(label[:, :, ii]), i, j, h, w, (height, width)) for ii in range(ch_label)])
return image, label