-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar_input.py
99 lines (83 loc) · 3.7 KB
/
cifar_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf
import numpy as np
import cifar10
import cifar100
IMG_WIDTH = 32
IMG_HEIGHT = 32
IMG_DEPTH = 3
def per_image_standardization(image_np):
'''
Ref: https://www.tensorflow.org/api_docs/python/tf/image/per_image_standardization
'''
for i in range(len(image_np)):
mean = np.mean(image_np[i, ...])
# Use adjusted standard deviation here, in case the std == 0.
std = np.max([np.std(image_np[i, ...]), 1.0/np.sqrt(IMG_HEIGHT * IMG_WIDTH * IMG_DEPTH)])
image_np[i,...] = (image_np[i, ...] - mean) / std
return image_np
def random_flip_left_right(image, axis):
'''
Ref: https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right
'''
flip_prop = np.random.randint(low=0, high=2)
if flip_prop == 0:
image = np.flip(image, axis)
return image
def random_crop_and_flip(batch_data, padding_size=2):
'''
Ref: https://www.tensorflow.org/api_docs/python/tf/image/random_crop
'''
cropped_batch = np.zeros(len(batch_data) * IMG_HEIGHT * IMG_WIDTH * IMG_DEPTH).reshape(
len(batch_data), IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH)
for i in range(len(batch_data)):
x_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]
y_offset = np.random.randint(low=0, high=2 * padding_size, size=1)[0]
cropped_batch[i, ...] = batch_data[i, ...][x_offset:x_offset+IMG_HEIGHT, y_offset:y_offset+IMG_WIDTH, :]
cropped_batch[i, ...] = random_flip_left_right(image=cropped_batch[i, ...], axis=1)
return cropped_batch
def padding(data, padding_size=2):
'''
Ref: https://www.tensorflow.org/api_docs/python/tf/image/random_crop
'''
pad_width = ((0, 0), (padding_size, padding_size), (padding_size, padding_size), (0, 0))
data = np.pad(data, pad_width=pad_width, mode='constant', constant_values=0)
return data
def load_data(dataset=10, is_tune=False):
if dataset == 10:
(train_data, train_labels), (test_data, test_labels) = cifar10.load_data()
if dataset == 100:
(train_data, train_labels), (test_data, test_labels) = cifar100.load_data()
if is_tune:
test_data = train_data[:5000]
test_labels = train_labels[:5000]
train_data = train_data[5000:]
train_labels = train_labels[5000:]
# (N, 1) --> (N,)
train_labels = np.squeeze(train_labels)
test_labels = np.squeeze(test_labels)
# per image standarizartion
test_data = per_image_standardization(test_data)
train_data = per_image_standardization(train_data)
print ('Load dataset: [CIFAR%d], is_tune: [%s], is_preprocessed: [%s]'%(dataset, is_tune, 'True'))
print ('Train_data: {}, Test_data: {}'.format(train_data.shape, test_data.shape))
return (train_data, train_labels), (test_data, test_labels)
def generate_augment_train_batch(train_data, train_labels, train_batch_size, is_tune=False):
'''
Ref: https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_input.py
'''
EPOCH_SIZE = 50000 if not is_tune else 50000-5000
offset = np.random.choice(EPOCH_SIZE - train_batch_size, 1)[0]
batch_data = train_data[offset:offset+train_batch_size, ...]
batch_data = padding(batch_data)
batch_data = random_crop_and_flip(batch_data, padding_size=2)
#batch_data = per_image_standardization(batch_data)
batch_label = train_labels[offset:offset+train_batch_size]
return batch_data, batch_label
# check if reshape is right
#(X, Y), (_, _) = load_data(dataset=100)
#b_X, b_y = generate_augment_train_batch(X, Y, 128)
#
# plot an image
#import matplotlib.pyplot as plt
#from scipy.misc import toimage
#plt.imshow(toimage(X[1]))