-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemonstrate.py
90 lines (63 loc) · 2.88 KB
/
demonstrate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
train.py
train model
"""
import argparse
from os.path import join
from pathlib import Path
import re
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import EsmTokenizer, BartTokenizer, BartModel, BartForConditionalGeneration
from crystoper import config
from crystoper.processor import filter_by_pdbx_details_length, filter_for_single_entities
from crystoper.utils.general import vprint, make_parent_dirs
from crystoper.esmc_models import ESMCcomplex
from crystoper.trainer import ESMCTrainer, seq2sent
def parse_args():
parser = argparse.ArgumentParser(description="Parse PDB 'entries' and 'polymer entities' json files.")
parser.add_argument('-m', '--model', type=str, default='esmc-complex',
help='model to use (if checkpoint is passed - it will be loaded instead)')
parser.add_argument('-c', '--checkpoint', type=str, default=None,
help='Checkpoint for loading a pre-trained model')
parser.add_argument('-d', '--data_path', type=str, default=config.toy_path,
help='csv to use for data input. default is using the toy data')
parser.add_argument('-x', '--x-column', type=str, default='sequence',
help='Column to use for input (sequences)')
parser.add_argument('-y', '--y-column', type=str, default='pdbx_details',
help='Column to use as true labels (used for display only)')
parser.add_argument('--device', default='cpu',
help='device to use')
args = parser.parse_args()
return args
def main():
args = parse_args()
device = args.device
if args.model == 'esmc-complex':
esm_model = ESMCcomplex()
esm_model.to(args.device)
vprint(f"A fresh {args.model} has been created!")
else:
raise ValueError('Model cannot be resolved')
if args.checkpoint:
checkpoint = torch.load(args.checkpoint)
esm_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
vprint(f"loaded previous model from checkpoint {args.checkpoint}")
del checkpoint
esm_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
data = pd.read_csv(args.data_path)
X = data[args.x_column]
Y = data[args.y_column]
bart_model.to(device)
esm_model.to(device)
for x, y_true in zip(X,Y):
pred = seq2sent(x, esm_model, esm_tokenizer, bart_model, bart_tokenizer, ac=True)
print(f'True sentence: {y_true}')
print(f'Pred sentence: {pred}')
print('\n\n')
if __name__ == "__main__":
main()