-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
97 lines (79 loc) · 3.01 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import zipfile
import numpy as np
import os
# (PLEASE DO NOT CHANGE) Set random seed:
np.random.seed(1746)
PREFIX = "digit_"
TEST_STEM = "test_"
TRAIN_STEM = "train_"
def check_and_extract_zipfile(filename, data_dir):
if os.path.isdir(data_dir) and not os.listdir(data_dir):
pass
else:
zip_f = zipfile.ZipFile(filename, 'r')
zip_f.extractall(data_dir)
zip_f.close()
def load_data(data_dir, stem):
"""
Loads data from either the training set or the test set and returns the pixel values and
class labels
"""
data = []
labels = []
for i in range(0, 10):
path = os.path.join(data_dir, PREFIX + stem + str(i) + ".txt")
digits = np.loadtxt(path, delimiter=',')
digit_count = digits.shape[0]
data.append(digits)
labels.append(np.ones(digit_count) * i)
data, labels = np.array(data), np.array(labels)
data = np.reshape(data, (-1, 64))
labels = np.reshape(labels, (-1))
return data, labels
def load_all_data(data_dir, shuffle=True):
'''
Loads all data from the given data directory.
Returns four numpy arrays:
- train_data
- train_labels
- test_data
- test_labels
'''
if not os.path.isdir(data_dir):
raise OSError('Data directory {} does not exist. Try "load_all_data_from_zip" function first.'.format(data_dir))
train_data, train_labels = load_data(data_dir, TRAIN_STEM)
test_data, test_labels = load_data(data_dir, TEST_STEM)
if shuffle:
train_indices = np.random.permutation(train_data.shape[0])
test_indices = np.random.permutation(test_data.shape[0])
train_data, train_labels = train_data[train_indices], train_labels[train_indices]
test_data, test_labels = test_data[test_indices], test_labels[test_indices]
return train_data, train_labels, test_data, test_labels
def load_all_data_from_zip(zipfile, data_dir, shuffle=True):
'''
Loads all the data from the given zip file.
Inputs:
- zipfile: string path to a2digits zipfile
- data_dir: path to directory to extract zip file
- shuffle: whether to randomly permute the data (true by default)
Returns four numpy arrays:
- train_data
- train_labels
- test_data
- test_labels
'''
check_and_extract_zipfile(zipfile, data_dir)
return load_all_data(data_dir, shuffle)
def get_digits_by_label(digits, labels, query_label):
'''
Return all digits in the provided array which match the query label
Input:
- digits: numpy array containing pixel values for digits
- labels: the corresponding digit labels (0-9)
- query_label: the digit label for all returned digits
Returns:
- Numpy array containing all digits matching the query label
'''
assert digits.shape[0] == labels.shape[0]
matching_indices = labels == query_label
return digits[matching_indices]