-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
80 lines (64 loc) · 2.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import tensorflow as tf
from utils import pload
from model import create_model
# from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
import sys
import os
import datetime
import matplotlib.pyplot as plt
from pathlib import Path
def COMPILE():
# Import model from model.py
model = create_model()
model.compile(optimizer='{some optimizer here}',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
model = COMPILE()
X_train, y_train, X_val, y_val = pload()
# Callbacks
# Comment out if you want to use callbacks
#log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
#checkpoint_path = "training_1/cp.ckpt"
# cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
# save_weights_only=True,
# verbose=1, period=10)
# After validation_data add the following:
# ,callbacks=[tensorboard_callback, cp_callback]
Epochs = input("epochs: ")
history = model.fit(X_train, y_train,
epochs=int(Epochs),
validation_data=(X_val, y_val))
def SAVE():
if '-h5' in sys.argv:
current_dir = Path.cwd()
save_dir = 'save_model/my_model.h5'
model.save(save_dir)
print("The model has been saved")
print(f"At {current_dir / save_dir}")
elif '-s' in sys.argv:
current_dir = Path.cwd()
save_dir = 'save_model/my_model'
model.save(save_dir)
print("The model has been saved")
print(f"At {current_dir / save_dir}")
else:
pass
def EVALUATE():
if '-e' in sys.argv:
model.summary()
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.3, 1])
plt.legend(loc='lower right')
test_loss, test_acc = model.evaluate(X_val, y_val, verbose=2)
plt.show()
else:
pass
if __name__ == '__main__':
LOAD()
EVALUATE()
SAVE()