-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
executable file
·136 lines (110 loc) · 4.12 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from PIL import Image
import json
import numpy as np
import torchvision.transforms as transforms
import os
import random
identity = lambda x:x
class SubDataset:
def __init__(self, sub_meta, cl, transform=transforms.ToTensor(), target_transform=identity, min_size=50):
self.sub_meta = sub_meta
self.cl = cl
self.transform = transform
self.target_transform = target_transform
if len(self.sub_meta) < min_size:
idxs = [i%len(self.sub_meta) for i in range(min_size)]
self.sub_meta = np.array(self.sub_meta)[idxs].tolist()
def __getitem__(self, i):
image_path = os.path.join(self.sub_meta[i])
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
target = self.target_transform(self.cl)
return img, target
def __len__(self):
return len(self.sub_meta)
# Full class
class SimpleDataset:
def __init__(self, data_file, transform, target_transform=identity):
with open(data_file, 'r') as f:
self.meta = json.load(f)
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, i):
image_path = os.path.join(self.meta['image_names'][i])
img = Image.open(image_path).convert('RGB')
img = self.transform(img)
target = self.target_transform(self.meta['image_labels'][i])
return img, target
def __len__(self):
return len(self.meta['image_names'])
# Few-shot Multi-domain
class MultiSetDataset:
def __init__(self, data_files, batch_size, transform):
self.cl_list = np.array([])
self.sub_dataloader = []
self.n_classes = []
for data_file in data_files:
with open(data_file, 'r') as f:
meta = json.load(f)
cl_list = np.unique(meta['image_labels']).tolist()
self.cl_list = np.concatenate((self.cl_list, cl_list))
sub_meta = {}
for cl in cl_list:
sub_meta[cl] = []
for x,y in zip(meta['image_names'], meta['image_labels']):
sub_meta[y].append(x)
for cl in cl_list:
sub_dataset = SubDataset(sub_meta[cl], cl, transform=transform, min_size=batch_size)
self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False))
self.n_classes.append(len(cl_list))
def __getitem__(self,i):
return next(iter(self.sub_dataloader[i]))
def __len__(self):
return len(self.cl_list)
def lens(self):
return self.n_classes
class MultiEpisodicBatchSampler(object):
def __init__(self, n_classes, n_way, n_episodes):
self.n_classes = n_classes
self.n_way = n_way
self.n_episodes = n_episodes
self.n_domains = len(n_classes)
def __len__(self):
return self.n_episodes
def __iter__(self):
domain_list = [i%self.n_domains for i in range(self.n_episodes)]
random.shuffle(domain_list)
for i in range(self.n_episodes):
domain_idx = domain_list[i]
start_idx = sum(self.n_classes[:domain_idx])
yield torch.randperm(self.n_classes[domain_idx])[:self.n_way] + start_idx
# Few-shot Single-domain
class SetDataset:
def __init__(self, data_file, batch_size, transform):
with open(data_file, 'r') as f:
self.meta = json.load(f)
self.cl_list = np.unique(self.meta['image_labels']).tolist()
self.sub_meta = {}
for cl in self.cl_list:
self.sub_meta[cl] = []
for x,y in zip(self.meta['image_names'],self.meta['image_labels']):
self.sub_meta[y].append(x)
self.sub_dataloader = []
for cl in self.cl_list:
sub_dataset = SubDataset(self.sub_meta[cl], cl, transform=transform)
self.sub_dataloader.append(torch.utils.data.DataLoader(sub_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False))
def __getitem__(self, i):
return next(iter(self.sub_dataloader[i]))
def __len__(self):
return len(self.cl_list)
class EpisodicBatchSampler(object):
def __init__(self, n_classes, n_way, n_episodes):
self.n_classes = n_classes
self.n_way = n_way
self.n_episodes = n_episodes
def __len__(self):
return self.n_episodes
def __iter__(self):
for i in range(self.n_episodes):
yield torch.randperm(self.n_classes)[:self.n_way]