-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
118 lines (93 loc) · 3.89 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Trains the U-net based on hdf_file data and config parameters
"""
# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------
import numpy as np
import tensorflow as tf
import argparse
import time
from keras.optimizers import Adam
from utils.configfiles import read_json_config
from utils.data_processing import get_normalized_data, get_time_info
from utils.model import get_model
# -----------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# -----------------------------------------------------------------------------
def get_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Trains and saves the U-net.')
parser.add_argument('--training-config',
default='training.json',
type=str,
help='Name the JSON file the program uses to '
'compile the model, Default: training.json')
parser.add_argument('--save-model',
action='store_true',
default=True,
help='Choose to save the model after fitting, Default: True')
arguments = parser.parse_args()
return arguments
if __name__ == '__main__':
# -------------------------------------------------------------------------
# PRELIMINARIES
# -------------------------------------------------------------------------
print('')
print('TRAINING AND SAVING MODEL')
print('')
# Start stopwatch
script_start = time.time()
# Get command line arguments
args = get_arguments()
# Get JSON config file
config_path = f'config_files/{args.training_config}'
config = read_json_config(config_path)
# Set seeds
seed = config['random_seed']
np.random.seed(seed)
tf.random.set_seed(seed)
# -------------------------------------------------------------------------
# ACQUIRING THE TRAINING DATA
# -------------------------------------------------------------------------
hdf_file_name = config['training_hdf_file_name']
X_train, y_train = get_normalized_data(hdf_file_name)
# -------------------------------------------------------------------------
# ACQUIRING, COMPILING, AND TRAINING THE MODEL
# -------------------------------------------------------------------------
# Get time data
time_info = get_time_info(hdf_file_name, 0)
sample_length = time_info['sample_length']
target_sampling_rate = time_info['target_sampling_rate']
img_size = int(sample_length * target_sampling_rate)
# Load model
print('Loading model... ', end='', flush=True)
model = get_model(img_size=img_size)
print('Done!', flush=True)
# Compile model
learning_rate = float(config['learning_rate'])
loss = str(config['loss'])
accuracy_metrics = list(config['accuracy_metrics'])
print('Compiling model... ', end='', flush=True)
print(accuracy_metrics)
model.compile(optimizer=Adam(learning_rate=learning_rate), loss=loss, metrics=accuracy_metrics)
print('Done!', flush=True)
# Fit model
print('')
print('FITTING MODEL: ', flush=True)
print('')
validation_split = float(config['validation_split'])
epochs = int(config['epochs'])
batch_size = int(config['batch_size'])
model.fit(X_train, y_train, validation_split=validation_split, epochs=epochs, batch_size=batch_size)
print('')
# Save model
if args.save_model:
print('Saving model... ', end='', flush=True)
model_name = str(config['model_name'])
model_file_path = f"outputs/models/{model_name}"
model.save(model_file_path)
print('Done!', flush=True)
# Print total runtime
print('')
print('Total runtime: {:.1f}s'.format(time.time()-script_start))
print('')