-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorF_Img.py
64 lines (53 loc) · 2.07 KB
/
TensorF_Img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#aiden | analyzing the integration with tensorflow training the machine with data of images from docker
mport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
learn = tf.contrib.learn
tf.logging.set_verbosity(tf.logging.ERROR)
# Import the dataset
mnist = learn.datasets.load_dataset('mnist')
data = mnist.train.images
labels = np.asarray(mnist.train.labels, dtype=np.int32)
test_data = mnist.test.images
test_labels = np.asarray(mnist.test.labels, dtype=np.int32)
# limit size of datasets for a faster experiment
max_examples = 10000
data = data[:max_examples]
labels = labels[:max_examples]
def display(i):
"""
Display example digits
:param i: example number (not the label)
"""
img = test_data[i]
plt.title('Example %d. Label: %d' % (i, test_labels[i]))
plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r)
display(0) # display example 0, label 7
display(1) # display example 1, label 2
display(8) # display example 8, label 5
# fit a linear classifier
feature_columns = learn.infer_real_valued_columns_from_input(data)
classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)
classifier.fit(data, labels, batch_size=100, steps=1000)
# evaluate linear classifier accuracy
classifier.evaluate(test_data, test_labels)
print(classifier.evaluate(test_data, test_labels)["accuracy"])
# classify some examples
# this will be classified correctly:
# print("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0]))
display(0)
#
# # this will be classified incorrectly:
# print("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8]))
display(8)
# visualize learned weights
weights = classifier.get_variable_value("linear//weight/d/linear//weight/part_0/Ftrl_1")
f, axes = plt.subplots(2, 5, figsize=(10, 4))
axes = axes.reshape(-1)
for i in range(len(axes)):
a = axes[i]
a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)
a.set_title(i)
a.set_xticks(()) # ticks be gone
a.set_yticks(())
plt.show()