-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlenet_mnist.py
112 lines (95 loc) · 4.44 KB
/
lenet_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# import the necessary packages
from pyimagesearch.cnn.networks.lenet import LeNet
from sklearn.model_selection import train_test_split
from keras.datasets import mnist
from keras.optimizers import SGD
from keras.utils import np_utils
from keras import backend as K
import numpy as np
import argparse
import cv2
import tensorflow as tf
# Set if memory growth should be enabled for a PhysicalDevice.
physical_devices = tf.config.experimental.list_physical_devices('GPU')
for physical_device in physical_devices:
tf.config.experimental.set_memory_growth(physical_device, True)
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument('-s', '--save-model', type=int, default=-1,
help='(optional) whether or not model should be saved to disk')
ap.add_argument('-l', '--load-model', type=int, default=-1,
help='(optional) whether or not pre-trained model should be loaded')
ap.add_argument('-w', '--weights', type=str,
help='(optional) path to weights file')
args = vars(ap.parse_args())
# grab the MNIST dataset (if this is your first time running this
# script, the download may take a minute -- the 55MB MNIST dataset
# will be downloaded)
print('[INFO] downloading MNIST...')
((trainData, trainLabels), (testData, testLabels)) = mnist.load_data()
# if we are using "channels first" ordering, then reshape the
# design matrix such that the matrix is:
# num_samples x depth x rows x columns
if K.image_data_format() == 'channels_first':
trainData = trainData.reshape((trainData.shape[0], 1, 28, 28))
testData = testData.reshape((testData.shape[0], 1, 28, 28))
# otherwise, we are using "channels last" ordering, so the design
# matrix shape should be: num_samples x rows x columns x depth
else:
trainData = trainData.reshape((trainData.shape[0], 28, 28, 1))
testData = testData.reshape((testData.shape[0], 28, 28, 1))
# scale data to the range of [0, 1], like normalization
trainData = trainData.astype('float32') / 255.0
testData = testData.astype('float32') / 255.0
# transform the training and testing labels into vectors in the
# range [0, classes] -- this generates a vector for each label,
# where the index of the label is set to `1` and all other entries
# to `0`; in the case of MNIST, there are 10 class labels
trainLabels = np_utils.to_categorical(trainLabels, 10)
testLabels = np_utils.to_categorical(testLabels, 10)
# initialize the optimizer and model
print('[INFO] compiling model...')
opt = SGD(lr=1e-2)
model = LeNet.build(numChannels=1, imgRows=28, imgCols=28,
numClasses=10,
weightsPath=args['weights'] if args['load_model'] > 0 else None)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=['accuracy'])
# only train and evaluate the model if we *are not* loading a
# pre-existing model
if args['load_model'] < 0:
print('[INFO] training...')
model.fit(trainData, trainLabels, batch_size=128, epochs=20,
verbose=1)
# show the accuracy on the testing set
print('[INFO] evaluating...')
(loss, accuracy) = model.evaluate(testData, testLabels,
batch_size=128, verbose=1)
print(f'[INFO] accuracy: {accuracy: .2f}%')
# check to see if the model should be saved to file
if args['save_model'] > 0:
print('[INFO] dumping weights to file...')
model.save_weights(args['weights'], overwrite=True)
# randomly select a few testing digits
for i in np.random.choice(np.arange(0, len(testLabels)), size=(10,)):
# classify the digit
probs = model.predict(testData[np.newaxis, i])
prediction = probs.argmax(axis=1)
# extract the image from the testData if using 'channels_first'
# ordering
if K.image_data_format() == 'channels_first':
image = (testData[i][0] * 255).astype('uint8')
# otherwise we are using 'channels_last' oerdering
else:
image = (testData[i] * 255).astype('uint8')
# merge the channels into one image
image = cv2.merge([image] * 3)
# resize the image from a 28 x 28 image to a 96 x 96 image so we
# can better see it
image = cv2.resize(image, (96, 96), interpolation=cv2.INTER_LINEAR)
# show the image and prediction
cv2.putText(image, str(prediction[0]), (5, 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 255, 0), 2)
print(f'[INFO] Predicted: {prediction[0]} , Actual: {np.argmax(testLabels[i])}')
cv2.imshow('Digit', image)
cv2.waitKey(0)