Skip to content

Commit

Permalink
Merge pull request #182 from helicalAI/main
Browse files Browse the repository at this point in the history
Bring main to release
  • Loading branch information
bputzeys authored Feb 5, 2025
2 parents 70f4ad3 + 2c4e89d commit 86fa264
Show file tree
Hide file tree
Showing 12 changed files with 521 additions and 467 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ or in case you're installing from the Helical repo cloned locally:
pip install .[mamba-ssm]
```

Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d.
Note:
- Make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d.
- The package `causal_conv1d` requires `torch` to be installed already. First installing `helical` separately (without `[mamba-ssm]`) will install `torch` for you. A second installation (with `[mamba-ssm]`), installs the packages correctly.

### Singularity (Optional)
If you desire to run your code in a singularity file, you can use the [singularity.def](./singularity.def) file and build an apptainer with it:
Expand Down
6 changes: 6 additions & 0 deletions ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def test_process_data_mapping_to_ensemble_ids(self, geneformer, mock_data):
assert mock_data.var[mock_data.var['gene_symbols'] == 'PLEKHN1']['ensembl_id'].values[0] == 'ENSG00000187583'
assert mock_data.var[mock_data.var['gene_symbols'] == 'HES4']['ensembl_id'].values[0] == 'ENSG00000188290'

def test_process_data_mapping_to_ensemble_ids_resulting_in_0_genes(self, geneformer, mock_data):
# provide a gene that does not exist in the ensembl database
mock_data.var['gene_symbols'] = ['1', '2', '3']
with pytest.raises(ValueError):
geneformer.process_data(mock_data, gene_names="gene_symbols")

@pytest.mark.parametrize("invalid_model_names", ["gf-12L-35M-i2048", "gf-34L-30M-i5000"])
def test_pass_invalid_model_name(self, invalid_model_names):
with pytest.raises(ValueError):
Expand Down
7 changes: 7 additions & 0 deletions ci/tests/test_scgpt/test_scgpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ def test_ensure_data_validity__value_error(self, data):
self.scgpt.ensure_data_validity(data, "index", False)
assert "total_counts" in data.obs

def test_process_data_no_matching_genes(self):
self.dummy_data.var['gene_ids'] = [1]*self.dummy_data.n_vars
model = scGPT()

with pytest.raises(ValueError):
model.process_data(self.dummy_data, gene_names='gene_ids')

np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
csr_data.X = csr_matrix(np.random.poisson(1, size=(100, 5)), dtype=np.float32)
Expand Down
39 changes: 39 additions & 0 deletions ci/tests/test_uce/test_gene_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from helical.models.uce.gene_embeddings import load_gene_embeddings_adata
from anndata import AnnData
import pandas as pd
import numpy as np
from pathlib import Path
import pytest
from pathlib import Path
CACHE_DIR_HELICAL = Path(Path.home(), '.cache', 'helical', 'models')

class TestUCEGeneEmbeddings:

adata = AnnData(
X=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
obs=pd.DataFrame({"species": ["human", "mouse", "rat"]}),
var=pd.DataFrame({"gene": ["gene1", "gene2", "gene3"]})
)
species = ["human"]
embedding_model = "ESM2"
embeddings_path = Path(CACHE_DIR_HELICAL, 'uce', "protein_embeddings")

def test_load_gene_embeddings_adata_filtering_all_genes(self):
with pytest.raises(ValueError):
load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path)

def test_load_gene_embeddings_adata_filtering_no_genes(self):
self.adata.var_names = ['hoxa6', 'cav2', 'txk']
anndata, mapping_dict = load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path)
assert (anndata.var_names == ['hoxa6', 'cav2', 'txk']).all()
assert (anndata.obs == self.adata.obs).all().all()
assert (anndata.X == self.adata.X).all()
assert len(mapping_dict['human']) == 19790

def test_load_gene_embeddings_adata_filtering_some_genes(self):
self.adata.var_names = ['hoxa6', 'cav2', '1']
anndata, mapping_dict = load_gene_embeddings_adata(self.adata, self.species, self.embedding_model, self.embeddings_path)
assert (anndata.var_names == ['hoxa6', 'cav2']).all()
assert (anndata.obs == self.adata.obs).all().all()
assert (anndata.X == [[1, 2], [4, 5], [7, 8]]).all()
assert len(mapping_dict['human']) == 19790
Loading

0 comments on commit 86fa264

Please sign in to comment.