-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathJIP.py
executable file
·335 lines (316 loc) · 18.9 KB
/
JIP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import os
import time
import argparse
from mp.paths import JIP_dir, telegram_login
from mp.utils.update_bots.telegram_bot import TelegramBot
from train_restore_use_models.preprocess_data import preprocess_data
from train_restore_use_models.CNN_train_restore_use import train_model
from train_restore_use_models.CNN_train_restore_use import restore_train_model
from train_restore_use_models.CNN_train_restore_use import retrain_model
from train_restore_use_models.CNN_train_restore_use import test_model
from train_restore_use_models.CNN_train_restore_use import do_inference
# Structure of JIP_dir/data_dirs:
# /
# |---WORKFLOW_DIR --> JIP_dir/data_dirs for inference or JIP_dir/train_dirs for training or JIP_dir/preprocessed_dirs for preprocessed data
# |---OPERATOR_IN_DIR
# | |---0001
# | | |---img
# | | | |---img.nii
# | | |---seg
# | | |---001.nii
# | | |---002.nii
# | | |---...
# | |---0002
# | | |---img
# | | | |---img.nii
# | | |---seg
# | | |---001.nii
# | | |---002.nii
# | | |---...
# | |---...
# |---OPERATOR_OUT_DIR
# |---OPERATOR_TEMP_DIR
if __name__ == "__main__":
# Build Argumentparser
parser = argparse.ArgumentParser(description='Train, reterain or use a specified model to predict the quality of CT scans.')
parser.add_argument('--noise_type', choices=['blur', 'resolution', 'ghosting', 'noise',
'motion', 'spike'], required=False,
help='Specify the CT artifact on which the model will be trained. '+
'Default model type: blur.')
parser.add_argument('--mode', choices=['preprocess', 'train', 'retrain', 'testID', 'testOOD', 'testIOOD', 'inference'], required=True,
help='Specify in which mode to use the model. Either train a model or use'+
' it for predictions. This can also be used to preprocess data (be)for(e) training.')
parser.add_argument('--datatype', choices=['all', 'train', 'test', 'inference'], required=False,
help='Only necessary for mode preprocessing. Indicates which data should be preprocessed.'+
' If not specified, \'all\' will be used for preprocessing.')
parser.add_argument('--device', action='store', type=int, nargs=1, default=4,
help='Try to train the model on the GPU device with <DEVICE> ID.'+
' Valid IDs: 0, 1, ..., 7'+
' Default: GPU device with ID 4 will be used.')
parser.add_argument('--restore', action='store_const', const=True, default=False,
help='Restore last saved model state and continue training from there.'+
' Restore is also used in the augmentation step of the preprocessing of the train_ds only,'+
' if the augmentation fails at some point. --> This flag is ambigious!'
' Default value: False.')
parser.add_argument('--use_telegram_bot', action='store_const', const=True, default=False,
help='Send message during training through a Telegram Bot'+
' (Token and Chat-ID need to be provided, otherwise an error occurs!).'+
' Default: No Telegram Bot will be used to send messages.')
parser.add_argument('--store_data', action='store_const', const=True, default=False,
help='Store the actual datapoints and save them as .npy after training.'+
' Takes a lot of time and CPU power (when using GPU).'+
' Default: Only necessary data will be stored --> False.')
parser.add_argument('--try_catch_repeat', action='store', type=int, nargs=1, default=0,
help='Try to train the model with a restored state, if an error occurs.'+
' Repeat only <TRY_CATCH_REPEAT> number of times. It can also be used for'+
' preprocessing step, since the augmentation might fail at some point, for'+
' training on a model and retraining a model.'+
' Default value: 0.')
parser.add_argument('--idle_time', action='store', type=int, nargs=1, default=0,
help='The number represents the idle time (waiting time) in seconds after an error'+
' occurs before starting the process again. --> Works only in combination'+
' with --try_catch_repeat.'+
' Default: Idle time is 0 seconds.')
# Define configuration dict and train the model
args = parser.parse_args()
noise = args.noise_type
mode = args.mode
data_type = args.datatype
cuda = args.device
restore = args.restore
store_data = args.store_data
msg_bot = args.use_telegram_bot
try_catch = args.try_catch_repeat
idle_time = args.idle_time
if isinstance(cuda, list):
cuda = cuda[0]
if isinstance(try_catch, list):
try_catch = try_catch[0]
if isinstance(idle_time, list):
idle_time = idle_time[0]
if mode == 'preprocess' and data_type is None:
data_type = 'all'
if mode != 'preprocess' and noise is None:
noise = 'blur'
# Define Telegram Bot
if msg_bot:
bot = TelegramBot(telegram_login)
if cuda < 0 or cuda > 7:
bot.send_msg('GPU device ID out of range (0, ..., 7).')
assert False, 'GPU device ID out of range (0, ..., 7).'
cuda = 'cuda:' + str(cuda)
# -------------------------
# Build environmental vars
# -------------------------
print('Building environmental variables..')
# The environmental vars will later be automatically set by the workflow that triggers the docker container
# data_dirs (for inference)
os.environ["WORKFLOW_DIR"] = os.path.join(JIP_dir, 'data_dirs')
os.environ["OPERATOR_IN_DIR"] = "input"
os.environ["OPERATOR_OUT_DIR"] = "output"
os.environ["OPERATOR_TEMP_DIR"] = "temp"
os.environ["OPERATOR_PERSISTENT_DIR"] = os.path.join(JIP_dir, 'data_dirs', 'persistent') # pre-trained models
# preprocessed_dirs (for preprocessed data (output of this workflow = input for main workflow)
os.environ["PREPROCESSED_WORKFLOW_DIR"] = os.path.join(JIP_dir, 'preprocessed_dirs')
os.environ["PREPROCESSED_OPERATOR_OUT_TRAIN_DIR"] = "output_train"
os.environ["PREPROCESSED_OPERATOR_OUT_TEST_DIR"] = "output_test"
os.environ["PREPROCESSED_OPERATOR_OUT_DATA_DIR"] = "output_data"
# train_dirs (for training data)
os.environ["TRAIN_WORKFLOW_DIR"] = os.path.join(JIP_dir, 'train_dirs')
# test_dirs (for test data)
os.environ["TEST_WORKFLOW_DIR"] = os.path.join(JIP_dir, 'test_dirs')
# Make directories to avoid any errors
paths = [os.path.join(os.environ["WORKFLOW_DIR"], os.environ["OPERATOR_IN_DIR"]),
os.path.join(os.environ["WORKFLOW_DIR"], os.environ["OPERATOR_OUT_DIR"]),
os.path.join(os.environ["WORKFLOW_DIR"], os.environ["OPERATOR_TEMP_DIR"]),
os.path.join(os.environ["WORKFLOW_DIR"], os.environ["OPERATOR_PERSISTENT_DIR"]),
os.path.join(os.environ["PREPROCESSED_WORKFLOW_DIR"], os.environ["PREPROCESSED_OPERATOR_OUT_TRAIN_DIR"]),
os.path.join(os.environ["PREPROCESSED_WORKFLOW_DIR"], os.environ["PREPROCESSED_OPERATOR_OUT_TEST_DIR"]),
os.path.join(os.environ["PREPROCESSED_WORKFLOW_DIR"], os.environ["PREPROCESSED_OPERATOR_OUT_DATA_DIR"]),
os.path.join(os.environ["TRAIN_WORKFLOW_DIR"], os.environ["OPERATOR_IN_DIR"]),
os.path.join(os.environ["TRAIN_WORKFLOW_DIR"], os.environ["OPERATOR_OUT_DIR"]),
os.path.join(os.environ["TEST_WORKFLOW_DIR"], os.environ["OPERATOR_IN_DIR"]),
os.path.join(os.environ["TEST_WORKFLOW_DIR"], os.environ["OPERATOR_OUT_DIR"])]
for path in paths:
if not os.path.exists(path):
os.makedirs(path)
# ------------------------
# Build config dictionary
# ------------------------
# NOTE: Decathlon Dataset will be nr_images x 5 scans big.
# Grand Challenge Dataset will be nr_images x 5 scans big (or less if less scans are available).
# UK Frankfurt Dataset will be nr_images x 5 scans big (or less if less scans are available).
# Mixed Dataset will be nr_images x 15 scans big (or less if less scans are available).
# NOTE: num_intensities embodies the number of quality values 1 to 5, where 1 is a bad quality
# and 5 is a the best quality. This will be transformed into values between 0 and 1 in
# the inference step, whereas 0s is bad quality and 1 is the best quality.
# NOTE: train_on indicates on which dataset to train: {'Decathlon', 'GC', 'mixed'}.
# For institutes that use this method to retrain a pre-trained model with their own dataset can write
# every name they want, since the whole data from preprocessed_dirs/output_train
# will be loaded and used. Thus, this variable will not be considered in this process.
# Its important to note, that the labels should be also provided at preprocessed_dirs/output_train/labels.json
# and that the data in preprocessed_dirs/output_train needs to be preprocessed while considering the defined data structure.
# NOTE: For learning rate decay, set the variables 'lr_decay', 'decay_type' and 'decay_rate'. decay_type can be
# {'exp_decay', 'step_decay', 'mstep_decay', 'plat_decay'}, whereas decay_rate represents gamma in all of these decays except
# plat_decay. --> lr will normally always updated with lr = lr * decay_rate, so consider this when setting decay_rate.
config = {'device': cuda, 'input_shape': (1, 60, 299, 299), 'augmentation': True, 'mode': mode,
'data_type': data_type, 'lr': 1e-3, 'batch_size': 64, 'num_intensities': 5, 'nr_epochs': 250, 'decay_type': 'plat_decay',
'noise': noise, 'weight_decay': 7e-3, 'save_interval': 100, 'msg_bot': msg_bot, 'lr_decay': True, 'decay_rate': 0.9,
'bot_msg_interval': 10, 'nr_images': 25, 'val_ratio': 0.2, 'test_ratio': 0.2, 'augment_strat': 'none',
'train_on': 'mixed', 'data_augmented': True, 'restore': restore, 'store_data': store_data}
# -------------------------
# Preprocess
# -------------------------
if mode == 'preprocess':
dir_name = os.path.join(os.environ["PREPROCESSED_WORKFLOW_DIR"], os.environ["PREPROCESSED_OPERATOR_OUT_TRAIN_DIR"])
augmented_on = os.path.join(dir_name, 'augmented_on.pkl')
if try_catch == 0:
try_catch = 1
for i in range(try_catch):
if i != 0:
time.sleep(idle_time) # Delays for idle_time seconds.
if not restore:
if msg_bot:
bot.send_msg('Start to preprocess data..')
preprocessed, error = preprocess_data(config)
if preprocessed and msg_bot:
bot.send_msg('Finished preprocessing..')
if not preprocessed:
print('Data could not be preprocessed. The following error occured: {}.'.format(error))
if msg_bot:
bot.send_msg('Data could not be preprocessed. The following error occured: {}.'.format(error))
# Check if the directory is empty and whether the augmented_on pickle file does not exist. If so, restore = False, otherwise True.
if (os.path.exists(dir_name) and os.path.isdir(dir_name)):
if len(os.listdir(dir_name)) <= 1 or not os.path.exists(augmented_on):
# Directory contains nothing
restore = False
config['restore'] = False
else:
# Directory is not empty
restore = True
config['restore'] = True
else:
if msg_bot:
bot.send_msg('Start to continue with preprocessing..')
preprocessed, error = preprocess_data(config)
if preprocessed and msg_bot:
bot.send_msg('Finished preprocessing..')
break
if not preprocessed:
print('Data could not be preprocessed. The following error occured: {}.'.format(error))
if msg_bot:
bot.send_msg('Data could not be preprocessed. The following error occured: {}.'.format(error))
# -------------------------
# Train
# -------------------------
if mode == 'train':
dir_name = os.path.join(os.environ["TRAIN_WORKFLOW_DIR"], os.environ["OPERATOR_OUT_DIR"], noise, 'states')
if try_catch == 0:
try_catch = 1
for i in range(try_catch):
if i != 0:
time.sleep(idle_time) # Delays for idle_time seconds.
if not restore:
if msg_bot:
bot.send_msg('Start to initialize and train the model for noise type {}..'.format(noise))
trained, error = train_model(config)
if trained and msg_bot:
bot.send_msg('Finished training for noise type {}..'.format(noise))
break
if not trained:
print('Model for noise type {} could not be initialized/trained. The following error occured: {}.'.format(noise, error))
if msg_bot:
bot.send_msg('Model for noise type {} could not be trained. The following error occured: {}.'.format(noise, error))
# Only restore, if a model state has already been saved, otherwise Index Error
# occurs while trying to extract the highest saved state for restoring a state.
# Check if the directory is empty. If so, restore = False, otherwise True.
if os.path.exists(dir_name) and os.path.isdir(dir_name):
if len(os.listdir(dir_name)) <= 1:
# Directory only contains json splitting file but no model state!
restore = False
config['restore'] = False
else:
# Directory is not empty
restore = True
config['restore'] = True
else:
if msg_bot:
bot.send_msg('Start to restore the model for noise type {} and continue training..'.format(noise))
trained, error = restore_train_model(config)
if trained and msg_bot:
bot.send_msg('Finished training for noise type {}..'.format(noise))
break
if not trained:
print('Model for noise type {} could not be restored/trained. The following error occured: {}.'.format(noise, error))
if msg_bot:
bot.send_msg('Model for noise type {} could not be restored/trained. The following error occured: {}.'.format(noise, error))
# -------------------------
# Retrain
# -------------------------
if mode == 'retrain':
dir_name = os.path.join(os.environ["TRAIN_WORKFLOW_DIR"], os.environ["OPERATOR_OUT_DIR"], noise, 'states')
if try_catch == 0:
try_catch = 1
for i in range(try_catch):
if i != 0:
time.sleep(idle_time) # Delays for idle_time seconds.
if not restore:
if msg_bot:
bot.send_msg('Start to retrain the model for noise type {}..'.format(noise))
trained, error = retrain_model(config)
if trained and msg_bot:
bot.send_msg('Finished retraining for noise type {}..'.format(noise))
break
if not trained:
print('Model for noise type {} could not be retrained. The following error occured: {}.'.format(noise, error))
if msg_bot:
bot.send_msg('Model for noise type {} could not be retrained. The following error occured: {}.'.format(noise, error))
# Only restore, if a model state has already been saved, otherwise Index Error
# occurs while trying to extract the highest saved state for restoring a state.
# Check if the directory is empty. If so, restore = False, otherwise True.
if os.path.exists(dir_name) and os.path.isdir(dir_name):
if len(os.listdir(dir_name)) <= 1:
# Directory only contains json splitting file but no model state!
restore = False
config['restore'] = False
else:
# Directory is not empty
restore = True
config['restore'] = True
else:
if msg_bot:
bot.send_msg('Start to restore the model for noise type {} and continue training..'.format(noise))
trained, error = restore_train_model(config)
if trained and msg_bot:
bot.send_msg('Finished retraining for noise type {}..'.format(noise))
break
if not trained:
print('Model for noise type {} could not be restored/retrained. The following error occured: {}.'.format(noise, error))
if msg_bot:
bot.send_msg('Model for noise type {} could not be restored/retrained. The following error occured: {}.'.format(noise, error))
# -------------------------
# Test
# -------------------------
if 'test' in mode:
if msg_bot:
bot.send_msg('Start testing the pre-trained model..')
tested, error = test_model(config)
if tested and msg_bot:
bot.send_msg('Finished testing..')
if not tested:
print('Testing could not be performed. The following error occured: {}.'.format(error))
if msg_bot:
bot.send_msg('Testing could not be performed. The following error occured: {}.'.format(error))
# -------------------------
# Inference
# -------------------------
if mode == 'inference':
if msg_bot:
bot.send_msg('Start the inference..')
inferred, error = do_inference(config)
if inferred and msg_bot:
bot.send_msg('Finished inference..')
if not inferred:
print('Inference could not be performed. The following error occured: {}.'.format(error))
if msg_bot:
bot.send_msg('Inference could not be performed. The following error occured: {}.'.format(error))