forked from foamliu/Car-Recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
55 lines (50 loc) · 2.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import keras
from resnet_152 import resnet152_model
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
from keras.callbacks import ReduceLROnPlateau
img_width, img_height = 224, 224
num_channels = 3
train_data = 'data/train'
valid_data = 'data/valid'
num_classes = 196
num_train_samples = 6549
num_valid_samples = 1595
verbose = 1
batch_size = 8
num_epochs = 100
patience = 50
if __name__ == '__main__':
# build a classifier model
model = resnet152_model(img_height, img_width, num_channels, num_classes)
# prepare data augmentation configuration
train_data_gen = ImageDataGenerator(rotation_range=20.,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2,
horizontal_flip=True)
valid_data_gen = ImageDataGenerator()
# callbacks
tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
log_file_path = 'logs/training.log'
csv_logger = CSVLogger(log_file_path, append=False)
early_stop = EarlyStopping('val_accuracy', patience=patience)
reduce_lr = ReduceLROnPlateau('val_accuracy', factor=0.1, patience=int(patience / 4), verbose=1)
trained_models_path = 'models/model'
model_names = trained_models_path + '.{epoch:02d}-{val_accuracy:.2f}.hdf5'
model_checkpoint = ModelCheckpoint(model_names, monitor='val_accuracy', verbose=1, save_best_only=True)
callbacks = [tensor_board, model_checkpoint, csv_logger, early_stop, reduce_lr]
# generators
train_generator = train_data_gen.flow_from_directory(train_data, (img_width, img_height), batch_size=batch_size,
class_mode='categorical')
valid_generator = valid_data_gen.flow_from_directory(valid_data, (img_width, img_height), batch_size=batch_size,
class_mode='categorical')
# fine tune the model
model.fit(
train_generator,
steps_per_epoch=num_train_samples / batch_size/10,
validation_data=valid_generator,
validation_steps=num_valid_samples / batch_size/10,
epochs=num_epochs,
callbacks=callbacks,
verbose=verbose)