-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy patheval_plus_secstr.py
155 lines (137 loc) · 7.93 KB
/
eval_plus_secstr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Written by Seonwoo Min, Seoul National University (mswzeus@gmail.com)
# PLUS
import os
import sys
import argparse
import torch
import plus.config as config
from plus.data.alphabets import Protein
import plus.data.secstr as secstr
import plus.data.dataset as dataset
import plus.model.plus_rnn as plus_rnn
import plus.model.plus_tfm as plus_tfm
import plus.model.p_elmo as p_elmo
import plus.model.mlp as mlp
from plus.train import Trainer
from plus.utils import Print, set_seeds, set_output, load_models
parser = argparse.ArgumentParser('Evaluate a Model on SecStr Datasets')
parser.add_argument('--data-config', help='path for data configuration file')
parser.add_argument('--model-config', help='path for model configuration file')
parser.add_argument('--lm-model-config', help='path for lm-model configuration file (for P-ELMo)')
parser.add_argument('--pr-model-config', help='path for pr-model configuration file (for P-ELMo and PLUS-RNN)')
parser.add_argument('--run-config', help='path for run configuration file')
parser.add_argument('--pretrained-model', help='path for pretrained model file')
parser.add_argument('--pretrained-lm-model', help='path for pretrained lm-model file (for P-ELMo)')
parser.add_argument('--pretrained-pr-model', help='path for pretrained pr-model file (for P-ELMo and PLUS-RNN)')
parser.add_argument('--device', help='device to use; multi-GPU if given multiple GPUs sperated by comma (default: cpu)')
parser.add_argument('--output-path', help='path for outputs (default: stdout and without saving)')
parser.add_argument('--output-index', help='prefix for outputs')
parser.add_argument('--sanity-check', default=False, action='store_true', help='sanity check flag')
def main():
set_seeds(2020)
args = vars(parser.parse_args())
alphabet = Protein()
cfgs = []
data_cfg = config.DataConfig(args["data_config"]); cfgs.append(data_cfg)
if args["lm_model_config"] is None:
model_cfg = config.ModelConfig(args["model_config"], input_dim=len(alphabet), num_classes=8)
cfgs += [model_cfg]
else:
lm_model_cfg = config.ModelConfig(args["lm_model_config"], idx="lm_model_config", input_dim=len(alphabet))
model_cfg = config.ModelConfig(args["model_config"], input_dim=len(alphabet),
lm_dim=lm_model_cfg.num_layers * lm_model_cfg.hidden_dim * 2, num_classes=8)
cfgs += [model_cfg, lm_model_cfg]
if model_cfg.model_type == "RNN":
pr_model_cfg = config.ModelConfig(args["pr_model_config"], idx="pr_model_config",
model_type="MLP", num_classes=8)
if pr_model_cfg.projection: pr_model_cfg.set_input_dim(model_cfg.embedding_dim)
else: pr_model_cfg.set_input_dim(model_cfg.hidden_dim * 2)
cfgs.append(pr_model_cfg)
run_cfg = config.RunConfig(args["run_config"], sanity_check=args["sanity_check"]); cfgs.append(run_cfg)
output, save_prefix = set_output(args, "eval_secstr_log", test=True)
os.environ['CUDA_VISIBLE_DEVICES'] = args["device"] if args["device"] is not None else ""
device, data_parallel = torch.device("cuda" if torch.cuda.is_available() else "cpu"), torch.cuda.device_count() > 1
config.print_configs(args, cfgs, device, output)
flag_rnn = (model_cfg.model_type == "RNN")
flag_lm_model = (args["lm_model_config"] is not None)
## load test datasets
idxs_test, datasets_test, iterators_test = [key for key in data_cfg.path.keys() if "test" in key], [], []
start = Print(" ".join(['start loading test datasets'] + idxs_test), output)
collate_fn = dataset.collate_sequences if flag_rnn else None
for idx_test in idxs_test:
dataset_test = secstr.load_secstr(data_cfg, idx_test, alphabet, args["sanity_check"])
dataset_test = dataset.Seq_dataset(*dataset_test, alphabet, run_cfg, flag_rnn, model_cfg.max_len, truncate=False)
iterator_test = torch.utils.data.DataLoader(dataset_test, run_cfg.batch_size_eval, collate_fn=collate_fn)
datasets_test.append(dataset_test); iterators_test.append(iterator_test)
end = Print(" ".join(['loaded', str(len(dataset_test)), 'sequences']), output)
Print(" ".join(['elapsed time:', str(end - start)]), output, newline=True)
## initialize a model
start = Print('start initializing a model', output)
models_list = [] # list of lists [model, idx, flag_frz, flag_clip_grad, flag_clip_weight]
### model
if not flag_rnn: model = plus_tfm.PLUS_TFM(model_cfg)
elif not flag_lm_model: model = plus_rnn.PLUS_RNN(model_cfg)
else: model = p_elmo.P_ELMo(model_cfg)
models_list.append([model, "", flag_lm_model, flag_rnn, False])
### lm_model
if flag_lm_model:
lm_model = p_elmo.P_ELMo_lm(lm_model_cfg)
models_list.append([lm_model, "lm", True, False, False])
### pr_model
if flag_rnn:
pr_model = mlp.MLP(pr_model_cfg)
models_list.append([pr_model, "pr", False, False, False])
params, pr_params = [], []
for model, idx, frz, _, _ in models_list:
if frz: continue
elif idx != "pr": params += [p for p in model.parameters() if p.requires_grad]
else: pr_params += [p for p in model.parameters() if p.requires_grad]
load_models(args, models_list, device, data_parallel, output, tfm_cls=flag_rnn)
get_loss = plus_rnn.get_loss if flag_rnn else plus_tfm.get_loss
end = Print('end initializing a model', output)
Print("".join(['elapsed time:', str(end - start)]), output, newline=True)
## setup trainer configurations
start = Print('start setting trainer configurations', output)
tasks_list = [] # list of lists [idx, metrics_train, metrics_eval]
tasks_list.append(["cls", [], ["acc8", "acc3"]])
if not flag_lm_model: tasks_list.append(["lm", [], ["acc"]])
trainer = Trainer(models_list, get_loss, run_cfg, tasks_list)
trainer_args = {}
trainer_args["data_parallel"] = data_parallel
trainer_args["paired"] = False
if flag_rnn: trainer_args["projection"] = pr_model_cfg.projection
if flag_rnn: trainer_args["evaluate_cls"] = plus_rnn.evaluate_cls_amino
else: trainer_args["evaluate_cls"] = plus_tfm.evaluate_cls_amino
trainer_args["evaluate"] = ["cls", secstr.evaluate_secstr]
end = Print('end setting trainer configurations', output)
Print("".join(['elapsed time:', str(end - start)]), output, newline=True)
## evaluate a model
start = Print('start evaluating a model', output)
Print(trainer.get_headline(test=True), output)
for idx_test, dataset_test, iterator_test in zip(idxs_test, datasets_test, iterators_test):
### evaluate cls
dataset_test.set_augment(False)
trainer.set_exec_flags(["cls", 'lm'], [True, False])
for b, batch in enumerate(iterator_test):
batch = [t.to(device) if type(t) is torch.Tensor else t for t in batch]
trainer.evaluate(batch, trainer_args)
if b % 10 == 0: print('# cls {:.1%} loss={:.4f}'.format(
b / len(iterator_test), trainer.loss_eval), end='\r', file=sys.stderr)
print(' ' * 150, end='\r', file=sys.stderr)
### evaluate lm
if not flag_lm_model:
dataset_test.set_augment(True)
trainer.set_exec_flags(["cls", 'lm'], [False, True])
for b, batch in enumerate(iterator_test):
batch = [t.to(device) if type(t) is torch.Tensor else t for t in batch]
trainer.evaluate(batch, trainer_args)
if b % 10 == 0: print('# lm {:.1%} loss={:.4f}'.format(
b / len(iterator_test), trainer.loss_eval), end='\r', file=sys.stderr)
print(' ' * 150, end='\r', file=sys.stderr)
Print(trainer.get_log(test_idx=idx_test, args=trainer_args), output)
trainer.reset()
end = Print('end evaluating a model', output)
Print("".join(['elapsed time:', str(end - start)]), output, newline=True)
output.close()
if __name__ == '__main__':
main()