-
Notifications
You must be signed in to change notification settings - Fork 4
/
cnn_train.py
140 lines (101 loc) · 4.65 KB
/
cnn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from keras.models import Sequential
from keras.layers import Convolution2D, Dropout, Dense, Flatten, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator, load_img
import numpy as np
from numpy import array
from keras import regularizers
import cv2
from skimage import color, exposure
from keras.optimizers import SGD
from keras.utils import plot_model
img_rows=33
img_colms=50
img_channels=1 #1 for grayscale and 3 for RGB images
def preprocess_img(img):
hsv = color.rgb2hsv(img)
hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
img = color.hsv2rgb(hsv)
img=np.array(img)
img = cv2.resize(img, (img_colms,img_rows), interpolation = cv2.INTER_AREA)
return img
def get_model():
#init the model
model= Sequential()
#add conv layers and pooling layers (2 of each)
model.add(Convolution2D(32,3,3, input_shape=(img_rows, img_colms, img_channels),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Convolution2D(32,3,3, input_shape=(img_rows, img_colms, img_channels),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25)) #to reduce overfitting
model.add(Flatten())
#Now one hidden(dense) layer:
model.add(Dense(output_dim = 500, activation = 'relu',
kernel_regularizer=regularizers.l2(0.01)
))
model.add(Dropout(0.25))#again for regularization
#output layer
model.add(Dense(output_dim = 7, activation = 'softmax'))
lr = 0.0001
sgd = SGD(lr=lr, decay=1e-6, momentum=0.9, nesterov=True) #custom learning rate, with rate decay enabled
#Now copile it
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary() #get summary
return model
def init_dataset():
#Prepare the images from the dataset
train_datagen=ImageDataGenerator(
rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.,
horizontal_flip = False
)
test_datagen=ImageDataGenerator(rescale=1./255,
horizontal_flip = False)
training_set=train_datagen.flow_from_directory("Dataset/training_set",
target_size = (img_rows, img_colms),
color_mode='grayscale',
batch_size=10,
class_mode='categorical')
test_set=test_datagen.flow_from_directory("Dataset/test_set",
target_size = (img_rows, img_colms),
color_mode='grayscale',
batch_size=10,
class_mode='categorical')
return training_set, test_set
def train_CNN():
#get model:
model=get_model()
#get datasets:
training_set, test_set = init_dataset()
#start training:
history = model.fit_generator(training_set,
samples_per_epoch = 18676,
nb_epoch = 10,
validation_data = test_set,
nb_val_samples =4652)
return history, model
input("Press enter to start training the model. Make sure the dataset is ready, and all files and folders are in place.")
history, model = train_CNN()
#accuracies over 10 epochs:
#train acc: 96.4665%
#test acc : 88.5039%
#Overfit? No. The test set was shit.
dec = str(input("Save model and weights, y/n?"))
if dec.lower() == 'y':
#saving the weights
model.save_weights("weights4.hdf5",overwrite=True)
#saving the model itself in json format:
model_json = model.to_json()
with open("model4.json", "w") as model_file:
model_file.write(model_json)
print("Model has been saved.")
#save the model schema
plot_model(model, to_file='model.png', show_shapes = True)
#check the model on a random image in test set
img = load_img('Dataset\\test_set\\e\\2.jpg',target_size=(33,50))
x=array(img)
img = cv2.cvtColor( x, cv2.COLOR_RGB2GRAY )
img=img.reshape((1,)+img.shape+(1,))
test_datagen = ImageDataGenerator(rescale=1./255)
m=test_datagen.flow(img,batch_size=1)
y_pred=model.predict_generator(m,1)