Skip to content

Commit 08b760e

Browse files
committed
fixed build script errors
1 parent 4ada575 commit 08b760e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2473
-27
lines changed

AAM/.ipynb_checkpoints/model-checkpoint.py

+374
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import seaborn as sns
2+
from sklearn.metrics import roc_curve, auc, precision_recall_curve
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
def plot_auroc_auprc(nares_predictions, forehead_predictions, stool_predictions, inside_floor_predictions):
7+
sample_data = {
8+
"Nares": nares_predictions[0],
9+
"Forehead": forehead_predictions[0],
10+
"Stool": stool_predictions[0],
11+
"Inside floor": inside_floor_predictions[0],
12+
}
13+
14+
# set up
15+
palette = ["#dc9766", "#d32f88", "#914f1f", "#bf64d7"]
16+
colors = sns.color_palette(palette)
17+
plt.figure(figsize=(10, 5))
18+
19+
# AUROC
20+
plt.subplot(1, 2, 1)
21+
for (sample, (y_pred, y_true)), color in zip(sample_data.items(), colors):
22+
fpr, tpr, _ = roc_curve(y_true, y_pred)
23+
roc_auc = auc(fpr, tpr)
24+
plt.plot(fpr, tpr, color=color, label=f"{sample}: AUROC={roc_auc:.2f}")
25+
plt.xlabel("1 - Specificity")
26+
plt.ylabel("Sensitivity")
27+
plt.xticks(np.arange(0.0, 1.1, 0.25))
28+
plt.yticks(np.arange(0.0, 1.1, 0.25))
29+
plt.xticks(np.arange(0.0, 1.1, 0.125), minor=True)
30+
plt.yticks(np.arange(0.0, 1.1, 0.125), minor=True)
31+
plt.tick_params(which="minor", length=0)
32+
plt.grid(True, linestyle="-", alpha=0.4)
33+
plt.grid(True, which="minor", linestyle="-", alpha=0.4)
34+
legend = plt.legend(title="Sample types", framealpha=1, facecolor="white", edgecolor="none", labelspacing=1.3, fontsize="medium")
35+
legend._legend_box.align = "left"
36+
37+
# AUPRC
38+
plt.subplot(1, 2, 2)
39+
for (sample, (y_pred, y_true)), color in zip(sample_data.items(), colors):
40+
precision, recall, _ = precision_recall_curve(y_true, y_pred)
41+
pr_auc = auc(recall, precision)
42+
plt.plot(recall, precision, color=color, label=f"{sample}: AUPRC={pr_auc:.2f}")
43+
plt.xlabel("Recall")
44+
plt.ylabel("Precision")
45+
plt.xticks(np.arange(0.0, 1.1, 0.25))
46+
plt.yticks(np.arange(0.0, 1.1, 0.25))
47+
plt.xticks(np.arange(0.0, 1.1, 0.125), minor=True)
48+
plt.yticks(np.arange(0.0, 1.1, 0.125), minor=True)
49+
plt.tick_params(which="minor", length=0)
50+
plt.grid(True, linestyle="-", alpha=0.4)
51+
plt.grid(True, which="minor", linestyle="-", alpha=0.4)
52+
legend = plt.legend(title="Sample types", framealpha=1, facecolor="white", edgecolor="none", labelspacing=1.3, fontsize="medium")
53+
legend._legend_box.align = "left"
54+
55+
# adjust layout
56+
plt.tight_layout()
57+
plt.subplots_adjust(wspace=0.3)
58+
59+
plt.savefig('figures/auroc_auprc_aam.png')
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from aam.models.sequence_regressor import SequenceRegressor
2+
from aam.models.sequence_regressor_v2 import SequenceRegressorV2
3+
from aam.callbacks import SaveModel
4+
from keras.callbacks import EarlyStopping
5+
from AAM.model import GeneratorDataset
6+
from AAM.model import Classifier
7+
8+
import tensorflow as tf
9+
import pandas as pd
10+
import numpy as np
11+
import seaborn as sns
12+
from sklearn.metrics import roc_curve, auc, precision_recall_curve
13+
import matplotlib.pyplot as plt
14+
from sklearn.model_selection import train_test_split, StratifiedKFold
15+
import biom
16+
from biom import Table, load_table
17+
import os
18+
import sys
19+
import warnings
20+
21+
gpus = tf.config.list_physical_devices("GPU")
22+
if len(gpus) > 0:
23+
tf.config.experimental.set_memory_growth(gpus[0], True)
24+
25+
warnings.filterwarnings('ignore')
26+
27+
K = tf.keras
28+
29+
def get_sample_type(file_path):
30+
filename = os.path.basename(file_path)
31+
# Remove the 'test_metadata_' prefix and the file extension
32+
if filename.startswith('test_metadata_'):
33+
sample_type = filename[len('test_metadata_'):]
34+
sample_type = os.path.splitext(sample_type)[0]
35+
return sample_type
36+
return "Unknown"
37+
38+
def test_model(test_fp, model_fp, ensemble=False):
39+
sample_type = get_sample_type(test_fp)
40+
test_metadata = pd.read_csv(test_fp, sep='\t', index_col=0)
41+
X_test = test_metadata.drop(columns=['study_sample_type', 'has_covid'], axis=1)
42+
y_test = test_metadata[['study_sample_type', 'has_covid']]
43+
44+
if sample_type == 'stool':
45+
rarefy_depth = 4000
46+
else:
47+
rarefy_depth = 1000
48+
49+
if 'large' in model_fp:
50+
sequence_embeddings = 'data/input/asv_embeddings_large.npy'
51+
else:
52+
sequence_embeddings = 'data/input/asv_embeddings_aam.npy'
53+
gd_test = [GeneratorDataset(
54+
table='data/input/merged_biom_table.biom',
55+
metadata=y_test,
56+
metadata_column='has_covid',
57+
shuffle=False,
58+
is_categorical=False,
59+
shift=0,
60+
rarefy_depth = rarefy_depth,
61+
scale=1,
62+
batch_size = 32,
63+
epochs=1,
64+
sequence_embeddings = sequence_embeddings,
65+
sequence_labels = 'data/input/asv_embeddings_ids.npy',
66+
upsample=False,
67+
drop_remainder=False,
68+
gen_new_table_frequency = 1,
69+
rarefy_seed = 42 + i
70+
) for i in range(69)
71+
]
72+
if '.keras' in model_fp: #Test on One Model
73+
model=tf.keras.models.load_model(model_fp, compile=False)
74+
predictions = [model.predict(ds, steps=ds.steps_per_epoch) for ds in gd_test]
75+
y_pred, y_true = [], []
76+
for y_p, y_t, _ in predictions:
77+
y_pred.append(y_p)
78+
y_true.append(y_t)
79+
y_pred = np.hstack(y_pred)
80+
y_true = np.hstack(y_true)
81+
82+
auc_score = 0
83+
return (y_pred, y_true), auc_score
84+
else: #Ensemble Method
85+
models = [tf.keras.models.load_model(f'{model_fp}/{sample_type}_{i}_model.keras', compile=False) for i in range(5)]
86+
predictions = []
87+
for model in models:
88+
predictions.append([model.predict(ds, steps=ds.steps_per_epoch) for ds in gd_test])
89+
ensemble_y_pred, ensemble_y_true = [], []
90+
for model_predictions in predictions:
91+
y_pred, y_true = [], []
92+
for y_p, y_t, _ in model_predictions:
93+
y_pred.append(y_p)
94+
y_true.append(y_t)
95+
y_pred = np.hstack(y_pred)
96+
y_true = np.hstack(y_true)
97+
ensemble_y_pred.append(y_pred)
98+
ensemble_y_true.append(y_true)
99+
ensemble_y_pred = np.vstack(ensemble_y_pred).mean(axis=0)
100+
ensemble_y_true = np.vstack(ensemble_y_true).mean(axis=0)
101+
102+
auc_score = 0
103+
return (ensemble_y_pred, ensemble_y_true), auc_score
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from aam.models.sequence_regressor import SequenceRegressor
2+
from aam.models.sequence_regressor_v2 import SequenceRegressorV2
3+
from aam.callbacks import SaveModel
4+
from keras.callbacks import EarlyStopping
5+
from AAM.model import GeneratorDataset
6+
from AAM.model import Classifier
7+
8+
import tensorflow as tf
9+
import pandas as pd
10+
import numpy as np
11+
import seaborn as sns
12+
from sklearn.metrics import roc_curve, auc, precision_recall_curve
13+
import matplotlib.pyplot as plt
14+
from sklearn.model_selection import train_test_split, StratifiedKFold
15+
import biom
16+
from biom import Table, load_table
17+
import os
18+
import sys
19+
import warnings
20+
21+
gpus = tf.config.list_physical_devices("GPU")
22+
if len(gpus) > 0:
23+
tf.config.experimental.set_memory_growth(gpus[0], True)
24+
25+
warnings.filterwarnings('ignore')
26+
27+
K = tf.keras
28+
29+
def get_sample_type(file_path):
30+
filename = os.path.basename(file_path)
31+
# Remove the 'training_metadata_' prefix and the file extension
32+
if filename.startswith('training_metadata_'):
33+
sample_type = filename[len('training_metadata_'):]
34+
sample_type = os.path.splitext(sample_type)[0]
35+
return sample_type
36+
return "Unknown"
37+
38+
#function that creates training and valid split and trains each model
39+
def train_model(train_fp, opt_type, hidden_dim, num_hidden_layers, dropout_rate, learning_rate, beta_1=None, beta_2=None, weight_decay=None, momentum=None, model_fp=None, large=True, use_cova=False):
40+
training_metadata = pd.read_csv(train_fp, sep='\t', index_col=0)
41+
X = training_metadata.drop(columns=['study_sample_type', 'has_covid'], axis=1)
42+
y = training_metadata[['study_sample_type', 'has_covid']]
43+
sample_type = get_sample_type(train_fp)
44+
dir_path = f'trained_models_aam/{sample_type}'
45+
if not os.path.exists(dir_path):
46+
os.makedirs(dir_path)
47+
if not large:
48+
sequence_embedding_fp = 'data/input/asv_embeddings_aam.npy'
49+
sequence_embedding_dim = 256
50+
else:
51+
sequence_embedding_fp = 'data/input/asv_embeddings_large.npy'
52+
sequence_embedding_dim = 512
53+
54+
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
55+
56+
curr_best_val_loss = np.inf
57+
curr_best_model = None
58+
for i, (train_index, valid_index) in enumerate(skf.split(y, y['has_covid'])):
59+
y_train = y.iloc[train_index]
60+
y_valid = y.iloc[valid_index]
61+
62+
if sample_type == 'stool':
63+
rarefy_depth = 4000
64+
else:
65+
rarefy_depth = 1000
66+
dataset_train = GeneratorDataset(
67+
table='data/input/merged_biom_table.biom',
68+
metadata=y_train,
69+
metadata_column='has_covid',
70+
shuffle=True,
71+
is_categorical=False,
72+
shift=0,
73+
rarefy_depth = rarefy_depth,
74+
scale=1,
75+
epochs=100000,
76+
batch_size = 4,
77+
gen_new_tables = True, #only in training dataset
78+
sequence_embeddings = sequence_embedding_fp,
79+
sequence_labels = 'data/input/asv_embeddings_ids.npy',
80+
upsample=False,
81+
drop_remainder=False
82+
)
83+
84+
dataset_valid = GeneratorDataset(
85+
table='data/input/merged_biom_table.biom',
86+
metadata=y_valid,
87+
metadata_column='has_covid',
88+
shuffle=False,
89+
is_categorical=False,
90+
shift=0,
91+
rarefy_depth = rarefy_depth,
92+
scale=1,
93+
epochs=100000,
94+
batch_size = 4,
95+
sequence_embeddings = sequence_embedding_fp,
96+
sequence_labels = 'data/input/asv_embeddings_ids.npy',
97+
upsample=False,
98+
drop_remainder=False,
99+
rarefy_seed = 42
100+
)
101+
102+
103+
if model_fp is None:
104+
model = Classifier(hidden_dim=hidden_dim, num_hidden_layers=num_hidden_layers, dropout_rate=dropout_rate, use_cova=use_cova)
105+
else:
106+
model = tf.keras.models.load_model(model_fp, compile=False)
107+
token_shape = tf.TensorShape([None, sequence_embedding_dim])
108+
batch_indicies = tf.TensorShape([None, 2])
109+
indicies_shape = tf.TensorShape([None])
110+
count_shape = tf.TensorShape([None, 1])
111+
model.build([token_shape, batch_indicies, indicies_shape, count_shape])
112+
model.summary()
113+
if opt_type == 'adam':
114+
optimizer = tf.keras.optimizers.Adam(
115+
learning_rate=tf.keras.optimizers.schedules.CosineDecay(
116+
initial_learning_rate = 0.0,
117+
warmup_target = learning_rate, # maybe change
118+
warmup_steps=0,
119+
decay_steps=250000,
120+
),
121+
use_ema = True,
122+
beta_1 = beta_1,
123+
beta_2 = beta_2,
124+
weight_decay = weight_decay
125+
)
126+
early_stop = EarlyStopping(patience=250, start_from_epoch=250, restore_best_weights=False)
127+
else:
128+
optimizer = tf.keras.optimizers.legacy.SGD(
129+
learning_rate=tf.keras.optimizers.schedules.CosineDecay(
130+
initial_learning_rate = 0.0,
131+
warmup_target = learning_rate, # maybe change
132+
warmup_steps=0,
133+
decay_steps=250000,
134+
),
135+
momentum = momentum
136+
)
137+
early_stop = EarlyStopping(patience=250, start_from_epoch=250, restore_best_weights=True)
138+
139+
model.compile(optimizer=optimizer, run_eagerly=False)
140+
#switch loss to val loss
141+
#pass early stopping for callbacks
142+
history = model.fit(dataset_train,
143+
validation_data = dataset_valid,
144+
validation_steps=dataset_valid.steps_per_epoch,
145+
epochs=10000,
146+
steps_per_epoch=dataset_train.steps_per_epoch,
147+
callbacks=[
148+
early_stop
149+
])
150+
151+
if opt_type == 'adam':
152+
model.optimizer.finalize_variable_values(model.trainable_variables)
153+
154+
validation_loss = history.history['val_loss']
155+
train_loss = history.history['loss']
156+
epochs = np.array(range(len(validation_loss)))
157+
158+
min_val_loss = np.min(history.history['val_loss'])
159+
if min_val_loss < curr_best_val_loss:
160+
curr_best_model = model
161+
curr_best_val_loss = min_val_loss
162+
163+
plt.plot(epochs, validation_loss, color='blue')
164+
plt.title(f'Validation Loss Per Epoch, Best: {curr_best_val_loss} Final: {min_val_loss}')
165+
plt.plot(epochs, train_loss, color='red')
166+
plt.savefig(os.path.join(dir_path, f'{sample_type}_{i}_model_loss.png'))
167+
plt.close()
168+
model.save(os.path.join(dir_path, f'{sample_type}_{i}_model.keras'), save_format='keras')
169+
curr_best_model.save(os.path.join(dir_path, f'{sample_type}_best_model.keras'), save_format='keras')
170+
print(f"\nAAM: Best model saved for {sample_type} samples {opt_type}.")

AAM/__pycache__/model.cpython-39.pyc

13 KB
Binary file not shown.
2.03 KB
Binary file not shown.

AAM/__pycache__/test.cpython-39.pyc

3.39 KB
Binary file not shown.
4.44 KB
Binary file not shown.

AAM/training.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
K = tf.keras
2828

2929
def get_sample_type(file_path):
30-
filename = os.path.basename(file_path)
31-
# Remove the 'training_metadata_' prefix and the file extension
32-
if filename.startswith('training_metadata_'):
33-
sample_type = filename[len('training_metadata_'):]
34-
sample_type = os.path.splitext(sample_type)[0]
35-
return sample_type
36-
return "Unknown"
30+
filename = os.path.basename(file_path)
31+
# Remove the 'training_metadata_' prefix and the file extension
32+
if filename.startswith('training_metadata_'):
33+
sample_type = filename[len('training_metadata_'):]
34+
sample_type = os.path.splitext(sample_type)[0]
35+
return sample_type
36+
return "Unknown"
3737

3838
#function that creates training and valid split and trains each model
3939
def train_model(train_fp, opt_type, hidden_dim, num_hidden_layers, dropout_rate, learning_rate, beta_1=None, beta_2=None, weight_decay=None, momentum=None, model_fp=None, large=True, use_cova=False):

0 commit comments

Comments
 (0)