-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocess.py
44 lines (36 loc) · 1.09 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import shutil
from argparse import ArgumentParser
from pathlib import Path
import pandas as pd
model_names = {
48: "esm2_t48_15B_UR50D",
36: "esm2_t36_3B_UR50D",
33: "esm2_t33_650M_UR50D",
30: "esm2_t30_150M_UR50D",
12: "esm2_t12_35M_UR50D",
6: "esm2_t6_8M_UR50D",
}
parser = ArgumentParser()
parser.add_argument("data_file", type=Path)
parser.add_argument("num_layers", type=int, choices=model_names.keys(), help="number of model layers")
parser.add_argument("out_dir", type=Path)
args = parser.parse_args()
out_dir = args.out_dir / "processed" / model_names[args.num_layers] / args.data_file.stem
df = pd.read_csv(args.data_file)
with open("tmp.fasta", "w") as f:
for i, sequence in enumerate(df["primary"]):
f.write(f">prot_{i}\n{sequence}\n")
launch_script = [
"python",
"extract.py",
model_names[args.num_layers],
"tmp.fasta",
str(out_dir),
"--include",
"per_tok",
"--repr_layers",
]
launch_script.extend([str(i) for i in range(args.num_layers)])
os.system(" ".join(launch_script))
shutil.copy(args.data_file, out_dir / "df.csv")