-
Notifications
You must be signed in to change notification settings - Fork 3
/
data.py
44 lines (34 loc) · 1.04 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from torchvision import datasets, transforms
def _permutate_image_pixels(image, permutation):
if permutation is None:
return image
c, h, w = image.size()
image = image.view(-1, c)
image = image[permutation, :]
image.view(c, h, w)
return image
def get_dataset(name, train=True, download=True, permutation=None):
dataset_class = AVAILABLE_DATASETS[name]
dataset_transform = transforms.Compose([
*AVAILABLE_TRANSFORMS[name],
transforms.Lambda(lambda x: _permutate_image_pixels(x, permutation)),
])
return dataset_class(
'./datasets/{name}'.format(name=name), train=train,
download=download, transform=dataset_transform,
)
AVAILABLE_DATASETS = {
'mnist': datasets.MNIST
}
AVAILABLE_TRANSFORMS = {
'mnist': [
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
}
DATASET_CONFIGS = {
'mnist': {'size': 32, 'channels': 1, 'classes': 10}
}