Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval不能正常运行 #3

Open
zhangtianhong-1998 opened this issue Dec 24, 2024 · 18 comments
Open

Eval不能正常运行 #3

zhangtianhong-1998 opened this issue Dec 24, 2024 · 18 comments

Comments

@zhangtianhong-1998
Copy link

似乎是在Dataloader中出现了一些问题
这是我们修改了一些本地的配置,以便顺利加载模型
parser = argparse.ArgumentParser()
# model config
parser.add_argument("--gnn", type=str, default="egnn", help="gat, gcn or egnn")
parser.add_argument("--gnn_config", type=str, default="src/config/egnn.yaml", help="gnn config")
parser.add_argument("--gnn_hidden_dim", type=int, default=512, help="hidden size of gnn")
parser.add_argument("--plm", type=str, default="./model/facebook", help="esm param number")
parser.add_argument("--plm_hidden_size", type=int, default=1280, help="hidden size of plm")
parser.add_argument("--pooling_method", type=str, default="attention1d", help="pooling method")
parser.add_argument("--pooling_dropout", type=float, default=0.1, help="pooling dropout")

# training strategy
parser.add_argument("--seed", type=int, default=3407, help="random seed")
parser.add_argument("--weight_decay", type=float, default=1e-2, help="weight_decay")
parser.add_argument("--batch_token_num", type=int, default=4096, help="how many tokens in one batch")
parser.add_argument("--max_graph_token_num", type=int, default=3000, help="max token num a graph has")
parser.add_argument("--max_grad_norm", type=float, default=None, help="clip grad norm")

# dataset
parser.add_argument("--num_labels", type=int, default=2, help="number of labels")
parser.add_argument("--problem_type", type=str, default="classification", help="classification or regression")
parser.add_argument("--supv_dataset", type=str, help="supervise protein dataset")
parser.add_argument("--test_file", type=str, help="test label file")
parser.add_argument('--test_result_dir', type=str, default=None, help='test result directory')
parser.add_argument("--feature_file", type=str, default=None, help="feature file")
parser.add_argument("--feature_name", nargs="+", default=None, help="feature names")
parser.add_argument("--feature_dim", type=int, default=0, help="feature dim")
parser.add_argument("--feature_embed_dim", type=int, default=512, help="feature embed dim")
parser.add_argument("--use_plddt_penalty", action="store_true", help="use plddt penalty")
parser.add_argument("--c_alpha_max_neighbors", type=int, default=20, help="graph dataset K")
parser.add_argument("--gnn_model_path", type=str, default="./model/protssn_k20_h512.pt", help="gnn model path")

# load model
parser.add_argument("--model_dir", type=str, default="./ckpt", help="model save dir")
parser.add_argument("--model_name", type=str, default="feature512_norm_pp_attention1d_k20_h512_lr5e-4.pt", help="model name")

这是readme中的建议脚本
python eval.py
--supv_dataset data/ExternalTest/esmfold_pdb
--test_file data/ExternalTest/ExternalTest.csv
--test_result_dir result/protssn_k20_h512/experiment
--feature_file data/ExternalTest/ExternalTest_feature.csv
--feature_name "aa_composition" "gravy" "ss_composition" "hygrogen_bonds" "exposed_res_fraction" "pLDDT"
--use_plddt_penalty
--batch_token_num 3000

这是报错的信息

12/24/2024 11:40:04 - INFO - main - ***** Loading Feature *****
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7843/7843 [00:00<00:00, 13289.17it/s]
12/24/2024 11:40:05 - INFO - main - ***** Loading Dataset *****
Processing...
0it [00:00, ?it/s]
Total proteins: []
Wrong proteins: []
0it [00:00, ?it/s]
Done!
12/24/2024 11:40:05 - INFO - main - ***** Load Model *****
12/24/2024 11:40:06 - INFO - main - Number of parameter: 3.24M
12/24/2024 11:40:06 - INFO - main - Number of trainable parameter: 3.24M
12/24/2024 11:40:06 - INFO - main - ***** Running eval *****
12/24/2024 11:40:06 - INFO - main - Num test examples = 7579
12/24/2024 11:40:06 - INFO - main - Batch token num = 3000
0%| | 0/635 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/gaoyuan/ProtSolM/eval.py", line 341, in
eval_model(
File "/home/gaoyuan/ProtSolM/eval.py", line 128, in eval_model
model, epoch_metric_results, result_dict, ssn_embeds = test_epoch_runner(test_data)
File "/home/gaoyuan/ProtSolM/eval.py", line 95, in call
for batch in loop:
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/tqdm/std.py", line 1181, in iter
for obj in iterable:
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/accelerate/data_loader.py", line 552, in iter
current_batch = next(dataloader_iter)
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 701, in next
data = self._next_data()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1465, in _next_data
return self._process_data(data)
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1491, in _process_data
data.reraise()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/_utils.py", line 715, in reraise
raise exception
FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
return self.collate_fn(data)
File "/home/gaoyuan/ProtSolM/eval.py", line 297, in
collate_fn=lambda x: collect_fn(x),
File "/home/gaoyuan/ProtSolM/eval.py", line 291, in collect_fn
graph = future.result()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/gaoyuan/ProtSolM/eval.py", line 278, in process_data
data = torch.load(f"{args.supv_dataset}/{graph_dir.capitalize()}/processed/{name}.pt")
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/serialization.py", line 1319, in load
with _open_file_like(f, "rb") as opened_file:
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/serialization.py", line 659, in _open_file_like
return _open_file(name_or_buffer, mode)
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/serialization.py", line 640, in init
super().init(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: 'data/ExternalTest/esmfold_pdb/Esmfold_pdb_k20/processed/test_protein_12.pt'

@tyang816
Copy link
Owner

您下载了对应的pdb文件吗,我们这个工作需要先将pdb处理成图才能进行推理,比如这里面的test_protein_12.pt就是处理的图文件

@zhangtianhong-1998
Copy link
Author

zhangtianhong-1998 commented Dec 24, 2024 via email

@tyang816
Copy link
Owner

这里显示的是缺少了文件欸,我尝试复现一下,不好意思

@zhangtianhong-1998
Copy link
Author

1735043748021 这个路径大概是这样,谢谢

@tyang816
Copy link
Owner

您这个目录看起来似乎有点奇怪,应该是这样的,下载后直接解压即可
image

@tyang816
Copy link
Owner

python eval.py
--supv_dataset data/ExternalTest【/esmfold_pdb这个不需要】
--test_file data/ExternalTest/ExternalTest.csv
--test_result_dir result/protssn_k20_h512/experiment
--feature_file data/ExternalTest/ExternalTest_feature.csv
--feature_name "aa_composition" "gravy" "ss_composition" "hygrogen_bonds" "exposed_res_fraction" "pLDDT"
--use_plddt_penalty
--batch_token_num 3000

@zhangtianhong-1998
Copy link
Author

确实,删掉后可以运行,我以为这一步还要引用Pdb文件,看来是不用了?

@tyang816
Copy link
Owner

pdb用于制作图和获取物理化学特征,两项都处理完了就不用了

@zhangtianhong-1998
Copy link
Author

谢谢,佬

@tyang816
Copy link
Owner

祝科研顺利!论文多多

@zhangtianhong-1998
Copy link
Author

对于ExternalTest可以正常跑通,但是一旦变为了我自己的数据集就会出现如下问题,此外发现一个问题,如果要使用自己的数据集,pdb文件夹的名称要设置为esmfold_pdb,否则会有其他bug

12/24/2024 22:33:06 - INFO - main - ***** Load Model *****
12/24/2024 22:33:07 - INFO - main - Number of parameter: 3.24M
12/24/2024 22:33:07 - INFO - main - Number of trainable parameter: 3.24M
12/24/2024 22:33:07 - INFO - main - ***** Running eval *****
12/24/2024 22:33:07 - INFO - main - Num test examples = 193
12/24/2024 22:33:07 - INFO - main - Batch token num = 1000
0%| | 0/94 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/gaoyuan/ProtSolM/eval.py", line 341, in
eval_model(
File "/home/gaoyuan/ProtSolM/eval.py", line 128, in eval_model
model, epoch_metric_results, result_dict, ssn_embeds = test_epoch_runner(test_data)
File "/home/gaoyuan/ProtSolM/eval.py", line 95, in call
for batch in loop:
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/tqdm/std.py", line 1181, in iter
for obj in iterable:
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/accelerate/data_loader.py", line 552, in iter
current_batch = next(dataloader_iter)
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 701, in next
data = self._next_data()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1465, in _next_data
return self._process_data(data)
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1491, in _process_data
data.reraise()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/_utils.py", line 715, in reraise
raise exception
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
return self.collate_fn(data)
File "/home/gaoyuan/ProtSolM/eval.py", line 297, in
collate_fn=lambda x: collect_fn(x),
File "/home/gaoyuan/ProtSolM/eval.py", line 291, in collect_fn
graph = future.result()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/gaoyuan/anaconda3/envs/protsolm/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/gaoyuan/ProtSolM/eval.py", line 283, in process_data
data.feature = torch.tensor(feature_dict[name]).view(1, -1)
KeyError: 1

@tyang816
Copy link
Owner

欸是的,目前pdb文件夹的名称要设置为esmfold_pdb,我修改一下这个bug。请问你有事先提取物理化学特征吗

@tyang816
Copy link
Owner

我在刚刚的commit中修复了这个问题,新增了对存储pdb文件的文件夹名的判断:

if os.path.exists(f"{args.supv_dataset}/esmfold_pdb"):
        pdb_dir = f"{args.supv_dataset}/esmfold_pdb"
    elif os.path.exists(f"{args.supv_dataset}/pdb"):
        pdb_dir = f"{args.supv_dataset}/pdb"
    else:
        raise ValueError("No pdb or esmfold_pdb directory found in the dataset")

@zhangtianhong-1998
Copy link
Author

没有事先提取物理特征,因为我看你整个图的流程,特征好像是直接从序列里边提取出来的呀。

@tyang816
Copy link
Owner

图里的意思应该是从蛋白质结构
image

@zhangtianhong-1998
Copy link
Author

是运行get_feature.py吗?这个是运行过的,但是还是这个错误

@zhangtianhong-1998
Copy link
Author

image

@zhangtianhong-1998
Copy link
Author

zhangtianhong-1998 commented Dec 25, 2024

最后排查出了问题,是多线程可能会导致feature_dict不是全局变量,为空
其次,我的文件名按照数字命名,这样造成了 def process_data(name, fd): 中,name可能为整型变量而不是字符串
以下是我的改动后的eval.py,最后感谢大佬
(改动是debug过程中出现的打印一些东西,能跑起来就行,不想改了)

import argparse
import warnings
import torch
import os
import sys
import yaml
import wandb
import datetime
import logging
import numpy as np
import pandas as pd
import transformers
import json
from tqdm import tqdm
from torch.utils.data import DataLoader
from typing import *
# from transformers import get_inverse_sqrt_schedule
from tqdm import tqdm
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor, as_completed
from accelerate.utils import set_seed
from accelerate import Accelerator
from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef
from src.models import ProtssnClassification, PLM_model, GNN_model
from src.utils.data_utils import BatchSampler
from src.utils.utils import param_num, total_param_num
from src.dataset.supervise_dataset import SuperviseDataset
from src.utils.dataset_utils import NormalizeProtein


current_dir = os.getcwd()
sys.path.append(current_dir)
# ignore warning information
transformers.logging.set_verbosity_error()
warnings.filterwarnings("ignore")

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)
def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n" + "==========" * 3 + "%s" % nowtime + "==========" * 3)
    print(str(info) + "\n")

class StepRunner:
    def __init__(self, args, model, 
                 loss_fn, accelerator=None,
                 metrics_dict=None,
                 ):
        self.model = model
        self.metrics_dict = metrics_dict
        self.accelerator = accelerator
        self.loss_fn = loss_fn
        self.args = args

    def step(self, batch):        
        logits, ssn_emebds = self.model(plm_model, gnn_model, batch, True)
        logits = logits.cuda()
        label = torch.cat([data.label for data in batch]).to(logits.device)
        pred_labels = torch.argmax(logits, 1).cpu().numpy()
        loss = self.loss_fn(logits, label)
        # compute metrics
        for name, metric_fn in self.metrics_dict.items():
            metric_fn.update(torch.argmax(logits, 1), label)
        return loss.item(), self.model, self.metrics_dict, pred_labels, ssn_emebds

    def train_step(self, batch):
        self.model.train()
        return self.step(batch)

    @torch.no_grad()
    def eval_step(self, batch):
        self.model.eval()
        return self.step(batch)

    def __call__(self, batch):
        return self.eval_step(batch)


class EpochRunner:
    def __init__(self, steprunner):
        self.steprunner = steprunner
        self.args = steprunner.args

    def __call__(self, dataloader):
        loop = tqdm(dataloader, total=len(dataloader), file=sys.stdout)
        total_loss = 0
        result_dict = {'name':[], 'aa_seq':[], 'label':[], 'pred_label':[]}
        ssn_embeds = []
        for batch in loop:
            step_loss, model, metrics_dict, pred_label, ssn_embed = self.steprunner(batch)
            result_dict["pred_label"].extend(pred_label)
            result_dict["name"].extend([data.name for data in batch])
            result_dict["aa_seq"].extend([data.aa_seq for data in batch])
            result_dict["label"].extend([data.label.item() for data in batch])
            ssn_embeds.append(ssn_embed)
            step_log = dict({f"eval/loss": round(step_loss, 3)})
            loop.set_postfix(**step_log)
            total_loss += step_loss
        ssn_embeds = torch.cat(ssn_embeds, dim=0)
        epoch_metric_results = {}
        for name, metric_fn in metrics_dict.items():
            epoch_metric_results[f"eval/{name}"] = metric_fn.compute().item()
            metric_fn.reset()
        avg_loss = total_loss / len(dataloader)
        epoch_metric_results[f"eval/loss"] = avg_loss
        return model, epoch_metric_results, result_dict, ssn_embeds

def eval_model(args, model, loss_fn, 
                accelerator=None, metrics_dict=None, 
                test_data=None
                ):
    model_path = os.path.join(args.model_dir, args.model_name)        
    if test_data:
        model.load_state_dict(torch.load(model_path)["state_dict"])
        test_step_runner = StepRunner(
            args=args, model=model, 
            loss_fn=loss_fn, accelerator=accelerator,
            metrics_dict=deepcopy(metrics_dict), 
            )
        test_epoch_runner = EpochRunner(test_step_runner)
        with torch.no_grad():
            model, epoch_metric_results, result_dict, ssn_embeds = test_epoch_runner(test_data)
        for name, metric in epoch_metric_results.items():
            epoch_metric_results[name] = [metric]
            print(f">>> {name}: {'%.3f'%metric}")
    
    if args.test_result_dir:
        os.makedirs(args.test_result_dir, exist_ok=True)
        pd.DataFrame(result_dict).to_csv(f"{args.test_result_dir}/test_result.csv", index=False)
        pd.DataFrame(epoch_metric_results).to_csv(f"{args.test_result_dir}/test_metrics.csv", index=False)
        torch.save(ssn_embeds, f"{args.test_result_dir}/ssn_embeds.pt")

def create_parser():
    parser = argparse.ArgumentParser()
    # model config
    parser.add_argument("--gnn", type=str, default="egnn", help="gat, gcn or egnn")
    parser.add_argument("--gnn_config", type=str, default="src/config/egnn.yaml", help="gnn config")
    parser.add_argument("--gnn_hidden_dim", type=int, default=512, help="hidden size of gnn")
    parser.add_argument("--plm", type=str, default="./model/facebook", help="esm param number")
    parser.add_argument("--plm_hidden_size", type=int, default=1280, help="hidden size of plm")
    parser.add_argument("--pooling_method", type=str, default="attention1d", help="pooling method")
    parser.add_argument("--pooling_dropout", type=float, default=0.1, help="pooling dropout")
    
    # training strategy
    parser.add_argument("--seed", type=int, default=3407, help="random seed")
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="weight_decay")
    parser.add_argument("--batch_token_num", type=int, default=4096, help="how many tokens in one batch")
    parser.add_argument("--max_graph_token_num", type=int, default=3000, help="max token num a graph has")
    parser.add_argument("--max_grad_norm", type=float, default=None, help="clip grad norm")
    
    # dataset
    parser.add_argument("--num_labels", type=int, default=2, help="number of labels")
    parser.add_argument("--problem_type", type=str, default="classification", help="classification or regression")
    parser.add_argument("--supv_dataset", type=str, help="supervise protein dataset")
    parser.add_argument("--test_file", type=str, help="test label file")
    parser.add_argument('--test_result_dir', type=str, default=None, help='test result directory')
    parser.add_argument("--feature_file", type=str, default=None, help="feature file")
    parser.add_argument("--feature_name", nargs="+", default=None, help="feature names")
    parser.add_argument("--feature_dim", type=int, default=0, help="feature dim")
    parser.add_argument("--feature_embed_dim", type=int, default=512, help="feature embed dim")
    parser.add_argument("--use_plddt_penalty", action="store_true", help="use plddt penalty")
    parser.add_argument("--c_alpha_max_neighbors", type=int, default=20, help="graph dataset K")
    parser.add_argument("--gnn_model_path", type=str, default="./model/protssn_k20_h512.pt", help="gnn model path")
    
    # load model
    parser.add_argument("--model_dir", type=str, default="./ckpt", help="model save dir")
    parser.add_argument("--model_name", type=str, default="feature512_norm_pp_attention1d_k20_h512_lr5e-4.pt", help="model name")

    args = parser.parse_args()
    return args

feature_dict ={}

if __name__ == "__main__":
    args = create_parser()
    args.gnn_config = yaml.load(open(args.gnn_config), Loader=yaml.FullLoader)[args.gnn]
    args.gnn_config["hidden_channels"] = args.gnn_hidden_dim
    
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    
    if args.feature_file:
        logger.info("***** Loading Feature *****")
        feature_df = pd.read_csv(args.feature_file)

        if type(args.feature_name) != list:
            args.feature_name = [args.feature_name]
        

        feature_aa_composition = ["1-C", "1-D", "1-E", "1-R", "1-H", "Turn-forming residues fraction"]
        if "aa_composition" in args.feature_name:
            aa_composition_df = feature_df[feature_aa_composition]
            args.feature_dim += len(feature_aa_composition)
        
        feature_gravy = ["GRAVY"]
        if "gravy" in args.feature_name:
            gravy_df = feature_df[feature_gravy]
            args.feature_dim += len(feature_gravy)
        
        feature_ss_composition = ["ss8-G", "ss8-H", "ss8-I", "ss8-B", "ss8-E", "ss8-T", "ss8-S", "ss8-P", "ss8-L", "ss3-H", "ss3-E", "ss3-C"]
        if "ss_composition" in args.feature_name:
            ss_composition_df = feature_df[feature_ss_composition]
            args.feature_dim += len(feature_ss_composition)
        
        feature_hygrogen_bonds = ["Hydrogen bonds", "Hydrogen bonds per 100 residues"]
        if "hygrogen_bonds" in args.feature_name:
            hygrogen_bonds_df = feature_df[feature_hygrogen_bonds]
            args.feature_dim += len(feature_hygrogen_bonds)
        
        feature_exposed_res_fraction = [
            "Exposed residues fraction by 5%", "Exposed residues fraction by 10%", "Exposed residues fraction by 15%", 
            "Exposed residues fraction by 20%", "Exposed residues fraction by 25%", "Exposed residues fraction by 30%", 
            "Exposed residues fraction by 35%", "Exposed residues fraction by 40%", "Exposed residues fraction by 45%", 
            "Exposed residues fraction by 50%", "Exposed residues fraction by 55%", "Exposed residues fraction by 60%", 
            "Exposed residues fraction by 65%", "Exposed residues fraction by 70%", "Exposed residues fraction by 75%", 
            "Exposed residues fraction by 80%", "Exposed residues fraction by 85%", "Exposed residues fraction by 90%", 
            "Exposed residues fraction by 95%", "Exposed residues fraction by 100%"
            ]
        if "exposed_res_fraction" in args.feature_name:
            exposed_res_fraction_df = feature_df[feature_exposed_res_fraction]
            args.feature_dim += len(feature_exposed_res_fraction)
        
        feature_pLDDT = ["pLDDT"]
        if "pLDDT" in args.feature_name:
            plddt_df = feature_df[feature_pLDDT]
            args.feature_dim += len(feature_pLDDT)
        
        
        for i in tqdm(range(len(feature_df))):
            name = feature_df["protein name"][i].split(".")[0]
            feature_dict[name] = []
            if "aa_composition" in args.feature_name:
                feature_dict[name] += list(aa_composition_df.iloc[i])
            if "gravy" in args.feature_name:
                feature_dict[name] += list(gravy_df.iloc[i])
            if "ss_composition" in args.feature_name:
                feature_dict[name] += list(ss_composition_df.iloc[i])
            if "hygrogen_bonds" in args.feature_name:
                feature_dict[name] += list(hygrogen_bonds_df.iloc[i])
            if "exposed_res_fraction" in args.feature_name:
                feature_dict[name] += list(exposed_res_fraction_df.iloc[i])
            if "pLDDT" in args.feature_name:
                feature_dict[name] += list(plddt_df.iloc[i])

        # print(len(feature_dict))
        # import json

        
        # # 指定要保存的文件名
        # file_name = 'feature_data.json'

        # # 将字典写入 JSON 文件
        # with open(file_name, 'w') as json_file:
        #     json.dump(feature_dict, json_file, indent=4)
        # exit()
    # load dataset
    logger.info("***** Loading Dataset *****")
    datatset_name = args.supv_dataset.split("/")[-1]
    pdb_dir = f"{args.supv_dataset}/esmfold_pdb"
    graph_dir = f"{datatset_name}_k{args.c_alpha_max_neighbors}"
    supervise_dataset = SuperviseDataset(
        root=args.supv_dataset,
        raw_dir=pdb_dir,
        name=graph_dir,
        c_alpha_max_neighbors=args.c_alpha_max_neighbors,
        pre_transform=NormalizeProtein(
            filename=f'norm/cath_k{args.c_alpha_max_neighbors}_mean_attr.pt'
        ),
    )

    label_dict, seq_dict = {}, {}
    def get_dataset(df):
        names, node_nums = [], []
        for name, label, seq in zip(df["name"], df["label"], df["aa_seq"]):
            names.append(name)
            label_dict[name] = label
            seq_dict[name] = seq
            node_nums.append(len(seq))
        return names, node_nums
    test_names, test_node_nums = get_dataset(pd.read_csv(args.test_file))
    
    # multi-thread load data will shuffle the order of data
    # so we need to save the information
    def process_data(name, fd):

        data = torch.load(f"{args.supv_dataset}/{graph_dir.capitalize()}/processed/{name}.pt")
        data.label = torch.tensor(label_dict[name]).view(1)
        data.aa_seq = seq_dict[name]
        data.name = name
        fe = None
        if args.feature_file:
            for key, value in fd.items():
 
                if key == str(name):
                    fe = value
     
        if fe is None:
            print(f'\'{name}\'')
        data.feature = torch.tensor(fe).view(1, -1)
        return data
    
    def collect_fn(batch):
        batch_data = []
        with ThreadPoolExecutor(max_workers=12) as executor:
            feature_d =deepcopy(feature_dict)
            futures = [executor.submit(process_data, name, feature_d) for name in batch]
            for future in as_completed(futures):
                graph = future.result()
                batch_data.append(graph)
        return batch_data
    
    test_dataloader = DataLoader(
        dataset=test_names, num_workers=4, 
        collate_fn=lambda x: collect_fn(x),
        batch_sampler=BatchSampler(
            node_num=test_node_nums,
            max_len=args.max_graph_token_num,
            batch_token_num=args.batch_token_num,
            shuffle=False
            )
        )
    
    logger.info("***** Load Model *****")
    # load model
    global plm_model
    plm_model = PLM_model(args).to(device)
    global gnn_model
    gnn_model = GNN_model(args).to(device)
    gnn_model.load_state_dict(torch.load(args.gnn_model_path))
    protssn_classification = ProtssnClassification(args)
    protssn_classification.to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    
    for param in plm_model.parameters():
        param.requires_grad = False
    for param in gnn_model.parameters():
        param.requires_grad = False
    logger.info(total_param_num(protssn_classification))
    logger.info(param_num(protssn_classification))

    accelerator = Accelerator()
    protssn_classification, test_dataloader = accelerator.prepare(
        protssn_classification, test_dataloader
    )
    metrics_dict = {
        "acc": BinaryAccuracy().to(device),
        "recall": BinaryRecall().to(device),
        "precision": BinaryPrecision().to(device),
        "mcc": BinaryMatthewsCorrCoef().to(device),
        "auroc": BinaryAUROC().to(device),
        "f1": BinaryF1Score().to(device),
    }
    
    logger.info("***** Running eval *****")
    logger.info("  Num test examples = %d", len(test_names))
    logger.info("  Batch token num = %d", args.batch_token_num)
    
    eval_model(
        args=args, model=protssn_classification, 
        loss_fn=loss_fn, 
        accelerator=accelerator, metrics_dict=metrics_dict, 
        test_data=test_dataloader
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants