-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
74 lines (60 loc) · 2.43 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import time
import os
import torch
import random
import PIL.Image as Image
import torch.utils.data as data
from pdb import set_trace as bp
class dataset(data.Dataset):
def __init__(self, Config, anno, common_aug=None, totensor=None, is_train=True):
self.root_path = Config.rawdata_root
self.numcls = Config.numcls
self.dataset = Config.dataset
self.paths = anno['ImageName'].tolist()
self.labels = anno['label'].tolist()
self.common_aug = common_aug
self.totensor = totensor
self.cfg = Config
self.is_train = is_train
if is_train == True and (Config.module == 'OEL' or Config.module == 'LIO'):
print('load and store positive_image_list for OEL module')
self.positive_image_list = self.get_positive_images(self.paths, self.labels)
def __len__(self):
return len(self.paths)
def __getitem__(self, item):
imgpath = os.path.join(self.root_path, self.paths[item])
img = self.pil_loader(imgpath)
if self.is_train:
img = self.common_aug(img) if not self.common_aug is None else img
img = self.totensor(img)
label = self.labels[item] - 1
return img, label, self.paths[item]
def pil_loader(self, imgpath):
with open(imgpath, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def get_positive_images(self, imgpath_list, label_list):
# For OEL module, we load and store positive_images. It takes some minutes (2~4 min)!
start = time.time()
positive_image_list = {}
for i in range(self.numcls):
positive_image_list[i] = []
for imgpath, label in zip(imgpath_list, label_list):
imgpath = os.path.join(self.root_path, imgpath)
img = self.pil_loader(imgpath)
# if self.is_train:
# img = self.common_aug(img) if not self.common_aug is None else img
img = self.totensor(img)
label = label - 1
positive_image_list[label].append(img)
print('Time check for positive_image_list load and store:', time.time() - start)
return positive_image_list
def collate_fn(batch):
imgs = []
label = []
img_name = []
for sample in batch:
imgs.append(sample[0])
label.append(sample[1])
img_name.append(sample[-1])
return torch.stack(imgs, 0), label, img_name