1
+ from aam .models .sequence_regressor import SequenceRegressor
2
+ from aam .models .sequence_regressor_v2 import SequenceRegressorV2
3
+ from aam .callbacks import SaveModel
4
+ from keras .callbacks import EarlyStopping
5
+ from AAM .model import GeneratorDataset
6
+ from AAM .model import Classifier
7
+
8
+ import tensorflow as tf
9
+ import pandas as pd
10
+ import numpy as np
11
+ import seaborn as sns
12
+ from sklearn .metrics import roc_curve , auc , precision_recall_curve
13
+ import matplotlib .pyplot as plt
14
+ from sklearn .model_selection import train_test_split , StratifiedKFold
15
+ import biom
16
+ from biom import Table , load_table
17
+ import os
18
+ import sys
19
+ import warnings
20
+
21
+ gpus = tf .config .list_physical_devices ("GPU" )
22
+ if len (gpus ) > 0 :
23
+ tf .config .experimental .set_memory_growth (gpus [0 ], True )
24
+
25
+ warnings .filterwarnings ('ignore' )
26
+
27
+ K = tf .keras
28
+
29
+ def get_sample_type (file_path ):
30
+ filename = os .path .basename (file_path )
31
+ # Remove the 'training_metadata_' prefix and the file extension
32
+ if filename .startswith ('training_metadata_' ):
33
+ sample_type = filename [len ('training_metadata_' ):]
34
+ sample_type = os .path .splitext (sample_type )[0 ]
35
+ return sample_type
36
+ return "Unknown"
37
+
38
+ #function that creates training and valid split and trains each model
39
+ def train_model (train_fp , opt_type , hidden_dim , num_hidden_layers , dropout_rate , learning_rate , beta_1 = None , beta_2 = None , weight_decay = None , momentum = None , model_fp = None , large = True , use_cova = False ):
40
+ training_metadata = pd .read_csv (train_fp , sep = '\t ' , index_col = 0 )
41
+ X = training_metadata .drop (columns = ['study_sample_type' , 'has_covid' ], axis = 1 )
42
+ y = training_metadata [['study_sample_type' , 'has_covid' ]]
43
+ sample_type = get_sample_type (train_fp )
44
+ dir_path = f'trained_models_aam/{ sample_type } '
45
+ if not os .path .exists (dir_path ):
46
+ os .makedirs (dir_path )
47
+ if not large :
48
+ sequence_embedding_fp = 'data/input/asv_embeddings_aam.npy'
49
+ sequence_embedding_dim = 256
50
+ else :
51
+ sequence_embedding_fp = 'data/input/asv_embeddings_large.npy'
52
+ sequence_embedding_dim = 512
53
+
54
+ skf = StratifiedKFold (n_splits = 5 , shuffle = True , random_state = 42 )
55
+
56
+ curr_best_val_loss = np .inf
57
+ curr_best_model = None
58
+ for i , (train_index , valid_index ) in enumerate (skf .split (y , y ['has_covid' ])):
59
+ y_train = y .iloc [train_index ]
60
+ y_valid = y .iloc [valid_index ]
61
+
62
+ if sample_type == 'stool' :
63
+ rarefy_depth = 4000
64
+ else :
65
+ rarefy_depth = 1000
66
+ dataset_train = GeneratorDataset (
67
+ table = 'data/input/merged_biom_table.biom' ,
68
+ metadata = y_train ,
69
+ metadata_column = 'has_covid' ,
70
+ shuffle = True ,
71
+ is_categorical = False ,
72
+ shift = 0 ,
73
+ rarefy_depth = rarefy_depth ,
74
+ scale = 1 ,
75
+ epochs = 100000 ,
76
+ batch_size = 4 ,
77
+ gen_new_tables = True , #only in training dataset
78
+ sequence_embeddings = sequence_embedding_fp ,
79
+ sequence_labels = 'data/input/asv_embeddings_ids.npy' ,
80
+ upsample = False ,
81
+ drop_remainder = False
82
+ )
83
+
84
+ dataset_valid = GeneratorDataset (
85
+ table = 'data/input/merged_biom_table.biom' ,
86
+ metadata = y_valid ,
87
+ metadata_column = 'has_covid' ,
88
+ shuffle = False ,
89
+ is_categorical = False ,
90
+ shift = 0 ,
91
+ rarefy_depth = rarefy_depth ,
92
+ scale = 1 ,
93
+ epochs = 100000 ,
94
+ batch_size = 4 ,
95
+ sequence_embeddings = sequence_embedding_fp ,
96
+ sequence_labels = 'data/input/asv_embeddings_ids.npy' ,
97
+ upsample = False ,
98
+ drop_remainder = False ,
99
+ rarefy_seed = 42
100
+ )
101
+
102
+
103
+ if model_fp is None :
104
+ model = Classifier (hidden_dim = hidden_dim , num_hidden_layers = num_hidden_layers , dropout_rate = dropout_rate , use_cova = use_cova )
105
+ else :
106
+ model = tf .keras .models .load_model (model_fp , compile = False )
107
+ token_shape = tf .TensorShape ([None , sequence_embedding_dim ])
108
+ batch_indicies = tf .TensorShape ([None , 2 ])
109
+ indicies_shape = tf .TensorShape ([None ])
110
+ count_shape = tf .TensorShape ([None , 1 ])
111
+ model .build ([token_shape , batch_indicies , indicies_shape , count_shape ])
112
+ model .summary ()
113
+ if opt_type == 'adam' :
114
+ optimizer = tf .keras .optimizers .Adam (
115
+ learning_rate = tf .keras .optimizers .schedules .CosineDecay (
116
+ initial_learning_rate = 0.0 ,
117
+ warmup_target = learning_rate , # maybe change
118
+ warmup_steps = 0 ,
119
+ decay_steps = 250000 ,
120
+ ),
121
+ use_ema = True ,
122
+ beta_1 = beta_1 ,
123
+ beta_2 = beta_2 ,
124
+ weight_decay = weight_decay
125
+ )
126
+ early_stop = EarlyStopping (patience = 250 , start_from_epoch = 250 , restore_best_weights = False )
127
+ else :
128
+ optimizer = tf .keras .optimizers .legacy .SGD (
129
+ learning_rate = tf .keras .optimizers .schedules .CosineDecay (
130
+ initial_learning_rate = 0.0 ,
131
+ warmup_target = learning_rate , # maybe change
132
+ warmup_steps = 0 ,
133
+ decay_steps = 250000 ,
134
+ ),
135
+ momentum = momentum
136
+ )
137
+ early_stop = EarlyStopping (patience = 250 , start_from_epoch = 250 , restore_best_weights = True )
138
+
139
+ model .compile (optimizer = optimizer , run_eagerly = False )
140
+ #switch loss to val loss
141
+ #pass early stopping for callbacks
142
+ history = model .fit (dataset_train ,
143
+ validation_data = dataset_valid ,
144
+ validation_steps = dataset_valid .steps_per_epoch ,
145
+ epochs = 10000 ,
146
+ steps_per_epoch = dataset_train .steps_per_epoch ,
147
+ callbacks = [
148
+ early_stop
149
+ ])
150
+
151
+ if opt_type == 'adam' :
152
+ model .optimizer .finalize_variable_values (model .trainable_variables )
153
+
154
+ validation_loss = history .history ['val_loss' ]
155
+ train_loss = history .history ['loss' ]
156
+ epochs = np .array (range (len (validation_loss )))
157
+
158
+ min_val_loss = np .min (history .history ['val_loss' ])
159
+ if min_val_loss < curr_best_val_loss :
160
+ curr_best_model = model
161
+ curr_best_val_loss = min_val_loss
162
+
163
+ plt .plot (epochs , validation_loss , color = 'blue' )
164
+ plt .title (f'Validation Loss Per Epoch, Best: { curr_best_val_loss } Final: { min_val_loss } ' )
165
+ plt .plot (epochs , train_loss , color = 'red' )
166
+ plt .savefig (os .path .join (dir_path , f'{ sample_type } _{ i } _model_loss.png' ))
167
+ plt .close ()
168
+ model .save (os .path .join (dir_path , f'{ sample_type } _{ i } _model.keras' ), save_format = 'keras' )
169
+ curr_best_model .save (os .path .join (dir_path , f'{ sample_type } _best_model.keras' ), save_format = 'keras' )
170
+ print (f"\n AAM: Best model saved for { sample_type } samples { opt_type } ." )
0 commit comments