-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain1.py
108 lines (75 loc) · 3.15 KB
/
main1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#THIS FILE CREATES FEATURE AND LABEL FILES.
import os
import re
import tensorflow as tf
import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
import pandas as pd
import sklearn
from sklearn import cross_validation
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.svm import SVC, LinearSVC
import matplotlib.pyplot as plt
import pickle
model_dir = 'imagenet'
images_dir = 'images/'
list_images = [images_dir+f for f in os.listdir(images_dir) if re.search('jpeg|JPEG|jpg|JPG', f)]
print(list_images)
# print(len(list_images))
#To use TensorFlow, you should define a graph that represents the description of computations. Then these computations will be executed within what is called sessions
def create_graph():
with gfile.FastGFile(os.path.join(
model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def extract_features(list_images):
nb_features = 2048
features = np.empty((len(list_images), nb_features))
labels = []
create_graph()
with tf.Session() as sess:
next_to_last_tensor = sess.graph.get_tensor_by_name('pool_3:0')
for ind, image in enumerate(list_images):
# print(ind)
# if (ind % 100 == 0):
# print('Processing %s...' % (image))
image_data = gfile.FastGFile(image, 'rb').read()
predictions = sess.run(next_to_last_tensor,
{'DecodeJpeg/contents:0': image_data})
features[ind, :] = np.squeeze(predictions)
labels.append(re.split('_\d+', image.split('/')[1])[0])
return features, labels
features,labels = extract_features(list_images)
pickle.dump(features, open('features', 'wb'))
pickle.dump(labels, open('labels', 'wb'))
features = pickle.load(open('features'))
labels = pickle.load(open('labels'))
X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.2, random_state=42)
# print(X_train, y_train)
clf = LinearSVC(C=1.0, loss='l2', penalty='l2',multi_class='ovr')
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
def plot_confusion_matrix(y_true,y_pred):
cm_array = confusion_matrix(y_true,y_pred)
true_labels = np.unique(y_true)
pred_labels = np.unique(y_pred)
plt.imshow(cm_array[:-1,:-1], interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion matrix", fontsize=16)
cbar = plt.colorbar(fraction=0.046, pad=0.04)
cbar.set_label('Number of images', rotation=270, labelpad=30, fontsize=12)
xtick_marks = np.arange(len(true_labels))
ytick_marks = np.arange(len(pred_labels))
plt.xticks(xtick_marks, true_labels, rotation=90)
plt.yticks(ytick_marks,pred_labels)
plt.tight_layout()
plt.ylabel('True label', fontsize=14)
plt.xlabel('Predicted label', fontsize=14)
fig_size = plt.rcParams["figure.figsize"]
fig_size[0] = 12
fig_size[1] = 12
plt.rcParams["figure.figsize"] = fig_size
plt.show()
# print("Accuracy: {0:0.1f}%\n".format(accuracy_score(y_test,y_pred)*100))
# plot_confusion_matrix(y_test,y_pred)