Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge main to release #171

Merged
merged 7 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ singularity/*
*.arrow
*zip
*.npy

*.json
*.pickle
*.pkl
*.bin
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ A lot of our models have been published by talend authors developing these excit
- [scanpy](https://github.com/scverse/scanpy)
- [transformers](https://github.com/huggingface/transformers)
- [scikit-learn](https://github.com/scikit-learn/scikit-learn)
- [GenePT](https://github.com/yiqunchen/GenePT)
- [Caduceus](https://github.com/kuleshov-group/caduceus)

### Licenses

Expand Down
230 changes: 230 additions & 0 deletions examples/notebooks/Genegpt-sample-run.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions helical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def filter(self, record):
from .models.uce.fine_tuning_model import UCEFineTuningModel
from .models.geneformer.model import Geneformer,GeneformerConfig
from .models.geneformer.fine_tuning_model import GeneformerFineTuningModel
from .models.genept.model import GenePT,GenePTConfig
from .models.scgpt.model import scGPT, scGPTConfig
from .models.scgpt.fine_tuning_model import scGPTFineTuningModel
from .models.hyena_dna.model import HyenaDNA, HyenaDNAConfig
Expand Down
1 change: 1 addition & 0 deletions helical/models/genept/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import GenePT,GenePTConfig
82 changes: 82 additions & 0 deletions helical/models/genept/genept_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional
from pathlib import Path
from helical.constants.paths import CACHE_DIR_HELICAL
from typing import Literal

class GenePTConfig():
"""Configuration class to use the GenePT Model.

Parameters
----------
model_name : Literal["gpt3.5"], optional, default="gpt3.5"
The name of the model for the embeddings.
batch_size : int, optional, default = 24
The batch size
emb_layer : int, optional, default = -1
The embedding layer
emb_mode : Literal["cls", "cell", "gene"], optional, default="cell"
The embedding mode
device : Literal["cpu", "cuda"], optional, default="cpu"
The device to use. Either use "cuda" or "cpu".
accelerator : bool, optional, default=False
The accelerator configuration. By default same device as model.
nproc: int, optional, default=1
Number of processes to use for data processing.
custom_attr_name_dict : dict, optional, default=None
A dictionary that contains the names of the custom attributes to be added to the dataset.
The keys of the dictionary are the names of the custom attributes, and the values are the names of the columns in adata.obs.
For example, if you want to add a custom attribute called "cell_type" to the dataset, you would pass custom_attr_name_dict = {"cell_type": "cell_type"}.
If you do not want to add any custom attributes, you can leave this parameter as None.
Returns
-------
GenePTConfig
The GenePT configuration object

"""
def __init__(
self,
model_name: Literal["gpt3.5"] = "gpt3.5",
batch_size: int = 24,
emb_layer: int = -1,
emb_mode: Literal["cls", "cell", "gene"] = "cell",
device: Literal["cpu", "cuda"] = "cpu",
accelerator: Optional[bool] = False,
nproc: int = 1,
custom_attr_name_dict: Optional[dict] = None
):

# model specific parameters
self.model_map = {
"gpt3.5": {
'input_size': 4096,
'special_token': True,
'embsize': 512,
}

}
if model_name not in self.model_map:
raise ValueError(f"Model name {model_name} not found in available models: {self.model_map.keys()}")
list_of_files_to_download = [
"genept/genept_embeddings/genept_embeddings.json",
]

embeddings_path = Path(CACHE_DIR_HELICAL, 'genept/genept_embeddings/genept_embeddings.json')

self.config = {
"embeddings_path": embeddings_path,
"model_name": model_name,
"batch_size": batch_size,
"emb_layer": emb_layer,
"emb_mode": emb_mode,
"device": device,
"accelerator": accelerator,
"input_size": self.model_map[model_name]["input_size"],
"special_token": self.model_map[model_name]["special_token"],
"embsize": self.model_map[model_name]["embsize"],
"nproc": nproc,
"custom_attr_name_dict": custom_attr_name_dict,
"list_of_files_to_download": list_of_files_to_download
}



149 changes: 149 additions & 0 deletions helical/models/genept/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from helical.models.base_models import HelicalRNAModel
import logging
import numpy as np
from anndata import AnnData
from helical.utils.downloader import Downloader
from helical.models.genept.genept_config import GenePTConfig
from helical.utils.mapping import map_ensembl_ids_to_gene_symbols
import logging
import scanpy as sc
import torch
import json
import torch

LOGGER = logging.getLogger(__name__)
class GenePT(HelicalRNAModel):
"""GenePT Model.

```

Parameters
----------
configurer : GenePTConfig, optional, default = default_configurer
The model configuration

Notes
-----


"""
default_configurer = GenePTConfig()
def __init__(self, configurer: GenePTConfig = default_configurer):
super().__init__()
self.configurer = configurer
self.config = configurer.config

downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
downloader.download_via_name(file)

with open(self.config['embeddings_path'],"r") as f:
self.embeddings = json.load(f)

LOGGER.info("GenePT initialized successfully.")

def process_data(self,
adata: AnnData,
gene_names: str = "index",
use_raw_counts: bool = True,
) -> AnnData:
"""
Processes the data for the GenePT model.

Parameters
----------
adata : AnnData
The AnnData object containing the data to be processed. GenePT uses Ensembl IDs to identify genes
and currently supports only human genes. If the AnnData object already has an 'ensembl_id' column,
the mapping step can be skipped.
gene_names : str, optional, default="index"
The column in `adata.var` that contains the gene names. If set to a value other than "ensembl_id",
the gene symbols in that column will be mapped to Ensembl IDs using the 'pyensembl' package,
which retrieves mappings from the Ensembl FTP server and loads them into a local database.
- If set to "index", the index of the AnnData object will be used and mapped to Ensembl IDs.
- If set to "ensembl_id", no mapping will occur.
Special case:
If the index of `adata` already contains Ensembl IDs, setting this to "index" will result in
invalid mappings. In such cases, create a new column containing Ensembl IDs and pass "ensembl_id"
as the value of `gene_names`.
use_raw_counts : bool, optional, default=True
Determines whether raw counts should be used.

Returns
-------
Dataset
The tokenized dataset in the form of a Huggingface Dataset object.
"""
LOGGER.info(f"Processing data for GenePT.")
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

# map gene symbols to ensemble ids if provided
if gene_names == "ensembl_id":
if (adata.var[gene_names].str.startswith("ENS").all()) or (adata.var[gene_names].str.startswith("None").any()):
message = "It seems an anndata with 'ensemble ids' and/or 'None' was passed. " \
"Please set gene_names='ensembl_id' and remove 'None's to skip mapping."
LOGGER.info(message)
raise ValueError(message)
adata = map_ensembl_ids_to_gene_symbols(adata, gene_names)

n_top_genes = 1000
LOGGER.info(f"Filtering the top {n_top_genes} highly variable genes.")
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor='seurat_v3')
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

genes_names = adata.var_names[adata.var['highly_variable']].tolist()
adata = adata[:,genes_names]

LOGGER.info(f"Successfully processed the data for GenePT.")
return adata

def get_text_embeddings(self, dataset: AnnData) -> np.array:
"""Gets the gene embeddings from the GenePT model

Parameters
----------
dataset : AnnData
The tokenized dataset containing the processed data

Returns
-------
np.array
The gene embeddings in the form of a numpy array
"""
# Generate a response
raw_embeddings = dataset.var_names
weights = []
count_missed = 0
gene_list = []
for i,emb in enumerate(raw_embeddings):
gene = self.embeddings.get(emb.upper(),None)
if gene is not None:
weights.append(gene['embeddings'])
gene_list.append(emb)
else:
count_missed += 1
LOGGER.info("Couln't find {} genes in embeddings".format(count_missed))

weights = torch.Tensor(weights)
embeddings = torch.matmul(torch.Tensor(dataset[:,gene_list].X.toarray()),weights)
return embeddings

def get_embeddings(self, dataset: AnnData) -> torch.Tensor:
"""Gets the gene embeddings from the GenePT model

Parameters
----------
dataset : Dataset
The tokenized dataset containing the processed data

Returns
-------
np.array
The gene embeddings in the form of a numpy array
"""
LOGGER.info(f"Inference started:")
# Generate a response
embeddings = self.get_text_embeddings(dataset)
embeddings = (embeddings/(np.linalg.norm(embeddings,axis=1)).reshape(-1,1))
return embeddings
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "helical"
version = "0.0.1a18"
version = "0.0.1a19"
authors = [
{ name="Helical Team", email="support@helical-ai.com" },
]
Expand Down
Loading