-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
758 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
import torch | ||
import random | ||
import time | ||
import numpy as np | ||
import numpy | ||
import pandas as pd | ||
import scanpy as sc | ||
import pickle | ||
import scipy.sparse as sp | ||
from torch.backends import cudnn | ||
from preprocess import read_data | ||
import pandas as pd | ||
|
||
def load_data(args): | ||
"""Load data, including ST and ATAC data.""" | ||
#read data | ||
data = read_data(args) | ||
|
||
print('Data reading finished!') | ||
return data | ||
|
||
def sparse_mx_to_torch_sparse_tensor(sparse_mx): | ||
"""Convert a scipy sparse matrix to a torch sparse tensor.""" | ||
sparse_mx = sparse_mx.tocoo().astype(np.float32) | ||
indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) | ||
values = torch.from_numpy(sparse_mx.data) | ||
shape = torch.Size(sparse_mx.shape) | ||
return torch.sparse.FloatTensor(indices, values, shape) | ||
|
||
# ====== Graph preprocessing | ||
def preprocess_graph(adj): | ||
adj = sp.coo_matrix(adj) | ||
adj_ = adj + sp.eye(adj.shape[0]) | ||
rowsum = np.array(adj_.sum(1)) | ||
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) | ||
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() | ||
return sparse_mx_to_torch_sparse_tensor(adj_normalized) | ||
|
||
def get_edge_index(matrix): | ||
edge_index = [[], []] | ||
for i in range(matrix.shape[0]): | ||
for j in range(matrix.shape[1]): | ||
if matrix[i][j] != 0: | ||
edge_index[0].append(i) | ||
edge_index[1].append(j) | ||
return torch.LongTensor(edge_index) #将列表转为张量 | ||
|
||
def fix_seed(seed): | ||
#seed = 666 | ||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
cudnn.deterministic = True | ||
cudnn.benchmark = False | ||
|
||
os.environ['PYTHONHASHSEED'] = str(seed) | ||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' | ||
|
||
def normalize_adj(adj): | ||
"""Symmetrically normalize adjacency matrix.""" | ||
adj = sp.coo_matrix(adj) | ||
rowsum = np.array(adj.sum(1)) | ||
d_inv_sqrt = np.power(rowsum, -0.5).flatten() | ||
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. | ||
d_mat_inv_sqrt = sp.diags(d_inv_sqrt) | ||
adj = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) | ||
return adj.toarray() | ||
|
||
def preprocess_adj(adj): | ||
"""Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" | ||
adj_normalized = normalize_adj(adj)+np.eye(adj.shape[0]) | ||
return adj_normalized | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
import torch | ||
import argparse | ||
import warnings | ||
import time | ||
from train import Train | ||
from inits import load_data, fix_seed | ||
from utils import UMAP, plot_weight_value | ||
import pickle | ||
|
||
warnings.filterwarnings("ignore") | ||
os.environ['R_HOME'] = '/scbio4/tools/R/R-4.0.3_openblas/R-4.0.3' | ||
|
||
parser = argparse.ArgumentParser(description='PyTorch implementation of spatial multi-omics data integration') | ||
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Initial learning rate.') # 0.0001 | ||
parser.add_argument('--epochs', type=int, default=1500, help='Number of epochs to train.') # 1500 Mouse Brain #1000 SPOTS_spleen_rep1 #700 Thymus | ||
parser.add_argument('--weight_decay', type=float, default=0.0000, help='Weight for L2 loss on embedding matrix.') # 5e-4 | ||
parser.add_argument('--datatype', type=str, default='SPOTS', help='Data type.') | ||
parser.add_argument('--input', type=str, default='/home/yahui/anaconda3/work/SpatialGlue_omics/data/', help='Input path.') | ||
parser.add_argument('--output', type=str, default='/home/yahui/anaconda3/work/SpatialGlue_omics/output/', help='output path.') | ||
parser.add_argument('--random_seed', type=int, default=2022, help='Random seed') # 50 | ||
parser.add_argument('--dim_input', type=int, default=3000, help='Dimension of input features') # 100 | ||
parser.add_argument('--dim_output', type=int, default=64, help='Dimension of output features') # 64 | ||
parser.add_argument('--n_neighbors', type=int, default=6, help='Number of sampling neighbors') # 6 | ||
parser.add_argument('--n_clusters', type=int, default=9, help='Number of clustering') # mouse brain 15 thymus 9 spleen 5 | ||
args = parser.parse_args() | ||
|
||
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') | ||
t = time.time() | ||
fix_seed(args.random_seed) | ||
|
||
args.dataset = 'SPOTS_spleen_rep1' | ||
|
||
if args.datatype == 'Stereo-CITE-seq': | ||
args.n_clusters = 8 | ||
args.epochs = 1500 | ||
elif args.datatype == 'Spatial-ATAC-RNA-seq': | ||
args.n_clusters = 15 | ||
args.epochs = 1500 | ||
elif args.datatype == 'SPOTS': | ||
args.n_clusters = 6 | ||
args.epochs = 900 | ||
|
||
print('>>>>>>>>>>>>>>>>> {} <<<<<<<<<<<<<<<<'.format(args.dataset)) | ||
|
||
data = load_data(args) | ||
adata_omics1, adata_omics2 = data['adata_omics1'], data['adata_omics2'] | ||
|
||
#start to train the model | ||
trainer = Train(args, device, data) | ||
emb_omics1, emb_omics2, emb_combined, alpha = trainer.train() | ||
print('time:', time.time()-t) | ||
|
||
adata_omics1.obsm['emb'] = emb_omics1 | ||
adata_omics2.obsm['emb'] = emb_omics2 | ||
adata_omics1.obsm['emb_combined'] = emb_combined | ||
adata_omics2.obsm['emb_combined'] = emb_combined | ||
|
||
adata_omics1.obsm['alpha'] = alpha | ||
|
||
# umap | ||
adata_combined = UMAP(adata_omics1, adata_omics2, args) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.parameter import Parameter | ||
from torch.nn.modules.module import Module | ||
import numpy as np | ||
#from torch_geometric.nn import GCNConv, GATConv | ||
|
||
class Encoder_omics(Module): | ||
""" | ||
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 | ||
""" | ||
def __init__(self, dim_in_feat_omics1, dim_out_feat_omics1, dim_in_feat_omics2, dim_out_feat_omics2, dropout=0.0, act=F.relu): | ||
super(Encoder_omics, self).__init__() | ||
self.dim_in_feat_omics1 = dim_in_feat_omics1 | ||
self.dim_in_feat_omics2 = dim_in_feat_omics2 | ||
self.dim_out_feat_omics1 = dim_out_feat_omics1 | ||
self.dim_out_feat_omics2 = dim_out_feat_omics2 | ||
self.dropout = dropout | ||
self.act = act | ||
|
||
self.encoder_omics1 = Encoder(self.dim_in_feat_omics1, self.dim_out_feat_omics1) | ||
self.decoder_omics1 = Decoder(self.dim_out_feat_omics1, self.dim_in_feat_omics1) | ||
self.encoder_omics2 = Encoder(self.dim_in_feat_omics2, self.dim_out_feat_omics2) | ||
self.decoder_omics2 = Decoder(self.dim_out_feat_omics2, self.dim_in_feat_omics2) | ||
|
||
self.atten_omics1 = AttentionLayer(self.dim_out_feat_omics1, self.dim_out_feat_omics1) | ||
self.atten_omics2 = AttentionLayer(self.dim_out_feat_omics2, self.dim_out_feat_omics2) | ||
self.atten_cross = AttentionLayer(self.dim_out_feat_omics1, self.dim_out_feat_omics2) | ||
|
||
self.discriminator = Discriminator(self.dim_out_feat_omics1) | ||
|
||
#def forward(self, feat1, feat2, adj1_1, adj1_2, adj2_1, adj2_2): | ||
def forward(self, features_omics1, features_omics2, adj_spatial_omics1, adj_feature_omics1, adj_spatial_omics2, adj_feature_omics2): | ||
# graph1 | ||
emb_latent_spatial_omics1 = self.encoder_omics1(features_omics1, adj_spatial_omics1) | ||
emb_latent_spatial_omics2 = self.encoder_omics2(features_omics2, adj_spatial_omics2) | ||
|
||
# graph2 | ||
emb_latent_feature_omics1 = self.encoder_omics1(features_omics1, adj_feature_omics1) | ||
emb_latent_feature_omics2 = self.encoder_omics2(features_omics2, adj_feature_omics2) | ||
|
||
# within-modality attention aggregation | ||
emb_latent_omics1, alpha_omics1 = self.atten_omics1(emb_latent_spatial_omics1, emb_latent_feature_omics1) | ||
emb_latent_omics2, alpha_omics2 = self.atten_omics2(emb_latent_spatial_omics2, emb_latent_feature_omics2) | ||
|
||
# between-modality attention aggregation | ||
emb_latent_combined, alpha_omics_1_2 = self.atten_cross(emb_latent_omics1, emb_latent_omics2) | ||
|
||
# reconstruct expression matrix using two modality-specific decoders, respectively | ||
emb_recon_omics1 = self.decoder_omics1(emb_latent_combined, adj_spatial_omics1) | ||
emb_recon_omics2 = self.decoder_omics2(emb_latent_combined, adj_spatial_omics2) | ||
|
||
emb_latent_omics1_across_recon = self.encoder_omics2(self.decoder_omics2(emb_latent_omics1, adj_spatial_omics2), adj_spatial_omics2) # consistent encoding # dim=64 | ||
emb_latent_omics2_across_recon = self.encoder_omics1(self.decoder_omics1(emb_latent_omics2, adj_spatial_omics1), adj_spatial_omics1) | ||
|
||
score_omics1 = self.discriminator(emb_latent_omics1) | ||
score_omics2 = self.discriminator(emb_latent_omics2) | ||
score_omics1=torch.squeeze(score_omics1, dim=1) | ||
score_omics2=torch.squeeze(score_omics2, dim=1) | ||
|
||
results = {'emb_latent_omics1':emb_latent_omics1, | ||
'emb_latent_omics2':emb_latent_omics2, | ||
'emb_latent_combined':emb_latent_combined, | ||
'emb_recon_omics1':emb_recon_omics1, | ||
'emb_recon_omics2':emb_recon_omics2, | ||
'emb_latent_omics1_across_recon':emb_latent_omics1_across_recon, | ||
'emb_latent_omics2_across_recon':emb_latent_omics2_across_recon, | ||
'alpha':alpha_omics_1_2, | ||
'score_omics1':score_omics1, | ||
'score_omics2':score_omics2} | ||
|
||
return results | ||
|
||
class Encoder(Module): | ||
""" | ||
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 | ||
""" | ||
def __init__(self, in_feat, out_feat, dropout=0.0, act=F.relu): | ||
super(Encoder, self).__init__() | ||
self.in_feat = in_feat | ||
self.out_feat = out_feat | ||
self.dropout = dropout | ||
self.act = act | ||
|
||
self.weight = Parameter(torch.FloatTensor(self.in_feat, self.out_feat)) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.xavier_uniform_(self.weight) | ||
|
||
def forward(self, feat, adj): | ||
x = torch.mm(feat, self.weight) | ||
x = torch.spmm(adj, x) | ||
|
||
return x | ||
|
||
class Decoder(Module): | ||
""" | ||
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 | ||
""" | ||
def __init__(self, in_feat, out_feat, dropout=0.0, act=F.relu): | ||
super(Decoder, self).__init__() | ||
self.in_feat = in_feat | ||
self.out_feat = out_feat | ||
self.dropout = dropout | ||
self.act = act | ||
|
||
self.weight = Parameter(torch.FloatTensor(self.in_feat, self.out_feat)) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.xavier_uniform_(self.weight) | ||
|
||
def forward(self, feat, adj): | ||
x = torch.mm(feat, self.weight) | ||
x = torch.spmm(adj, x) | ||
|
||
return x | ||
|
||
class Discriminator(nn.Module): | ||
"""Latent space discriminator""" | ||
def __init__(self, dim_input, n_hidden=50, n_out=1): | ||
super(Discriminator, self).__init__() | ||
self.dim_input = dim_input | ||
self.n_hidden = n_hidden | ||
self.n_out = n_out | ||
|
||
self.net = nn.Sequential( | ||
nn.Linear(dim_input, n_hidden), | ||
nn.LeakyReLU(inplace=True), | ||
nn.Linear(n_hidden, 2*n_hidden), | ||
nn.LeakyReLU(inplace=True), | ||
#nn.Linear(n_hidden, n_hidden), | ||
#nn.ReLU(inplace=True), | ||
#nn.Linear(n_hidden, n_hidden), | ||
#nn.ReLU(inplace=True), | ||
#nn.Linear(n_hidden, n_hidden), | ||
#nn.ReLU(inplace=True), | ||
nn.Linear(2*n_hidden,n_out), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
class AttentionLayer(Module): | ||
def __init__(self, in_feat, out_feat, dropout=0.0, act=F.relu): | ||
super(AttentionLayer, self).__init__() | ||
self.in_feat = in_feat | ||
self.out_feat = out_feat | ||
|
||
self.w_omega = Parameter(torch.FloatTensor(in_feat, out_feat)) | ||
self.u_omega = Parameter(torch.FloatTensor(out_feat, 1)) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
torch.nn.init.xavier_uniform_(self.w_omega) | ||
torch.nn.init.xavier_uniform_(self.u_omega) | ||
|
||
def forward(self, emb1, emb2): | ||
emb = [] | ||
emb.append(torch.unsqueeze(torch.squeeze(emb1), dim=1)) | ||
emb.append(torch.unsqueeze(torch.squeeze(emb2), dim=1)) | ||
self.emb = torch.cat(emb, dim=1) | ||
|
||
self.v = F.tanh(torch.matmul(self.emb, self.w_omega)) | ||
self.vu= torch.matmul(self.v, self.u_omega) | ||
self.alpha = F.softmax(torch.squeeze(self.vu) + 1e-6) | ||
|
||
emb_combined = torch.matmul(torch.transpose(self.emb,1,2), torch.unsqueeze(self.alpha, -1)) | ||
|
||
return torch.squeeze(emb_combined), self.alpha | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.