-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
359 lines (317 loc) · 15.5 KB
/
eval.py
File metadata and controls
359 lines (317 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import argparse
import warnings
import torch
import os
import sys
import yaml
import datetime
import logging
import numpy as np
import pandas as pd
import transformers
from tqdm import tqdm
from torch.utils.data import DataLoader
from typing import *
# from transformers import get_inverse_sqrt_schedule
from tqdm import tqdm
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor, as_completed
from accelerate.utils import set_seed
from accelerate import Accelerator
from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef
from src.models import ProtssnClassification, PLM_model, GNN_model
from src.utils.data_utils import BatchSampler
from src.utils.utils import param_num, total_param_num
from src.dataset.supervise_dataset import SuperviseDataset
from src.utils.dataset_utils import NormalizeProtein
# set path
current_dir = os.getcwd()
sys.path.append(current_dir)
# ignore warning information
transformers.logging.set_verbosity_error()
warnings.filterwarnings("ignore")
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def printlog(info):
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("\n" + "==========" * 3 + "%s" % nowtime + "==========" * 3)
print(str(info) + "\n")
class StepRunner:
def __init__(self, args, model,
loss_fn, accelerator=None,
metrics_dict=None,
):
self.model = model
self.metrics_dict = metrics_dict
self.accelerator = accelerator
self.loss_fn = loss_fn
self.args = args
def step(self, batch):
logits, ssn_emebds = self.model(plm_model, gnn_model, batch, True)
logits = logits.cuda()
label = torch.cat([data.label for data in batch]).to(logits.device)
pred_labels = torch.argmax(logits, 1).cpu().numpy()
loss = self.loss_fn(logits, label)
# compute metrics
for name, metric_fn in self.metrics_dict.items():
metric_fn.update(torch.argmax(logits, 1), label)
return loss.item(), self.model, self.metrics_dict, pred_labels, ssn_emebds
def train_step(self, batch):
self.model.train()
return self.step(batch)
@torch.no_grad()
def eval_step(self, batch):
self.model.eval()
return self.step(batch)
def __call__(self, batch):
return self.eval_step(batch)
class EpochRunner:
def __init__(self, steprunner):
self.steprunner = steprunner
self.args = steprunner.args
def __call__(self, dataloader):
loop = tqdm(dataloader, total=len(dataloader), file=sys.stdout)
total_loss = 0
result_dict = {'name':[], 'aa_seq':[], 'label':[], 'pred_label':[]}
ssn_embeds = []
for batch in loop:
step_loss, model, metrics_dict, pred_label, ssn_embed = self.steprunner(batch)
result_dict["pred_label"].extend(pred_label)
result_dict["name"].extend([data.name for data in batch])
result_dict["aa_seq"].extend([data.aa_seq for data in batch])
result_dict["label"].extend([data.label.item() for data in batch])
ssn_embeds.append(ssn_embed)
step_log = dict({f"eval/loss": round(step_loss, 3)})
loop.set_postfix(**step_log)
total_loss += step_loss
ssn_embeds = torch.cat(ssn_embeds, dim=0)
epoch_metric_results = {}
for name, metric_fn in metrics_dict.items():
epoch_metric_results[f"eval/{name}"] = metric_fn.compute().item()
metric_fn.reset()
avg_loss = total_loss / len(dataloader)
epoch_metric_results[f"eval/loss"] = avg_loss
return model, epoch_metric_results, result_dict, ssn_embeds
def eval_model(args, model, loss_fn,
accelerator=None, metrics_dict=None,
test_data=None
):
model_path = os.path.join(args.model_dir, args.model_name)
if test_data:
model.load_state_dict(torch.load(model_path)["state_dict"])
test_step_runner = StepRunner(
args=args, model=model,
loss_fn=loss_fn, accelerator=accelerator,
metrics_dict=deepcopy(metrics_dict),
)
test_epoch_runner = EpochRunner(test_step_runner)
with torch.no_grad():
model, epoch_metric_results, result_dict, ssn_embeds = test_epoch_runner(test_data)
for name, metric in epoch_metric_results.items():
epoch_metric_results[name] = [metric]
print(f">>> {name}: {'%.3f'%metric}")
if args.test_result_dir:
os.makedirs(args.test_result_dir, exist_ok=True)
pd.DataFrame(result_dict).to_csv(f"{args.test_result_dir}/test_result.csv", index=False)
pd.DataFrame(epoch_metric_results).to_csv(f"{args.test_result_dir}/test_metrics.csv", index=False)
torch.save(ssn_embeds, f"{args.test_result_dir}/ssn_embeds.pt")
def create_parser():
parser = argparse.ArgumentParser()
# model config
parser.add_argument("--gnn", type=str, default="egnn", help="gat, gcn or egnn")
parser.add_argument("--gnn_config", type=str, default="src/config/egnn.yaml", help="gnn config")
parser.add_argument("--gnn_hidden_dim", type=int, default=512, help="hidden size of gnn")
parser.add_argument("--plm", type=str, default="facebook/esm2_t33_650M_UR50D", help="esm param number")
parser.add_argument("--plm_hidden_size", type=int, default=1280, help="hidden size of plm")
parser.add_argument("--pooling_method", type=str, default="attention1d", help="pooling method")
parser.add_argument("--pooling_dropout", type=float, default=0.1, help="pooling dropout")
# training strategy
parser.add_argument("--seed", type=int, default=3407, help="random seed")
parser.add_argument("--weight_decay", type=float, default=1e-2, help="weight_decay")
parser.add_argument("--batch_token_num", type=int, default=4096, help="how many tokens in one batch")
parser.add_argument("--max_graph_token_num", type=int, default=3000, help="max token num a graph has")
parser.add_argument("--max_grad_norm", type=float, default=None, help="clip grad norm")
# dataset
parser.add_argument("--num_labels", type=int, default=2, help="number of labels")
parser.add_argument("--problem_type", type=str, default="classification", help="classification or regression")
parser.add_argument("--supv_dataset", type=str, help="supervise protein dataset")
parser.add_argument("--test_file", type=str, help="test label file")
parser.add_argument('--test_result_dir', type=str, default=None, help='test result directory')
parser.add_argument("--feature_file", type=str, default=None, help="feature file")
parser.add_argument("--feature_name", nargs="+", default=None, help="feature names")
parser.add_argument("--feature_dim", type=int, default=0, help="feature dim")
parser.add_argument("--feature_embed_dim", type=int, default=512, help="feature embed dim")
parser.add_argument("--use_plddt_penalty", action="store_true", help="use plddt penalty")
parser.add_argument("--c_alpha_max_neighbors", type=int, default=20, help="graph dataset K")
parser.add_argument("--gnn_model_path", type=str, default="./model/protssn_k20_h512.pt", help="gnn model path")
# load model
parser.add_argument("--model_dir", type=str, default="./ckpt", help="model save dir")
parser.add_argument("--model_name", type=str, default="feature512_norm_pp_attention1d_k20_h512_lr5e-4.pt", help="model name")
args = parser.parse_args()
return args
feature_dict ={}
if __name__ == "__main__":
args = create_parser()
args.gnn_config = yaml.load(open(args.gnn_config), Loader=yaml.FullLoader)[args.gnn]
args.gnn_config["hidden_channels"] = args.gnn_hidden_dim
set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.feature_file:
logger.info("***** Loading Feature *****")
feature_df = pd.read_csv(args.feature_file)
if type(args.feature_name) != list:
args.feature_name = [args.feature_name]
feature_aa_composition = ["1-C", "1-D", "1-E", "1-R", "1-H", "Turn-forming residues fraction"]
if "aa_composition" in args.feature_name:
aa_composition_df = feature_df[feature_aa_composition]
args.feature_dim += len(feature_aa_composition)
feature_gravy = ["GRAVY"]
if "gravy" in args.feature_name:
gravy_df = feature_df[feature_gravy]
args.feature_dim += len(feature_gravy)
feature_ss_composition = ["ss8-G", "ss8-H", "ss8-I", "ss8-B", "ss8-E", "ss8-T", "ss8-S", "ss8-P", "ss8-L", "ss3-H", "ss3-E", "ss3-C"]
if "ss_composition" in args.feature_name:
ss_composition_df = feature_df[feature_ss_composition]
args.feature_dim += len(feature_ss_composition)
feature_hygrogen_bonds = ["Hydrogen bonds", "Hydrogen bonds per 100 residues"]
if "hygrogen_bonds" in args.feature_name:
hygrogen_bonds_df = feature_df[feature_hygrogen_bonds]
args.feature_dim += len(feature_hygrogen_bonds)
feature_exposed_res_fraction = [
"Exposed residues fraction by 5%", "Exposed residues fraction by 10%", "Exposed residues fraction by 15%",
"Exposed residues fraction by 20%", "Exposed residues fraction by 25%", "Exposed residues fraction by 30%",
"Exposed residues fraction by 35%", "Exposed residues fraction by 40%", "Exposed residues fraction by 45%",
"Exposed residues fraction by 50%", "Exposed residues fraction by 55%", "Exposed residues fraction by 60%",
"Exposed residues fraction by 65%", "Exposed residues fraction by 70%", "Exposed residues fraction by 75%",
"Exposed residues fraction by 80%", "Exposed residues fraction by 85%", "Exposed residues fraction by 90%",
"Exposed residues fraction by 95%", "Exposed residues fraction by 100%"
]
if "exposed_res_fraction" in args.feature_name:
exposed_res_fraction_df = feature_df[feature_exposed_res_fraction]
args.feature_dim += len(feature_exposed_res_fraction)
feature_pLDDT = ["pLDDT"]
if "pLDDT" in args.feature_name:
plddt_df = feature_df[feature_pLDDT]
args.feature_dim += len(feature_pLDDT)
for i in tqdm(range(len(feature_df))):
name = feature_df["protein name"][i].split(".")[0]
feature_dict[name] = []
if "aa_composition" in args.feature_name:
feature_dict[name] += list(aa_composition_df.iloc[i])
if "gravy" in args.feature_name:
feature_dict[name] += list(gravy_df.iloc[i])
if "ss_composition" in args.feature_name:
feature_dict[name] += list(ss_composition_df.iloc[i])
if "hygrogen_bonds" in args.feature_name:
feature_dict[name] += list(hygrogen_bonds_df.iloc[i])
if "exposed_res_fraction" in args.feature_name:
feature_dict[name] += list(exposed_res_fraction_df.iloc[i])
if "pLDDT" in args.feature_name:
feature_dict[name] += list(plddt_df.iloc[i])
# load dataset
logger.info("***** Loading Dataset *****")
datatset_name = args.supv_dataset.split("/")[-1]
if os.path.exists(f"{args.supv_dataset}/esmfold_pdb"):
pdb_dir = f"{args.supv_dataset}/esmfold_pdb"
elif os.path.exists(f"{args.supv_dataset}/pdb"):
pdb_dir = f"{args.supv_dataset}/pdb"
else:
raise ValueError("No pdb or esmfold_pdb directory found in the dataset")
graph_dir = f"{datatset_name}_k{args.c_alpha_max_neighbors}"
supervise_dataset = SuperviseDataset(
root=args.supv_dataset,
raw_dir=pdb_dir,
name=graph_dir,
c_alpha_max_neighbors=args.c_alpha_max_neighbors,
pre_transform=NormalizeProtein(
filename=f'norm/cath_k{args.c_alpha_max_neighbors}_mean_attr.pt'
),
)
label_dict, seq_dict = {}, {}
def get_dataset(df):
names, node_nums = [], []
for name, label, seq in zip(df["name"], df["label"], df["aa_seq"]):
names.append(name)
label_dict[name] = label
seq_dict[name] = seq
node_nums.append(len(seq))
return names, node_nums
test_names, test_node_nums = get_dataset(pd.read_csv(args.test_file))
# multi-thread load data will shuffle the order of data
# so we need to save the information
def process_data(name, fd):
data = torch.load(f"{args.supv_dataset}/{graph_dir.capitalize()}/processed/{name}.pt")
data.label = torch.tensor(label_dict[name]).view(1)
data.aa_seq = seq_dict[name]
data.name = name
fe = None
if args.feature_file:
for key, value in fd.items():
if key == str(name):
fe = value
if fe is None:
print(f'\'{name}\'')
data.feature = torch.tensor(fe).view(1, -1)
return data
def collect_fn(batch):
batch_data = []
with ThreadPoolExecutor(max_workers=12) as executor:
feature_d =deepcopy(feature_dict)
futures = [executor.submit(process_data, name, feature_d) for name in batch]
for future in as_completed(futures):
graph = future.result()
batch_data.append(graph)
return batch_data
test_dataloader = DataLoader(
dataset=test_names, num_workers=4,
collate_fn=lambda x: collect_fn(x),
batch_sampler=BatchSampler(
node_num=test_node_nums,
max_len=args.max_graph_token_num,
batch_token_num=args.batch_token_num,
shuffle=False
)
)
logger.info("***** Load Model *****")
# load model
global plm_model
plm_model = PLM_model(args).to(device)
global gnn_model
gnn_model = GNN_model(args).to(device)
gnn_model.load_state_dict(torch.load(args.gnn_model_path))
protssn_classification = ProtssnClassification(args)
protssn_classification.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
for param in plm_model.parameters():
param.requires_grad = False
for param in gnn_model.parameters():
param.requires_grad = False
logger.info(total_param_num(protssn_classification))
logger.info(param_num(protssn_classification))
accelerator = Accelerator()
protssn_classification, test_dataloader = accelerator.prepare(
protssn_classification, test_dataloader
)
metrics_dict = {
"acc": BinaryAccuracy().to(device),
"recall": BinaryRecall().to(device),
"precision": BinaryPrecision().to(device),
"mcc": BinaryMatthewsCorrCoef().to(device),
"auroc": BinaryAUROC().to(device),
"f1": BinaryF1Score().to(device),
}
logger.info("***** Running eval *****")
logger.info(" Num test examples = %d", len(test_names))
logger.info(" Batch token num = %d", args.batch_token_num)
eval_model(
args=args, model=protssn_classification,
loss_fn=loss_fn,
accelerator=accelerator, metrics_dict=metrics_dict,
test_data=test_dataloader
)