diff --git a/src/scimilarity/cell_annotation.py b/src/scimilarity/cell_annotation.py index 7b4806e..85cea07 100644 --- a/src/scimilarity/cell_annotation.py +++ b/src/scimilarity/cell_annotation.py @@ -10,9 +10,7 @@ def __init__( self, model_path: str, use_gpu: bool = False, - parameters: Optional[dict] = None, filenames: Optional[dict] = None, - residual: bool = False, ): """Constructor. @@ -22,12 +20,8 @@ def __init__( Path to the directory containing model files. use_gpu: bool, default: False Use GPU instead of CPU. - parameters: dict, optional, default: None - Use a dictionary of custom model parameters instead of infering from model files. filenames: dict, optional, default: None - Use a dictionary of custom filenames for model files instead default. - residual: bool, default: False - Use residual connections. + Use a dictionary of custom filenames for files instead default. Examples -------- @@ -39,17 +33,14 @@ def __init__( super().__init__( model_path=model_path, use_gpu=use_gpu, - parameters=parameters, - filenames=filenames, - residual=residual, ) - if filenames is None: - filenames = {} - self.annotation_path = os.path.join(model_path, "annotation") os.makedirs(self.annotation_path, exist_ok=True) + if filenames is None: + filenames = {} + self.filenames["knn"] = os.path.join( self.annotation_path, filenames.get("knn", "labelled_kNN.bin") ) @@ -220,10 +211,10 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): >>> ca.blocklist_celltypes(["T cell"]) """ + self.reset_knn() self.blocklist = set(labels) self.safelist = None - self.reset_knn() for i, celltype_name in self.idx2label.items(): if celltype_name in self.blocklist: self.knn.mark_deleted(i) # mark blocklist @@ -321,9 +312,10 @@ def get_predictions_knn( embeddings=embeddings, k=k, ef=ef ) end_time = time.time() - print( - f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min" - ) + if not disable_progress: + print( + f"Get nearest neighbors finished in: {float(end_time - start_time) / 60} min" + ) stats = { "hits": [], "hits_weighted": [], @@ -345,8 +337,10 @@ def get_predictions_knn( celltype = defaultdict(float) celltype_weighted = defaultdict(float) for neighbor, dist in zip(nns, d_nns): - celltype[self.idx2label[neighbor]] += 1 - celltype_weighted[self.idx2label[neighbor]] += 1 / max(dist, 1e-6) + celltype[self.idx2label[neighbor]] += 1.0 + celltype_weighted[self.idx2label[neighbor]] += 1.0 / float( + max(dist, 1e-6) + ) # predict based on consensus max occurrence if weighting: predictions.append( diff --git a/src/scimilarity/cell_query.py b/src/scimilarity/cell_query.py index 90a36d8..1d186a9 100644 --- a/src/scimilarity/cell_query.py +++ b/src/scimilarity/cell_query.py @@ -10,11 +10,9 @@ def __init__( self, model_path: str, use_gpu: bool = False, - parameters: Optional[dict] = None, filenames: Optional[dict] = None, metadata_tiledb_uri: str = "cell_metadata", embedding_tiledb_uri: str = "cell_embedding", - residual: bool = False, load_knn: bool = True, ): """Constructor. @@ -25,16 +23,12 @@ def __init__( Path to the model directory. use_gpu: bool, default: False Use GPU instead of CPU. - parameters: dict, optional, default: None - Use a dictionary of custom model parameters instead of infering from model files. filenames: dict, optional, default: None Use a dictionary of custom filenames for model files instead default. metadata_tiledb_uri: str, default: "cell_metadata" Relative path to the directory containing the tiledb cell metadata storage. embedding_tiledb_uri: str, default: "cell_embedding" Relative path to the directory containing the tiledb cell embedding storage. - residual: bool, default: False - Use residual connections. load_knn: bool, default: True Load the knn index. Set to False if knn is not needed. @@ -51,11 +45,10 @@ def __init__( super().__init__( model_path=model_path, use_gpu=use_gpu, - parameters=parameters, - filenames=filenames, - residual=residual, ) + self.cellsearch_path = os.path.join(model_path, "cellsearch") + os.makedirs(self.cellsearch_path, exist_ok=True) if filenames is None: filenames = {} @@ -105,7 +98,7 @@ def __init__( cell_metadata[c] = cell_metadata[c].replace("NA", np.nan) cell_metadata = cell_metadata.astype(convert_dict) tiledb.from_pandas(metadata_tiledb_uri, cell_metadata) - self.cell_metadata = tiledb.open_dataframe(metadata_tiledb_uri) + self.cell_metadata = tiledb.open(metadata_tiledb_uri, "r").df[:] # get cell embeddings: create tiledb storage if it does not exist embedding_tiledb_uri = os.path.join(self.cellsearch_path, embedding_tiledb_uri) @@ -599,6 +592,9 @@ def search_exhaustive( import pandas as pd from scipy.spatial.distance import cdist + if embeddings.ndim == 1: + embeddings = embeddings.reshape(1, -1) + nn_idxs = [[] for _ in range(embeddings.shape[0])] nn_dists = [[] for _ in range(embeddings.shape[0])] n_cells = self.cell_metadata.shape[0] @@ -624,7 +620,7 @@ def search_exhaustive( # sort by lowest distance for row in range(len(nn_idxs)): - nn_idxs[row] = np.hstack(nn_idxs[0]) + nn_idxs[row] = np.hstack(nn_idxs[row]) nn_dists[row] = np.hstack(nn_dists[row]) sorted_indices = np.argsort(nn_dists[row]) nn_idxs[row] = nn_idxs[row][sorted_indices] diff --git a/src/scimilarity/nn_models.py b/src/scimilarity/nn_models.py index e96a982..e2c0c30 100644 --- a/src/scimilarity/nn_models.py +++ b/src/scimilarity/nn_models.py @@ -113,9 +113,11 @@ def load_state(self, filename: str, use_gpu: bool = False): """ if not use_gpu: - ckpt = torch.load(filename, map_location=torch.device("cpu")) + ckpt = torch.load( + filename, map_location=torch.device("cpu"), weights_only=False + ) else: - ckpt = torch.load(filename) + ckpt = torch.load(filename, weights_only=False) self.load_state_dict(ckpt["state_dict"]) @@ -218,7 +220,9 @@ def load_state(self, filename: str, use_gpu: bool = False): """ if not use_gpu: - ckpt = torch.load(filename, map_location=torch.device("cpu")) + ckpt = torch.load( + filename, map_location=torch.device("cpu"), weights_only=False + ) else: - ckpt = torch.load(filename) + ckpt = torch.load(filename, weights_only=False) self.load_state_dict(ckpt["state_dict"]) diff --git a/src/scimilarity/ontologies.py b/src/scimilarity/ontologies.py index 0d9eef3..7bad8b7 100644 --- a/src/scimilarity/ontologies.py +++ b/src/scimilarity/ontologies.py @@ -1,5 +1,6 @@ import networkx as nx import obonet +import pandas as pd from typing import Union, Tuple, List @@ -327,6 +328,47 @@ def get_lowest_common_ancestor(graph, node1, node2) -> nx.DiGraph: ) +def find_most_viable_parent(graph, node, node_list): + """Get most viable parent of a given node among the node_list. + + Parameters + ---------- + graph: networkx.DiGraph + Node graph. + node: str + ID of given node. + node_list: list, set, optional, default: None + A restricted node list for filtering. + + Returns + ------- + networkx.DiGraph + Node graph of parents. + + Examples + -------- + >>> coarse_grained = find_most_viable_parent(onto, id, celltype_list) + """ + + parents = get_parents(graph, node, node_list=node_list) + if len(parents) == 0: + coarse_grained = None + all_parents = list(get_parents(graph, node)) + if len(all_parents) == 1: + grandparents = get_parents(graph, all_parents[0], node_list=node_list) + if len(grandparents) == 1: + (coarse_grained,) = grandparents + elif len(parents) == 1: + (coarse_grained,) = parents + else: + for parent in list(parents): + coarse_grained = None + if get_all_ancestors(graph, parent, node_list=pd.Index(parents)): + coarse_grained = parent + break + return coarse_grained + + def ontology_similarity(graph, node1, node2, restricted_set=None) -> int: """Get the ontology similarity of two terms based on the number of common ancestors. diff --git a/src/scimilarity/training_models.py b/src/scimilarity/training_models.py index 217bd5b..92ecdd3 100644 --- a/src/scimilarity/training_models.py +++ b/src/scimilarity/training_models.py @@ -658,7 +658,7 @@ def save_all( meta_data["cell_tdb_uri"] = self.trainer.datamodule.cell_tdb_uri meta_data["gene_tdb_uri"] = self.trainer.datamodule.gene_tdb_uri meta_data["counts_tdb_uri"] = self.trainer.datamodule.counts_tdb_uri - self.trainer.datamodule.data_df.to_csv( + self.trainer.datamodule.train_df.to_csv( os.path.join(model_path, "train_cells.csv") ) if self.trainer.datamodule.val_df is not None: @@ -669,7 +669,7 @@ def save_all( with open(os.path.join(model_path, "reference_labels.tsv"), "w") as f: f.write( "\n".join( - self.trainer.datamodule.data_df["cellTypeName"].values.tolist() + self.trainer.datamodule.train_df["cellTypeName"].values.tolist() ) ) with open(os.path.join(model_path, "metadata.json"), "w") as f: diff --git a/src/scimilarity/utils.py b/src/scimilarity/utils.py index d6ac1c3..1796f1e 100644 --- a/src/scimilarity/utils.py +++ b/src/scimilarity/utils.py @@ -468,7 +468,14 @@ def write_csr_to_tiledb( vals.extend(val_slice) -def optimize_tiledb_array(tiledb_array_uri: str, verbose: bool = True): +def optimize_tiledb_array( + tiledb_array_uri: str, + steps=100000, + step_max_frags: int = 10, + buffer_size: int = 1000000000, # 1GB + total_budget: int = 200000000000, # 200GB + verbose: bool = True, +): """Optimize TileDB Array. Parameters @@ -489,8 +496,11 @@ def optimize_tiledb_array(tiledb_array_uri: str, verbose: bool = True): print("Fragments before consolidation: {}".format(len(frags))) cfg = tiledb.Config() - cfg["sm.consolidation.step_min_frags"] = 1 - cfg["sm.consolidation.step_max_frags"] = 200 + cfg["sm.consolidation.steps"] = steps + cfg["sm.consolidation.step_min_frags"] = 2 + cfg["sm.consolidation.step_max_frags"] = step_max_frags + cfg["sm.consolidation.buffer_size"] = buffer_size + cfg["sm.mem.total_budget"] = total_budget tiledb.consolidate(tiledb_array_uri, config=cfg) tiledb.vacuum(tiledb_array_uri) @@ -578,7 +588,6 @@ def pseudobulk_anndata( import numpy as np import pandas as pd import scanpy as sc - from scipy.sparse import csr_matrix if "counts" not in adata.layers: raise ValueError(f"Raw counts matrix not found in layers['counts'].") @@ -785,8 +794,10 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "Fat", "omental fat pad", "white adipose tissue", + "subcutaneous adipose tissue", + "visceral fat", }, - "adrenal gland": {"adrenal gland", "visceral fat"}, + "adrenal gland": {"adrenal gland"}, "airway": { "trachea", "trachea;bronchus", @@ -798,11 +809,45 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "inferior nasal concha", "nose", "nasal turbinal", + "respiratory airway", + "trachea;respiratory airway", + "bronchial epithelial cell", + "tonsil", + "dental pulp", + "gingiva", + "olfactory epithelium", + "periodontium", + "nasal cavity", + }, + "biliary system": { + "bile duct", + "mucosa of gallbladder", + }, + "bladder": { + "urinary bladder", + "Bladder", + "bladder", + "urothelium", + "ureter", + "urine", + }, + "blood": { + "blood", + "umbilical cord blood", + "peripheral blood", + "Blood", + "venous blood", + }, + "bone": { + "bone", + "bone tissue", + "head of femur", + "bone spine", + }, + "bone marrow": { + "bone marrow", + "Bone_Marrow", }, - "bone": {"bone", "bone tissue", "head of femur", "synovial fluid"}, - "bladder": {"urinary bladder", "Bladder", "bladder"}, - "blood": {"blood", "umbilical cord blood", "peripheral blood", "Blood"}, - "bone marrow": {"bone marrow", "Bone_Marrow"}, "brain": { "brain", "cortex", @@ -819,16 +864,56 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "brain white matter", "cerebellum", "hypothalamus", + "dorsal root ganglion", + "Brodmann (1909) area 9", + "choroid plexus", + "striatum", + "dorsolateral prefrontal cortex", + "putamen", + "middle temporal gyrus", + "frontal cortex", + "substantia nigra", + "primary somatosensory cortex", + "temporal cortex", + "primary visual cortex", + "central nervous system", }, - "breast": {"breast", "Mammary", "mammary gland"}, + "breast": { + "breast", + "Mammary", + "mammary gland", + "upper outer quadrant of breast", + }, + "ear": {"tympanic membrane"}, + # "embryo": { + # "amniotic fluid", + # "embryo", + # "blastocyst", + # "yolk sac", + # "ureteric bud", + # "placenta", + # }, "esophagus": { "esophagus", "esophagusmucosa", "esophagusmuscularis", "esophagus mucosa", "esophagus muscularis mucosa", + "epithelium of esophagus", + }, + "eye": { + "eye", + "uvea", + "corneal epithelium", + "retina", + "Eye", + "sclera", + "lacrimal gland", + "macula lutea proper", + "peripheral region of retina", + "fovea centralis", + "pigment epithelium of eye", }, - "eye": {"eye", "uvea", "corneal epithelium", "retina", "Eye"}, "stomach": {"stomach"}, "gut": { "colon", @@ -848,6 +933,9 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "jejunum", "jejunum ", "descending colon", + "rectum", + "colonic mucosa", + "mucosa of descending colon", }, "heart": { "heart", @@ -856,6 +944,19 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "Heart", "heart left ventricle", "pulmonary artery", + "cardiac ventricle", + "heart right ventricle", + "left cardiac atrium", + "right cardiac atrium", + "apex of heart", + "interventricular septum", + }, + "joint": { + "synovial fluid", + "cartilage tissue", + "portion of cartilage tissue in tibia", + "layer of synovial tissue", + "synovial membrane of synovial joint", }, "kidney": { "adult mammalian kidney", @@ -863,14 +964,23 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "Kidney", "inner medulla of kidney", "outer cortex of kidney", + "renal medulla", + "cortex of kidney", + "renal pelvis", + "kidney blood vessel", + "renal papilla", + }, + "liver": { + "liver", + "Liver", + "caudate lobe of liver", + "right lobe of liver", + "left lobe of liver", }, - "liver": {"liver", "Liver", "caudate lobe of liver"}, "lung": { "lung", "alveolar system", "lung parenchyma", - "respiratory airway", - "trachea;respiratory airway", "BAL", "Lung", "Parenchymal lung tissue", @@ -879,12 +989,21 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "Intermediate", "lower lobe of lung", "upper lobe of lung", + "upper lobe of left lung", + "upper lobe of right lung", + "lower lobe of right lung", + "lower lobe of left lung", + "left lung", + "right lung", + "lingula of left lung", }, "lymph node": { "lymph node", "axillary lymph node", "Lymph_Node", "craniocervical lymph node", + "thoracic lymph node", + "mesenteric lymph node", }, "male reproduction": { "male reproductive gland", @@ -894,6 +1013,8 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "Prostate", "prostate", "peripheral zone of prostate", + "transition zone of prostate;urethra", + "transition zone of prostate", }, "female reproduction": { "ovary", @@ -906,7 +1027,21 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "uterus", "Uterus", }, - "pancreas": {"pancreas", "Pancreas", "islet of Langerhans"}, + "muscle": { + "psoas muscle", + "muscle tissue", + "gastrocnemius", + }, + "pancreas": { + "pancreas", + "Pancreas", + "islet of Langerhans", + "exocrine pancreas", + }, + "peritoneum": { + "peritoneum", + "pleural effusion", + }, "skin": { "skin of body", "skin epidermis", @@ -914,22 +1049,32 @@ def clean_tissues(tissues: "pandas.Series") -> "pandas.Series": "scrotum skin", "Skin", "skin", + "skin of leg", + "zone of skin", + }, + "spleen": { + "spleen", + "Spleen", + }, + "thymus": { + "thymus", + "Thymus", }, - "spleen": {"spleen", "Spleen"}, - "thymus": {"thymus", "Thymus"}, "vasculature": { "vasculature", "mesenteric artery", "umbilical vein", "Vasculature", + "carotid artery segment", + "posterior vena cava", }, } term2simple = {} for tissue_simplified, children in tissue_mapper.items(): for child in children: - term2simple[child] = tissue_simplified + term2simple[child.lower()] = tissue_simplified - return tissues.map(term2simple) + return tissues.str.lower().map(term2simple) def clean_diseases(diseases: "pandas.Series") -> "pandas.Series": @@ -952,12 +1097,8 @@ def clean_diseases(diseases: "pandas.Series") -> "pandas.Series": disease_mapper = { "healthy": {"healthy", "", "NA"}, - "Alzheimer's": { - "Alzheimer's disease", - }, - "COVID-19": { - "COVID-19", - }, + "Alzheimer's": {"Alzheimer's disease"}, + "COVID-19": {"COVID-19"}, "ILD": { "pulmonary fibrosis", "idiopathic pulmonary fibrosis", @@ -965,6 +1106,7 @@ def clean_diseases(diseases: "pandas.Series") -> "pandas.Series": "systemic scleroderma;interstitial lung disease", "fibrosis", "hypersensitivity pneumonitis", + "Idiopathic pulmonary arterial hypertension", }, "cancer": { "head and neck squamous cell carcinoma", @@ -978,47 +1120,100 @@ def clean_diseases(diseases: "pandas.Series") -> "pandas.Series": "melanoma", "multiple myeloma", "Gastrointestinal stromal tumor", - "neuroblastoma" "nasopharyngeal neoplasm", "adenocarcinoma", "pancreatic ductal adenocarcinoma", "chronic lymphocytic leukemia", "Uveal Melanoma", "Myelofibrosis", + "acute myeloid leukemia", + "acute lymphoblastic leukemia", + "precursor B-cell acute lymphoblastic leukemia", + "T-cell acute lymphoblastic leukemia", + "chronic myelogenous leukemia", + "B-cell lymphoma", + "precursor T-cell lymphoblastic leukemia-lymphoma", + "human papilloma virus infection;head and neck squamous cell carcinoma", + "squamous cell carcinoma", + "Tonsillar Squamous Cell Carcinoma", + "invasive breast ductal carcinoma", + "basal cell carcinoma", + "brain glioblastoma;non-small cell lung carcinoma", + "renal cell carcinoma", + "non-small cell lung carcinoma", + "colorectal cancer", + "esophageal carcinoma", + "liver neoplasm;Uveal Melanoma", + "glioblastoma multiforme", + "Ewing sarcoma", + "medulloblastoma", + "brain glioblastoma", + "breast neoplasm", + "lung adenocarcinoma", + "lung cancer", + "nasopharyngeal neoplasm", + "small cell lung carcinoma", + "breast cancer", + "prostate cancer", + "gastric cancer", + "gastric carcinoma", + "bladder carcinoma", + "urinary bladder cancer", + "Pleuropulmonary blastoma", + "cutaneous squamous cell carcinoma", + "Merkel cell skin cancer", + "urothelial neoplasm", + "alveolar rhabdomyosarcoma", + "myeloid neoplasm", + "Sezary's disease", + "essential thrombocythemia", }, - "MS": { - "multiple sclerosis", - }, - "dengue": { - "dengue disease", + "MS": {"multiple sclerosis"}, + "dengue": {"dengue disease"}, + "HIV": { + "HIV enteropathy", + "HIV infection", }, "IBD": { "Crohn's disease", + "ulcerative colitis", }, "SLE": {"systemic lupus erythematosus"}, "scleroderma": {"scleroderma"}, "LCH": {"Langerhans Cell Histiocytosis"}, - "NAFLD": {"non-alcoholic fatty liver disease"}, + "NAFLD": {"non-alcoholic fatty liver disease", "non-alcoholic steatohepatitis"}, "Kawasaki disease": {"mucocutaneous lymph node syndrome"}, "eczema": {"atopic eczema"}, "sepsis": {"septic shock"}, "obesity": {"obesity"}, "DRESS": {"drug hypersensitivity syndrome"}, "hidradenitis suppurativa": {"hidradenitis suppurativa"}, - "T2 diabetes": {"type II diabetes mellitus"}, - "non-alcoholic steatohepatitis": {"non-alcoholic steatohepatitis"}, - "Biliary atresia": {"Biliary atresia"}, - "essential thrombocythemia": {"essential thrombocythemia"}, - "HIV": {"HIV enteropathy"}, + "diabetes": { + "type II diabetes mellitus", + "type 2 diabetes mellitus", + "diabetes mellitus", + "Wolfram syndrome", + }, + "biliary atresia": {"Biliary atresia"}, "monoclonal gammopathy": {"monoclonal gammopathy"}, "psoriatic arthritis": {"psoriatic arthritis"}, "RA": {"rheumatoid arthritis"}, "osteoarthritis": {"osteoarthritis"}, "periodontitis": {"periodontitis"}, - "Lymphangioleiomyomatosis": {"Lymphangioleiomyomatosis"}, + "LAM": {"Lymphangioleiomyomatosis"}, + "Parkinson's": { + "Parkinson's disease", + "Parkinson's Disease", + }, + "cardiomyopathy": { + "cardiomyopathy", + "arrhythmogenic right ventricular cardiomyopathy", + "dilated cardiomyopathy", + }, } + term2simple = {} for disease_simplified, children in disease_mapper.items(): for child in children: - term2simple[child] = disease_simplified + term2simple[child.lower()] = disease_simplified - return diseases.map(term2simple) + return diseases.str.lower().map(term2simple)