-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_grn.py
99 lines (78 loc) · 3.92 KB
/
infer_grn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
author: Akshata
timestamp: Thu August 24th 2023 11.40 AM
"""
import torch_geometric as pyg
from torch import nn
#import graph_transformer_pytorch as gt
import pandas as pd
import math
import os
import torch
from lightning import Trainer, seed_everything
import torch.nn as nn
import numpy as np
from pandas import read_csv
from torch import Tensor
from torch.utils.data import DataLoader,ConcatDataset
import wandb
from torch_geometric.utils import negative_sampling
import src.datamodules.grn_dataset_inference as dt
from argparse import ArgumentParser
from typing import Optional, Tuple
from torch_geometric.loader import DataListLoader
from src.models.grnformer.model import GRNFormerLitModule
#from dotenv import load_dotenv
#load_dotenv()
AVAIL_GPUS = [0]
NUM_NODES = 1
BATCH_SIZE = 1
DATALOADERS = 1
ACCELERATOR = "gpu"
DATASET_DIR = os.path.abspath("./")
EPS = 1e-15
"""
torch.set_default_tensor_type(torch.FloatTensor) # Ensure that the default tensor type is FloatTensor
3
if device.type == "cuda":
torch.backends.cudnn.benchmark = True # Enable cuDNN auto-tuner to find the best algorithm to use for hardware
torch.set_default_tensor_type(torch.cuda.FloatTensor) # Set the default tensor type to CUDA FloatTensor
torch.set_float32_matmul_precision('medium') # Set Tensor Core precision to medium
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Choose the device you want to use
if __name__ == "__main__":
seed_everything(123)
parser = ArgumentParser()
# parser = pl.Trainer.add_argparse_args(parser)
# parser = GRNFormerLinkPred.add_model_specific_args(parser)
parser.add_argument('--exp_file',type=str, default=False,
help="sets the expression file of datafolder"
"Enter the relative path to the root folder of the dataset eg:'Data/sc-RNA-seq/hESC/hESC_nonspecific_chipseq_500-ExpressionData.csv'" )
parser.add_argument('--tf_file',type=str,default=False,
help="sets the TF file of the data folder format single column CSV file"
"Enter the relative path to the transcription factopr file of the species")
parser.add_argument('--output_file',type=str, default=False,
help="sets the expression file of datafolder"
"Enter the relative path to the root folder of the dataset eg:'Data/sc-RNA-seq/hESC/hESC_nonspecific_chipseq_500-ExpressionData.csv'" )
args = parser.parse_args()
root = [os.path.dirname(os.path.abspath(args.exp_file))]
gene_expression_file=[os.path.abspath(args.exp_file)]
numnodes= len(pd.read_csv(os.path.abspath(args.exp_file)))
tffile = os.path.abspath(args.tf_file)
#if args.TFspecies=="human":
tf = pd.read_csv(tffile,header=None)[0].to_list()
TF_list = [tf]
regulation_file=[os.path.abspath(args.net_file)]
#os.makedirs(DATASET_DIR+"/"+args.save_dir, exist_ok=True)
All_test_dataset=[]
for i in range(len(root)):
dataset = dt.GeneExpressionDataset(root[i],gene_expression_file[i],TF_list[i])
All_test_dataset.append(dataset)
TestDatasets = ConcatDataset(All_test_dataset)
test_loader = DataListLoader(dataset=TestDatasets, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
model = GRNFormerLitModule(totalnodes=numnodes, tf_file = tffile, exp_file = os.path.abspath(args.exp_file), net_file=os.path.abspath(args.net_file), output_file=os.path.abspath(args.output_file))
print("Model loaded")
# trainer = pl.Trainer.from_argparse_args(args)
trainer = Trainer(devices=[0], num_nodes=1, accelerator = ACCELERATOR, detect_anomaly = True, enable_model_summary = True)
#trainer.test(model,dataloaders=test_loader, ckpt_path=os.path.abspath('Trainings/GRNFormer_epoch=26_valid_loss=0.645546.ckpt'))
trainer.predict(model,dataloaders=test_loader, ckpt_path=os.path.abspath('Trainings/GRNFormer_epoch=26_valid_loss=0.645546.ckpt'))