Skip to content

Commit

Permalink
⬆️ update dependencies and interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
subercui committed Jun 1, 2022
1 parent e67d4db commit 5eaad6b
Show file tree
Hide file tree
Showing 15 changed files with 387 additions and 423 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ DeepVelo employs cell-specific kinetic rates and provides more accurate RNA velo
pip install deepvelo
```

The `dgl` package is required, the cpu version is installed by default. Feel free to install the [dgl cuda](https://www.dgl.ai/pages/start.html) version for GPU acceleration.
### Using GPU

The `dgl` cpu version is installed by default. For GPU acceleration, please install the proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment.

```bash
pip uninstall dgl # [optional] remove the cpu version
pip install dgl-cu101>=0.4.3 # an example for CUDA 10.1

```

### Install the development version
Expand Down
4 changes: 1 addition & 3 deletions deepvelo/data_loader/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class VeloDataset(Dataset):
def __init__(
self,
data_source,
basis,
train=True,
type="average",
topC=30,
Expand Down Expand Up @@ -164,7 +163,6 @@ def __init__(
self,
data_source,
batch_size,
basis="raw",
shuffle=True,
validation_split=0.0,
num_workers=1,
Expand All @@ -175,7 +173,7 @@ def __init__(
):
self.data_source = data_source
self.dataset = VeloDataset(
data_source, basis, train=training, type=type, topC=topC, topG=topG
data_source, train=training, type=type, topC=topC, topG=topG
)
self.shuffle = shuffle
self.is_large_batch = batch_size == len(self.dataset)
Expand Down
15 changes: 13 additions & 2 deletions deepvelo/parse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
from functools import reduce, partial
from operator import getitem
from datetime import datetime
from typing import Callable
from deepvelo.logger import setup_logging
from deepvelo.utils import read_json, write_json
from deepvelo.utils import read_json, write_json, validate_config


class ConfigParser:
def __init__(self, config, resume=None, modification=None, run_id=None):
def __init__(
self,
config,
resume=None,
modification=None,
run_id=None,
validator: Callable = validate_config,
):
"""
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
and logging module.
Expand All @@ -22,6 +30,9 @@ class to parse configuration json file. Handles hyperparameters for training, in
self._config = _update_config(config, modification)
self.resume = resume

if validator:
self._config = validator(self._config)

# set save_dir where trained model and log will be saved.
save_dir = Path(self.config["trainer"]["save_dir"])

Expand Down
17 changes: 7 additions & 10 deletions deepvelo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,19 @@ def default_configs(cls):
class Constants(object, metaclass=MetaConstants):
_default_configs = {
"name": "DeepVelo_Base",
"n_gpu": 1,
"n_gpu": 1, # whether to use GPU
"arch": {
"type": "VeloGCN",
"args": {
"n_genes": 2001,
"layers": [64, 64],
"dropout": 0.2,
"fc_layer": False, # whther add an output fully connected layer
"fc_layer": False,
"pred_unspliced": False,
},
},
"data_loader": {
"type": "VeloDataLoader",
"args": {
"basis": "pca",
"batch_size": 128,
"shuffle": False,
"validation_split": 0.0,
"num_workers": 2,
Expand All @@ -46,7 +43,6 @@ class Constants(object, metaclass=MetaConstants):
"topG": 20,
},
},
"online_test": "velo_mat_E10-12.npz",
"optimizer": {
"type": "Adam",
"args": {"lr": 0.001, "weight_decay": 0, "amsgrad": True},
Expand All @@ -64,12 +60,11 @@ class Constants(object, metaclass=MetaConstants):
"metrics": ["mse"],
"lr_scheduler": {"type": "StepLR", "args": {"step_size": 1, "gamma": 0.97}},
"trainer": {
"epochs": 280,
"epochs": 100,
"save_dir": "saved/",
"save_period": 1000,
"verbosity": 1,
"monitor": "min mse",
"guided_epochs": 0, # epochs showing t+1 neighbors, only work w/ mle
"early_stop": 1000,
"tensorboard": True,
},
Expand All @@ -87,7 +82,6 @@ def train(
batch_size, n_genes = adata.layers["Ms"].shape
configs["arch"]["args"]["n_genes"] = n_genes
configs["data_loader"]["args"]["batch_size"] = batch_size
print(configs)
config = ConfigParser(configs)
logger = config.get_logger("train")

Expand All @@ -100,7 +94,10 @@ def train(
model = config.init_obj("arch", module_arch, g=data_loader.dataset.g)
else:
model = config.init_obj("arch", module_arch)
logger.info(model)
logger.info(f"Beginning training of {configs['name']} ...")
if verbose:
logger.info(configs)
logger.info(model)

# get function handles of loss and metrics
criterion = getattr(module_loss, configs["loss"]["type"])
Expand Down
48 changes: 14 additions & 34 deletions deepvelo/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# from torchvision.utils import make_grid
from deepvelo.base import BaseTrainer
from deepvelo.utils import validate_config, inf_loop, MetricTracker
from deepvelo.utils import inf_loop, MetricTracker
from deepvelo.logger import TensorboardWriter


Expand All @@ -28,7 +28,6 @@ def __init__(
len_epoch=None,
):
super().__init__(model, criterion, metric_ftns, optimizer, config)
self.config = validate_config(config)
self.data_loader = data_loader
if len_epoch is None:
# epoch-based training
Expand All @@ -50,31 +49,21 @@ def __init__(
)

def _compute_core(self, batch_data):
if isinstance(batch_data, dgl.nodeflow.NodeFlow):
nf = batch_data
nf.copy_from_parent()
nf.layers[0].data["Ux_sz"] = nf.layers[0].data["Ux_sz"].to(self.device)
nf.layers[0].data["Sx_sz"] = nf.layers[0].data["Sx_sz"].to(self.device)
nf.layers[-1].data["Ux_sz"] = nf.layers[-1].data["Ux_sz"].to(self.device)
nf.layers[-1].data["Sx_sz"] = nf.layers[-1].data["Sx_sz"].to(self.device)
target = nf.layers[-1].data["velo"].to(self.device)
output = self.model(nf)
else:
data_dict = batch_data
x_u, x_s, target = data_dict["Ux_sz"], data_dict["Sx_sz"], data_dict["velo"]
x_u, x_s, target = (
x_u.to(self.device),
x_s.to(self.device),
target.to(self.device),
)
data_dict = batch_data
x_u, x_s, target = data_dict["Ux_sz"], data_dict["Sx_sz"], data_dict["velo"]
x_u, x_s, target = (
x_u.to(self.device),
x_s.to(self.device),
target.to(self.device),
)

if self.config["arch"]["args"]["pred_unspliced"]:
target_u = data_dict["velo_u"]
target_u = target_u.to(self.device)
# concate target to (batch, 2*genes), be careful of the order
target = torch.cat([target, target_u], dim=1)
if self.config["arch"]["args"]["pred_unspliced"]:
target_u = data_dict["velo_u"]
target_u = target_u.to(self.device)
# concate target to (batch, 2*genes), be careful of the order
target = torch.cat([target, target_u], dim=1)

output = self.model(x_u, x_s)
output = self.model(x_u, x_s)
return output, target

def _smooth_constraint_step(self):
Expand Down Expand Up @@ -104,15 +93,6 @@ def _train_epoch(self, epoch):
"""
self.model.train()
self.train_metrics.reset()
if "mle" in self.config["loss"]["type"]:
if "t+1" not in self.config["data_loader"]["args"]["type"]:
if epoch <= self.config["trainer"]["guided_epochs"]:
self.data_loader.dataset.neighbor_time = 1
print(
f"This is a guided_epoch, use neighbors at t{self.data_loader.dataset.neighbor_time}."
)
else:
self.data_loader.dataset.neighbor_time = 0
if (not self.data_loader.shuffle) and self.data_loader.is_large_batch:
loader = self.data_loader.dataset.large_batch(self.device)
else:
Expand Down
3 changes: 3 additions & 0 deletions deepvelo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .util import *
from .confidence import *
from .velocity import *
from .temporal import *
54 changes: 54 additions & 0 deletions deepvelo/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import scvelo as scv
from scvelo import logging as logg
from scvelo.core import sum as sum_
from anndata import AnnData

from deepvelo.utils.plot import dist_plot

Expand Down Expand Up @@ -116,3 +118,55 @@ def clip_and_norm_Ms_Mu(
logg.hint(f"replaced 'Mu' (adata.layers) with 'NMu'")

return scale_Ms, scale_Mu


def autoset_coeff_s(adata: AnnData, use_raw: bool = True) -> float:
"""
Automatically set the weighting for objective term of the spliced
read correlation. Modified from the scv.pl.proportions function.
Args:
adata (Anndata): Anndata object.
use_raw (bool): use raw data or processed data.
Returns:
float: weighting coefficient for objective term of the unpliced read
"""
layers = ["spliced", "unspliced", "ambigious"]
layers_keys = [key for key in layers if key in adata.layers.keys()]
counts_layers = [sum_(adata.layers[key], axis=1) for key in layers_keys]

if use_raw:
ikey, obs = "initial_size_", adata.obs
counts_layers = [
obs[ikey + layer_key] if ikey + layer_key in obs.keys() else c
for layer_key, c in zip(layers_keys, counts_layers)
]
counts_total = np.sum(counts_layers, 0)
counts_total += counts_total == 0
counts_layers = np.array([counts / counts_total for counts in counts_layers])
counts_layers = np.mean(counts_layers, axis=1)

spliced_counts = counts_layers[layers_keys.index("spliced")]
ratio = spliced_counts / counts_layers.sum()

if ratio < 0.7:
coeff_s = 0.5
print(
f"The ratio of spliced reads is {ratio*100:.1f}% (less than 70%). "
f"Suggest using coeff_s {coeff_s}."
)
elif ratio < 0.85:
coeff_s = 0.75
print(
f"The ratio of spliced reads is {ratio*100:.1f}% (between 70% and 85%). "
f"Suggest using coeff_s {coeff_s}."
)
else:
coeff_s = 1.0
print(
f"The ratio of spliced reads is {ratio*100:.1f}% (more than 85%). "
f"Suggest using coeff_s {coeff_s}."
)

return coeff_s
19 changes: 13 additions & 6 deletions deepvelo/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,19 @@ def validate_config(config: Mapping) -> Mapping:
"""
Return config if it is valid, otherwise raise an error.
"""
if config["trainer"]["guided_epochs"] > 0:
assert config["loss"]["type"] == "mle", "Using guided epochs requires MLE loss"
assert "t+1" not in config["data_loader"]["args"]["type"], (
"Using guided epochs requires prdecting neighbors at time t after "
"guided epochs"
)

# check if the gpu verion of dgl is installed
if config["n_gpu"] > 0:
import dgl

try:
dgl.graph([]).to("cuda")
except dgl.DGLError:
print(
"Config Warning: Set to use GPU, but GPU version of DGL is not "
"installed. Reset to use CPU instead."
)
config["n_gpu"] = 0

return config

Expand Down
Loading

0 comments on commit 5eaad6b

Please sign in to comment.