-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
77 lines (55 loc) · 2.76 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import PIL
import numpy as np
from torchvision.datasets import CIFAR10
class CIFAR10Sup(CIFAR10):
"""
CIFAR10 subclass to extract a subset of samples (sup_num) for supervised training.
"""
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, sup_num=4000, val_num=1000, random_seed=89):
super(CIFAR10Sup, self).__init__(root, train=train, transform=transform,
target_transform=target_transform, download=download)
idx = np.random.RandomState(
seed=random_seed).permutation(self.__len__())
self.data = self.data[idx[:sup_num]]
self.targets = np.array(self.targets)[idx[:sup_num]]
self.transform = transform[0]
class CIFAR10Unsup(CIFAR10):
"""
CIFAR10 subclass to extract a subset of samples (sup_num) for unsupervised training.
- Each sample is subject to two differnt sets of transformations
"""
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, sup_num=4000, val_num=1000, random_seed=89):
super(CIFAR10Unsup, self).__init__(root, train=train, transform=transform,
target_transform=target_transform, download=download)
idx = np.random.RandomState(
seed=random_seed).permutation(self.__len__())
self.data = self.data[idx[sup_num:-val_num]]
self.targets = None
self.transform_unsup = transform[0]
self.transform_unsup_aug = transform[1]
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (img_unsup, img_unsup_aug) obtained from two different sets of transformations on self.data[idx]
"""
img = self.data[index]
img = PIL.Image.fromarray(img)
if self.transform_unsup is not None:
img_unsup = self.transform_unsup(img)
if self.transform_unsup_aug is not None:
img_unsup_aug = self.transform_unsup_aug(img)
return img_unsup, img_unsup_aug
class CIFAR10Val(CIFAR10):
"""
CIFAR10 subclass to extract a subset of samples (sup_num) for validation.
"""
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, sup_num=4000, val_num=1000, random_seed=89):
super(CIFAR10Val, self).__init__(root, train=train, transform=transform,
target_transform=target_transform, download=download)
idx = np.random.RandomState(
seed=random_seed).permutation(self.__len__())
self.data = self.data[idx[-val_num:]]
self.targets = np.array(self.targets)[idx[-val_num:]]
self.transform = transform[0]