-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
125 lines (108 loc) · 3.59 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torchvision
from typing import Any, Callable, Optional
from PIL import Image
from torchvision.datasets.folder import default_loader
from transforms import build_transform
from torch.utils import data
class CIFAR10Dataset(torchvision.datasets.CIFAR10):
def __getitem__(self, index: int):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.train:
return img, target, index
else:
return img, target, index + 50000
class CIFAR100Dataset(CIFAR10Dataset):
base_folder = "cifar-100-python"
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
train_list = [["train", "16019d7e3df5f24257cddd939b257f8d"]]
test_list = [["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"]]
meta = {
"filename": "meta",
"key": "fine_label_names",
"md5": "7973b15100ade9c7d40fb424638fde48",
}
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
class CustomImageFolder(torchvision.datasets.DatasetFolder):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
def __getitem__(self, index: int):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, index
def build_dataset(type, args):
is_train = type == "train"
transform = build_transform(is_train, args)
root = args.data_path
if args.dataset == "CIFAR-10":
dataset = data.ConcatDataset(
[
CIFAR10Dataset(
root=root + "CIFAR-10",
train=True,
download=True,
transform=transform,
),
CIFAR10Dataset(
root=root + "CIFAR-10",
train=False,
download=True,
transform=transform,
),
]
)
elif args.dataset == "CIFAR-100":
dataset = data.ConcatDataset(
[
CIFAR100Dataset(
root=root + "CIFAR-100",
train=True,
download=True,
transform=transform,
),
]
)
elif args.dataset == "MIAD":
if is_train:
path = root
else:
path = root
dataset = CustomImageFolder(root=path, transform=transform)
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
return dataset