Skip to content

Commit

Permalink
Update experiments/albation_study_subset/ & scripts in train/src/
Browse files Browse the repository at this point in the history
  • Loading branch information
Kuanhao-Chao committed Apr 4, 2024
1 parent 4f8640e commit 823cb62
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 825 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ build/
private/

# train
train/src/MODEL*
train/src/MODEL/
train/src/MODEL_TEST/
train/src/INPUTS/Intersection/
Expand Down
244 changes: 74 additions & 170 deletions experiments/albation_study_subset/splam_albation_train.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions experiments/albation_study_subset/splam_albation_train_rsg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# python splam_albation_train.py --rsg $rsg --rsb $rsb
# done

for rsg in 1 2 3 4
for rsg in 1 2 3 4 5
do
rsb=4
echo python splam_albation_train.py --rsg $rsg --rsb $rsb
python splam_albation_train.py --rsg $rsg --rsb $rsb
# python splam_albation_train.py --rsg $rsg --rsb $rsb
done

# rsg=5
Expand Down
48 changes: 22 additions & 26 deletions experiments/albation_study_subset/splam_dataset_Chromsome.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@

from splam_utils import *
# MODEL_VERSION
SEQ_LEN = "800"

project_root = '/ccb/cybertron/khchao/splam-analysis-results/'
def split_seq_name(seq):
return seq[1:]


class myDatasetEval(Dataset):
def __init__(self, type, segment_len=800, shuffle=True, eval_select=None, test_f=""):
print("!!shuffle: ", shuffle, eval_select)
self.segment_len = segment_len
self.data = []
CONSTANT_SIZE_NEG = 10000
CONSTANT_SIZE_NEG = 20000
if type == "eval":
#################################
## Processing 'NEGATIVE_1' samples
Expand All @@ -34,15 +35,15 @@ def __init__(self, type, segment_len=800, shuffle=True, eval_select=None, test_f
seq_name = split_seq_name(line)
elif nidx % 2 == 1:
seq = line
X, Y = create_datapoints(seq, '-')
X, Y = create_datapoints(seq, '-', segment_len)
X = torch.Tensor(np.array(X))
Y = torch.Tensor(np.array(Y)[0])
if X.size()[0] != 800:
if X.size()[0] != int(segment_len):
print(X.size())
print(Y.size())
self.data.append([X, Y, seq_name])
nidx += 1
if nidx %10000 == 0:
if nidx %20000 == 0:
print("nidx: ", nidx)
if nidx >= CONSTANT_SIZE_NEG:
break
Expand All @@ -68,20 +69,19 @@ class myDatasetTrain(Dataset):
def __init__(self, process_type, segment_len=800, shuffle=True, eval_select=None, idx=0):
self.segment_len = segment_len
self.data = []
CONSTANT_SIZE = 0
CONSTANT_SIZE_NEG = 0
if process_type == "train":
pos_MANE_f = f'{project_root}/train/results/train_test_dataset/input_pos_mane/{segment_len}bp/train_pos_mane.shuffle.fa'
pos_ALTS_f = f'{project_root}/train/results/train_test_dataset/input_pos_alts/{segment_len}bp/train_pos_alts.shuffle.fa'
neg_1_f = f'{project_root}/train/results/train_test_dataset/input_neg_1/{segment_len}bp/train_neg_1.shuffle.fa'
neg_random_f = f'{project_root}/train/results/train_test_dataset/input_neg_random/{segment_len}bp/train_neg_random.shuffle.fa'
CONSTANT_SIZE = 30000
CONSTANT_SIZE = 20000
elif process_type == "test":
pos_MANE_f = f'{project_root}/train/results/train_test_dataset/input_pos_mane/{segment_len}bp/test_pos_mane.shuffle.fa'
pos_ALTS_f = f'{project_root}/train/results/train_test_dataset/input_pos_alts/{segment_len}bp/test_pos_alts.shuffle.fa'
neg_1_f = f'{project_root}/train/results/train_test_dataset/input_neg_1/{segment_len}bp/test_neg_1.shuffle.fa'
neg_random_f = f'{project_root}/train/results/train_test_dataset/input_neg_random/{segment_len}bp/test_neg_random.shuffle.fa'
CONSTANT_SIZE = 10000
CONSTANT_SIZE = 20000
CONSTANT_SIZE_NEG = 0
#################################
## Processing 'Positive_MANE' samples
#################################
Expand All @@ -96,10 +96,10 @@ def __init__(self, process_type, segment_len=800, shuffle=True, eval_select=None
seq_name = split_seq_name(line)
elif pp_MANE_idx % 2 == 1:
seq = line
X, Y = create_datapoints(seq, '+')
X, Y = create_datapoints(seq, '+', segment_len)
X = torch.Tensor(np.array(X))
Y = torch.Tensor(np.array(Y)[0])
if X.size()[0] != 800:
if X.size()[0] != int(segment_len):
print("seq_name: ", seq_name)
print(X.size())
print(Y.size())
Expand All @@ -124,10 +124,10 @@ def __init__(self, process_type, segment_len=800, shuffle=True, eval_select=None
seq_name = split_seq_name(line)
elif pp_alts_idx % 2 == 1:
seq = line
X, Y = create_datapoints(seq, '+')
X, Y = create_datapoints(seq, '+', segment_len)
X = torch.Tensor(np.array(X))
Y = torch.Tensor(np.array(Y)[0])
if X.size()[0] != 800:
if X.size()[0] != int(segment_len):
print("seq_name: ", seq_name)
print(X.size())
print(Y.size())
Expand Down Expand Up @@ -155,16 +155,16 @@ def __init__(self, process_type, segment_len=800, shuffle=True, eval_select=None
seq_name = split_seq_name(line)
elif n1idx % 2 == 1:
seq = line
X, Y = create_datapoints(seq, '-')
X, Y = create_datapoints(seq, '-', segment_len)
X = torch.Tensor(np.array(X))
Y = torch.Tensor(np.array(Y)[0])
if X.size()[0] != 800:
if X.size()[0] != int(segment_len):
print("seq_name: ", seq_name)
print(X.size())
print(Y.size())
self.data.append([X, Y, seq_name])
n1idx += 1
if n1idx %10000 == 0:
if n1idx %20000 == 0:
print("\tn1idx: ", n1idx)
if n1idx >= CONSTANT_SIZE_NEG:
break
Expand All @@ -183,16 +183,16 @@ def __init__(self, process_type, segment_len=800, shuffle=True, eval_select=None
seq_name = split_seq_name(line)
elif nridx % 2 == 1:
seq = line
X, Y = create_datapoints(seq, '-')
X, Y = create_datapoints(seq, '-', segment_len)
X = torch.Tensor(np.array(X))
Y = torch.Tensor(np.array(Y)[0])
if X.size()[0] != 800:
if X.size()[0] != int(segment_len):
print("seq_name: ", seq_name)
print(X.size())
print(Y.size())
self.data.append([X, Y, seq_name])
nridx += 1
if nridx %10000 == 0:
if nridx %20000 == 0:
print("\tnridx: ", nridx)
if nridx >= CONSTANT_SIZE_NEG:
break
Expand All @@ -214,11 +214,10 @@ def __getitem__(self, index):
return feature, label, seq_name


def get_train_dataloader(batch_size, TARGET, n_workers):
def get_train_dataloader(batch_size, TARGET, SEQ_LEN, n_workers):
"""Generate dataloader"""
trainset_origin = myDatasetTrain("train", int(SEQ_LEN))
trainset, valset = torch.utils.data.random_split(trainset_origin, [0.9, 0.1])
# trainset = myDatasetTrain("train", int(SEQ_LEN))
testset = myDatasetTrain("test", int(SEQ_LEN))
train_loader = DataLoader(
trainset,
Expand All @@ -244,9 +243,6 @@ def get_train_dataloader(batch_size, TARGET, n_workers):
# predicting splice / non-splice
#######################################
return train_loader, val_loader, test_loader
# torch.save(train_loader, "./INPUTS/"+SEQ_LEN+"bp/"+TARGET+"/train.pt")
# torch.save(val_loader, "./INPUTS/"+SEQ_LEN+"bp/"+TARGET+"/val.pt")
# torch.save(test_loader, "./INPUTS/"+SEQ_LEN+"bp/"+TARGET+"/test.pt")


def get_test_dataloader(batch_size, TARGET, n_workers, shuffle):
Expand All @@ -265,7 +261,6 @@ def get_test_dataloader(batch_size, TARGET, n_workers, shuffle):
# test_loader = torch.load("./INPUTS/"+SEQ_LEN+"bp/"+TARGET+"/test.pt")
return test_loader


def get_eval_dataloader(batch_size, TARGET, n_workers, shuffle, eval_select, test_f):
#######################################
# predicting splice / non-splice
Expand All @@ -281,4 +276,5 @@ def get_eval_dataloader(batch_size, TARGET, n_workers, shuffle, eval_select, tes
dataLoader = os.path.join(os.path.dirname(test_f), "splam_dataloader.pt")
print(f'[INFO] Loading dataset (shuffle: {test_f}')
torch.save(test_loader, f'{dataLoader}')
return test_loader
return test_loader

44 changes: 25 additions & 19 deletions experiments/albation_study_subset/splam_test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def main(MODEL_VERSION):
junc_counter = 0
target = "test_juncs"
os.makedirs(MODEL_OUTPUT_BASE+target, exist_ok=True)
d_score_tsv_f = MODEL_OUTPUT_BASE+target+"/splam_all_seq.score.d."+target+".tsv"
a_score_tsv_f = MODEL_OUTPUT_BASE+target+"/splam_all_seq.score.a."+target+".tsv"
d_score_fw = open(d_score_tsv_f, "a")
a_score_fw = open(a_score_tsv_f, "a")
train_loader, val_loader, test_loader = get_train_dataloader(BATCH_SIZE, MODEL_VERSION, N_WORKERS)
train_loader, val_loader, test_loader = get_train_dataloader(BATCH_SIZE, MODEL_VERSION, 800, 1)
print(f"[Info]: Finish loading data!", flush = True)
print("valid_iterator: ", len(test_loader))
LOG_OUTPUT_TEST_BASE = MODEL_OUTPUT_BASE + "/" + target + "/LOG/"
Expand Down Expand Up @@ -100,7 +96,7 @@ def main(MODEL_VERSION):
fw_test_log_J_threshold_recall = open(test_log_J_threshold_recall, 'w')

for model_idx in range(0, 15):
MODEL = f'/ccb/cybertron/khchao/splam-analysis-results/results/albation_study/MODEL/subset_10000/{MODEL_VERSION}/splam_{model_idx}.pt'
MODEL = f'{project_root}results/albation_study/MODEL/subset_20000/{MODEL_VERSION}/splam_{model_idx}.pt'
model = torch.load(MODEL)
model = model.to(device)
print("########################################")
Expand Down Expand Up @@ -152,20 +148,32 @@ def main(MODEL_VERSION):
A_YL = labels[is_expr, 1, :].to('cpu').detach().numpy()
A_YP = yp[is_expr, 1, :].to('cpu').detach().numpy()
D_YL = labels[is_expr, 2, :].to('cpu').detach().numpy()
D_YP = yp[is_expr, 2, :].to('cpu').detach().numpy()
np.savetxt(d_score_fw, D_YP, delimiter=" ")
np.savetxt(a_score_fw, A_YP, delimiter=" ")
D_YP = yp[is_expr, 2, :].to('cpu').detach().numpy()

donor_labels, donor_scores, acceptor_labels, acceptor_scores = get_donor_acceptor_scores(D_YL, A_YL, D_YP, A_YP)
donor_labels, donor_scores, acceptor_labels, acceptor_scores = get_donor_acceptor_scores(
D_YL, A_YL, D_YP, A_YP, 800)

# donor site metric
A_G_TP, A_G_FN, A_G_FP, A_G_TN, A_TP, A_FN, A_FP, A_TN = print_splice_site_statistics(A_YL, A_YP, JUNC_THRESHOLD, A_G_TP, A_G_FN, A_G_FP, A_G_TN, 800, "acceptor")
D_G_TP, D_G_FN, D_G_FP, D_G_TN, D_TP, D_FN, D_FP, D_TN = print_splice_site_statistics(D_YL, D_YP, JUNC_THRESHOLD, D_G_TP, D_G_FN, D_G_FP, D_G_TN, 800, "donor")

# Junction statistics
J_G_TP, J_G_FN, J_G_FP, J_G_TN, J_TP, J_FN, J_FP, J_TN = print_junc_statistics(D_YL, A_YL, D_YP, A_YP, JUNC_THRESHOLD, J_G_TP, J_G_FN, J_G_FP, J_G_TN)
# Top-k statistics
# junction metric
J_G_TP, J_G_FN, J_G_FP, J_G_TN, J_TP, J_FN, J_FP, J_TN = print_junc_statistics(
D_YL, A_YL, D_YP, A_YP, JUNC_THRESHOLD, J_G_TP, J_G_FN, J_G_FP, J_G_TN, 800)
A_accuracy, A_auprc = print_top_1_statistics(Acceptor_YL, Acceptor_YP)
D_accuracy, D_auprc = print_top_1_statistics(Donor_YL, Donor_YP)
# Donor and Acceptor statistics
A_G_TP, A_G_FN, A_G_FP, A_G_TN, A_TP, A_FN, A_FP, A_TN = print_threshold_statistics(Acceptor_YL, Acceptor_YP, JUNC_THRESHOLD, A_G_TP, A_G_FN, A_G_FP, A_G_TN)
D_G_TP, D_G_FN, D_G_FP, D_G_TN, D_TP, D_FN, D_FP, D_TN = print_threshold_statistics(Donor_YL, Donor_YP, JUNC_THRESHOLD, D_G_TP, D_G_FN, D_G_FP, D_G_TN)


# donor_labels, donor_scores, acceptor_labels, acceptor_scores = get_donor_acceptor_scores(D_YL, A_YL, D_YP, A_YP, 800)

# # Junction statistics
# J_G_TP, J_G_FN, J_G_FP, J_G_TN, J_TP, J_FN, J_FP, J_TN = print_junc_statistics(D_YL, A_YL, D_YP, A_YP, JUNC_THRESHOLD, J_G_TP, J_G_FN, J_G_FP, J_G_TN)
# # Top-k statistics
# A_accuracy, A_auprc = print_top_1_statistics(Acceptor_YL, Acceptor_YP)
# D_accuracy, D_auprc = print_top_1_statistics(Donor_YL, Donor_YP)
# # Donor and Acceptor statistics
# A_G_TP, A_G_FN, A_G_FP, A_G_TN, A_TP, A_FN, A_FP, A_TN = print_threshold_statistics(Acceptor_YL, Acceptor_YP, JUNC_THRESHOLD, A_G_TP, A_G_FN, A_G_FP, A_G_TN)
# D_G_TP, D_G_FN, D_G_FP, D_G_TN, D_TP, D_FN, D_FP, D_TN = print_threshold_statistics(Donor_YL, Donor_YP, JUNC_THRESHOLD, D_G_TP, D_G_FN, D_G_FP, D_G_TN)
batch_loss = loss.item()
epoch_loss += loss.item()
epoch_donor_acc += D_accuracy
Expand Down Expand Up @@ -206,8 +214,6 @@ def main(MODEL_VERSION):
print(f'Acceptor Precision: {A_G_TP/(A_G_TP+A_G_FP):.5f} | Acceptor Recall: {A_G_TP/(A_G_TP+A_G_FN):.5f} | TP: {A_G_TP} | FN: {A_G_FN} | FP: {A_G_FP} | TN: {A_G_TN}')
print("\n\n")

d_score_fw.close()
a_score_fw.close()
fw_test_log_loss.close()
fw_test_log_A_acc.close()
fw_test_log_A_auprc.close()
Expand Down Expand Up @@ -251,6 +257,6 @@ def plot_roc_curve(true_y, y_prob, label):
# main(MODEL_VERSION)

rsb = 4
for rsg in range(5, 6):
for rsg in range(2, 6):
MODEL_VERSION = f'rsg_{rsg}__rsb_{rsb}/'
main(MODEL_VERSION)
57 changes: 42 additions & 15 deletions experiments/albation_study_subset/splam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Optimizer, AdamW
from torch.nn import CrossEntropyLoss, BCELoss, BatchNorm1d, ModuleList
# from torch.nn import Module, BatchNorm1d, LazyBatchNorm1d, ReLU, LeakyReLU, Conv1d, LazyConv1d, ModuleList, Softmax, Sigmoid, Flatten, Dropout2d, Linear

import numpy as np
import re
import math
Expand All @@ -15,7 +17,6 @@
from splam_constant import *
from SPLAM import *

SEQ_LEN = 800
# fix random seed
def same_seeds(seed):
torch.manual_seed(seed)
Expand Down Expand Up @@ -50,17 +51,17 @@ def one_hot_encode_classifier(Xd):
#######################################
# This is for Conformer model
#######################################
def create_datapoints(seq, strand):
def create_datapoints(seq, strand, segment_len):
# seq = 'N'*(CL_MAX//2) + seq + 'N'*(CL_MAX//2)
seq = seq.upper().replace('A', '1').replace('C', '2')
seq = seq.replace('G', '3').replace('T', '4').replace('N', '0').replace('K', '0').replace('R', '0')
jn_start = JUNC_START
jn_end = JUNC_END
jn_start = segment_len//4
jn_end = segment_len//4 * 3
#######################################
# predicting pb for every bp
#######################################
X0 = np.asarray(list(map(int, list(seq))))
Y0 = [np.zeros(SEQ_LEN) for t in range(1)]
Y0 = [np.zeros(segment_len) for t in range(1)]
if strand == '+':
for t in range(1):
Y0[t][jn_start] = 2
Expand Down Expand Up @@ -190,23 +191,49 @@ def print_threshold_statistics(y_true, y_pred, threshold, TOTAL_TP, TOTAL_FN, TO
return TOTAL_TP, TOTAL_FN, TOTAL_FP, TOTAL_TN, LCL_TOTAL_TP, LCL_TOTAL_FN, LCL_TOTAL_FP, LCL_TOTAL_TN


def get_donor_acceptor_scores(D_YL, A_YL, D_YP, A_YP):
return D_YL[:, 200], D_YP[:, 200], A_YL[:, 600], A_YP[:, 600]
def get_donor_acceptor_scores(D_YL, A_YL, D_YP, A_YP, seq_len):
return D_YL[:, seq_len//4], D_YP[:, seq_len//4], A_YL[:, seq_len//4*3], A_YP[:, seq_len//4*3]


def get_junc_scores(D_YL, A_YL, D_YP, A_YP, choice):
def get_junc_scores(D_YL, A_YL, D_YP, A_YP, choice, seq_len):
if choice == "min":
junc_labels = np.minimum(D_YL[:, 200], A_YL[:, 600])
junc_scores = np.minimum(D_YP[:, 200], A_YP[:, 600])
junc_labels = np.minimum(D_YL[:, seq_len//4], A_YL[:, seq_len//4*3])
junc_scores = np.minimum(D_YP[:, seq_len//4], A_YP[:, seq_len//4*3])
elif choice == "avg":
junc_labels = np.minimum(D_YL[:, 200], A_YL[:, 600])
junc_scores = np.mean([D_YP[:, 200], A_YP[:, 600]], axis=0)
junc_labels = np.minimum(D_YL[:, seq_len//4], A_YL[:, seq_len//4*3])
junc_scores = np.mean([D_YP[:, seq_len//4], A_YP[:, seq_len//4*3]], axis=0)
return junc_labels, junc_scores


def print_junc_statistics(D_YL, A_YL, D_YP, A_YP, threshold, TOTAL_TP, TOTAL_FN, TOTAL_FP, TOTAL_TN):
label_junc_idx = (D_YL[:, 200]==1) & (A_YL[:, 600]==1)
predict_junc_idx = (D_YP[:, 200]>=threshold) & (A_YP[:, 600]>=threshold)
def print_splice_site_statistics(YL, YP, threshold, TOTAL_TP, TOTAL_FN, TOTAL_FP, TOTAL_TN, seq_len, choice):
if choice == "donor":
label_junc_idx = (YL[:, seq_len//4]==1)
label_nonjunc_idx = (YL[:, seq_len//4]==0)
predict_junc_idx = (YP[:, seq_len//4]>=threshold)
predict_nonjunc_idx = (YP[:, seq_len//4]<threshold)
elif choice == "acceptor":
label_junc_idx = (YL[:, seq_len//4*3]==1)
label_nonjunc_idx = (YL[:, seq_len//4*3]==0)
predict_junc_idx = (YP[:, seq_len//4*3]>=threshold)
predict_nonjunc_idx = (YP[:, seq_len//4*3]<threshold)
idx_true = np.nonzero(label_junc_idx == True)[0]
idx_pred = np.nonzero(predict_junc_idx == True)[0]
LCL_TOTAL_TP = np.size(np.intersect1d(idx_true, idx_pred))
LCL_TOTAL_FN = len(idx_true) - LCL_TOTAL_TP
LCL_TOTAL_FP = len(idx_pred) - LCL_TOTAL_TP
LCL_TOTAL_TN = len(YL) - LCL_TOTAL_TP - LCL_TOTAL_FN - LCL_TOTAL_FP
TOTAL_TP += LCL_TOTAL_TP
TOTAL_FN += LCL_TOTAL_FN
TOTAL_FP += LCL_TOTAL_FP
TOTAL_TN += LCL_TOTAL_TN
return TOTAL_TP, TOTAL_FN, TOTAL_FP, TOTAL_TN, LCL_TOTAL_TP, LCL_TOTAL_FN, LCL_TOTAL_FP, LCL_TOTAL_TN


def print_junc_statistics(D_YL, A_YL, D_YP, A_YP, threshold, TOTAL_TP, TOTAL_FN, TOTAL_FP, TOTAL_TN, seq_len):
label_junc_idx = (D_YL[:, seq_len//4]==1) & (A_YL[:, seq_len//4*3]==1)
label_nonjunc_idx = (D_YL[:, seq_len//4]==0) & (A_YL[:, seq_len//4*3]==0)
predict_junc_idx = (D_YP[:, seq_len//4]>=threshold) & (A_YP[:, seq_len//4*3]>=threshold)
predict_nonjunc_idx = (D_YP[:, seq_len//4]<threshold) | (A_YP[:, seq_len//4*3]<threshold)
idx_true = np.nonzero(label_junc_idx == True)[0]
idx_pred = np.nonzero(predict_junc_idx == True)[0]
LCL_TOTAL_TP = np.size(np.intersect1d(idx_true, idx_pred))
Expand Down
Loading

0 comments on commit 823cb62

Please sign in to comment.