-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathisao.py
58 lines (50 loc) · 1.88 KB
/
isao.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from torch.utils.data import Dataset, DataLoader
import glob
from PIL import Image
import numpy as np
import os
from torchvision import transforms as transforms
class Isao(Dataset):
def __init__(self, data_dir, use_label, resize = None):
self.data_dir = data_dir
self.use_label = use_label
self.files = glob.glob(data_dir + '/**/*.jpg', recursive=True)
self.files = [f.replace('\\n', '/n') for f in self.files]
self.labels = self.get_label(self.files)
self.label_map = np.eye(len(self.labels))
if resize != None:
self.transform = transforms.Compose(
[transforms.Resize(resize),
transforms.ToTensor()]
)
else:
self.transform = transforms.Compose(
[transforms.ToTensor()]
)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
filepath = self.files[idx]
img = Image.open(filepath)
label = self.get_label(filepath)
label_idx = self.labels.index(label)
label_one_hot = self.label_map[label_idx]
sample = {'img': self.transform(img), 'label_name': [label], 'label_one_hot': label_one_hot}
return sample
def __str__(self):
desc = []
for i in range(len(self.labels)):
s = f'{self.labels[i]} : {str(self.label_map[i])}'
desc.append(s)
return '\n'.join(desc)
def get_label(self, files):
if type(files) == list:
folders = os.listdir(self.data_dir)
label = []
for folder in folders:
label.append('-'.join(folder.split('-')[1:]))
return label
else:
folder_name = files.split('/')[3]
label = '-'.join(folder_name.split('-')[1:])
return label