-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_scvi_integration.R
107 lines (95 loc) · 3.74 KB
/
run_scvi_integration.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#' Function to perform SCVI integration on a Seurat object
#'
#' @param object Seurat object after variable gene selection and normalization. Can be object itself or path to an RDS file. Note: raw gene counts are used in either case.
#' @param python_path path to desired python version for reticulate setup
#' @param random_seed random seed for SCVI
#' @param project_path path to project directory for loading R environment. If NA, does not use renv.
#' @param batch_col name of metadata column specifying batch. Defaults to orig.ident, where each sample is its own batch.
#' @param RDS_path location to save RDS package. Ignored if save_RDS is FALSE.
#' @param use_GPU toggle to true for running on GPU
#' @param hvgs number of highly variable genes to use in the model. Default: 3000.
#' @param n_layers Number of hidden layers used for encoder and decoder NNs
#' @param dropout_rate Dropout rate for neural networks
#' @param early_stopping
#' @param n_latent
#' @return a Seurat object with new embeddings in the "SCVI" slot
#' @import renv
#' @import Seurat
#' @import sceasy
#' @import reticulate
#' @export
#'
run_SCVI_integration = function(object,
python_path = NA,
project_path = NA,
random_seed = 12345,
batch_col = "orig.ident",
n_epochs = NA,
RDS_path = NA,
use_GPU = FALSE,
hvgs = 3000,
n_layers = 1,
dropout_rate = 0.1,
early_stopping = FALSE,
n_latent = 10)
{
#if required, set up R environment
if(!is.na(project_path))
{
#need renv installed outside of project folder for this to work
if(!("renv" %in% installed.packages()))
{
install.packages("renv")
}
require(renv)
renv::load(project_path)
}
#set up reticulate
require(reticulate)
if(!is.na(python_path))
{
use_python(python_path, required = T)
}
py_config()
#import python packages
sc = import('scanpy', convert = FALSE)
scvi = import('scvi', convert = FALSE)
scvi$settings$progress_bar_style = 'tqdm'
scvi$settings$seed = as.integer(random_seed) #causes scvi error unless cast
#if filepath is given as input, load as an RDS file
if(class(object) == "character")
{
object = readRDS(object)
}
#get variable features (assume already calculated with method of choice)
top_n = head(VariableFeatures(object), hvgs)
object_scvi = object[top_n, ]
#convert to annData object
annData = convertFormat(object_scvi,
from="seurat",
to="anndata",
main_layer="counts",
drop_single_values=FALSE)
#create model
scvi$model$SCVI$setup_anndata(annData, batch_key = batch_col)
model = scvi$model$SCVI(annData, dropout_rate = as.double(dropout_rate), n_layers = as.integer(n_layers))
#if supplied, set number of epochs
# train the model
if(is.na(n_epochs))
{
model$train(use_gpu = use_GPU, early_stopping = as.logical(early_stopping))
} else
{
model$train(use_gpu = use_GPU, max_epochs = as.integer(n_epochs), early_stopping = as.logical(early_stopping))
}
#get latent representation and place back into Seurat object
latent = model$get_latent_representation()
latent = as.matrix(latent)
rownames(latent) = colnames(object)
object[["SCVI"]] = CreateDimReducObject(embeddings = latent, key = "SCVI_", assay = DefaultAssay(object))
#finally, return as necessary
if(!is.na(RDS_path))
{
saveRDS(object, RDS_path)
}
}