-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
116 lines (92 loc) · 4.14 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from utils.io import ls
class TLessDataset(Dataset):
"""Dataset class for loading the texture-less data from memory."""
def __init__(self, path, transform=None):
"""
Args:
path (string): Path to the dataset.
"""
print("Loading texture-less dataset from:", path)
self.transform = transform
self.images = []
self.labels = []
objects = sorted(os.listdir(path))
for o in objects:
obj_dir = os.path.join(path, o)
if os.path.isdir(obj_dir):
sequences = os.listdir(obj_dir)
for s in sequences:
seq_dir = os.path.join(obj_dir, s)
if os.path.isdir(seq_dir):
self.images += [os.path.join(seq_dir, p) for p in ls(seq_dir, '.png')]
self.labels += [os.path.join(seq_dir, p) for p in ls(seq_dir, '.npz')]
def __len__(self):
"""Return the size of dataset."""
return len(self.images)
def __getitem__(self, idx):
"""Get the item at index idx."""
# Get the data and label
data = cv2.imread(self.images[idx])
dmap = np.load(self.labels[idx])['dmap'].astype(np.float32)
nmap = np.load(self.labels[idx])['nmap'].astype(np.float32)
mask = dmap >= 1
# Apply transformation if any
if self.transform:
data = self.transform(data)
# Return the data and label
return data, (dmap, nmap, mask)
class TransparentDataset(Dataset):
"""Dataset class for loading the transparent data from memory."""
def __init__(self, path, single_object=False, envs=None, seqs=None, transform=None):
"""
Args:
path (string): Path to the dataset.
single_object (bool): If True, only load data from one object. If False, 'path' is the path to the
directory containing all objects. Default: False.
envs (list): List of world environments to load. If None, load all environments. Default: None.
seqs (list): List of sequences to load. If None, load all sequences. Default: None.
transform (callable, optional): Optional transform to be applied on a sample.
"""
print("Loading transparent dataset from:", path)
self.transform = transform
self.images = []
self.labels = []
if single_object:
objects = [path]
else:
objects = [os.path.join(path, o) for o in sorted(os.listdir(path)) if os.path.isdir(os.path.join(path, o))]
for o in objects:
sequences = [os.path.join(o, s) for s in sorted(os.listdir(o)) if
os.path.isdir(os.path.join(o, s)) and (seqs is None or s in seqs)]
for s in sequences:
images_path = os.path.join(s, 'images')
images = ls(images_path, '.png')
if envs is not None:
# Keep only the specified environments
images = [i for i in images if i.split('/')[-1].split('_')[-1].split('.')[0] in envs]
labels_path = s
labels = []
for y in ls(labels_path, '.npz'):
labels += [y] * (len(envs) if envs is not None else 5)
if len(images) == len(labels):
self.images += [os.path.join(images_path, x) for x in images]
self.labels += [os.path.join(labels_path, y) for y in labels]
def __len__(self):
"""Return the size of dataset."""
return len(self.images)
def __getitem__(self, idx):
"""Get the item at index idx."""
# Get the data and label
data = cv2.imread(self.images[idx])
dmap = np.load(self.labels[idx])['dmap'].astype(np.float32)
nmap = np.load(self.labels[idx])['nmap'].astype(np.float32)
mask = dmap >= 1
# Apply transformation if any
if self.transform:
data = self.transform(data)
# Return the data and label
return data, (dmap, nmap, mask)