-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_cifar_10.py
77 lines (61 loc) · 2.38 KB
/
load_cifar_10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
CIFAR10/100 Parser - Returns images and labels
from pickled CIFAR data - https://www.cs.toronto.edu/~kriz/cifar.html
Theo Jaquenoud - @thjaquenoud
See:
https://github.com/dmezh/convmixer-tf
"""
import numpy as np
import pickle
def unpickle(file):
with open(file, "rb") as fo:
data = pickle.load(fo, encoding="bytes")
return data
def load_cifar_10_data(data_dir, negatives=False):
"""
Return train_data, train_filenames, train_labels, test_data, test_filenames, test_labels
"""
meta_data_dict = unpickle(data_dir + "/batches.meta")
cifar_label_names = meta_data_dict[b"label_names"]
cifar_label_names = np.array(cifar_label_names)
# training data
cifar_train_data = None
cifar_train_filenames = []
cifar_train_labels = []
for i in range(1, 6):
cifar_train_data_dict = unpickle(data_dir + "/data_batch_{}".format(i))
if i == 1:
cifar_train_data = cifar_train_data_dict[b"data"]
else:
cifar_train_data = np.vstack(
(cifar_train_data, cifar_train_data_dict[b"data"])
)
cifar_train_filenames += cifar_train_data_dict[b"filenames"]
cifar_train_labels += cifar_train_data_dict[b"labels"]
cifar_train_data = cifar_train_data.reshape((len(cifar_train_data), 3, 32, 32))
if negatives:
cifar_train_data = cifar_train_data.transpose(0, 2, 3, 1).astype(np.float32)
else:
cifar_train_data = np.rollaxis(cifar_train_data, 1, 4)
cifar_train_filenames = np.array(cifar_train_filenames)
cifar_train_labels = np.array(cifar_train_labels)
cifar_test_data_dict = unpickle(data_dir + "/test_batch")
cifar_test_data = cifar_test_data_dict[b"data"]
cifar_test_filenames = cifar_test_data_dict[b"filenames"]
cifar_test_labels = cifar_test_data_dict[b"labels"]
cifar_test_data = cifar_test_data.reshape((len(cifar_test_data), 3, 32, 32))
if negatives:
cifar_test_data = cifar_test_data.transpose(0, 2, 3, 1).astype(np.float32)
else:
cifar_test_data = np.rollaxis(cifar_test_data, 1, 4)
cifar_test_filenames = np.array(cifar_test_filenames)
cifar_test_labels = np.array(cifar_test_labels)
return (
cifar_train_data,
cifar_train_filenames,
cifar_train_labels,
cifar_test_data,
cifar_test_filenames,
cifar_test_labels,
cifar_label_names,
)