-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader_regressor.py
63 lines (47 loc) · 1.76 KB
/
dataloader_regressor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from torchvision import transforms
from image_augment_pairs import *
class RoomDataset(Dataset):
def __init__(self, file_path, train=True, augment=False):
self.file_path = file_path
self.augment = augment
self.train = train
self.img_list = []
self.label_list = []
with open(file_path) as f:
self.list = f.readlines()
f.close()
self.list = [l[:-1] for l in self.list]
self.img_dir = '../../data/image'
self.label_dir = '../../data/height_arr'
def __len__(self):
return len(self.list)
def _to_tensor(self, array):
assert (isinstance(array, np.ndarray))
# handle numpy array
try:
tensor = torch.from_numpy(array).permute(2, 0, 1)
except:
tensor = torch.from_numpy(np.expand_dims(array, axis=2)).permute(2, 0, 1)
# put it from HWC to CHW format
return tensor.float()
def __getitem__(self, index):
image = cv2.imread(os.path.join(self.img_dir, self.list[index] + '.png'))
label = np.load(os.path.join(self.label_dir, self.list[index] + '.npy'))
label = cv2.resize(label, (64, 64))
height, width = label.shape
if self.train and self.augment:
# random rotations
random_rotation(image, label)
# random h-flips
horizontal_flip(image, label)
# random v-flips
vertical_flip(image, label)
# random crops
#scale_augmentation(image, label)
return self._to_tensor(image), self._to_tensor(label)