diff --git a/docs/notebooks/cell_annotation_tutorial.ipynb b/docs/notebooks/cell_annotation_tutorial.ipynb index b1cad94..0b73b8a 100644 --- a/docs/notebooks/cell_annotation_tutorial.ipynb +++ b/docs/notebooks/cell_annotation_tutorial.ipynb @@ -303,7 +303,7 @@ "\n", "Two methods within the CellAnnotation class:\n", " 1. `annotate_dataset` - automatically computes embeddings.\n", - " 2. `get_predictions` - more detailed control of annotation.\n", + " 2. `get_predictions_knn` - more detailed control of annotation.\n", "\n", "*Description of inputs*\n", " - `X_scimilarity`: embeddings from the model, which can be used to generate UMAPs in lieu of PCA and is generalized across datasets. \n", @@ -313,7 +313,7 @@ " - `nn_idxs`: indicies of cells in the SCimilarity reference. \n", " - `nn_dists`: the minimum distance within k=50 nearest neighbors.\n", " - `nn_stats`: a dataframe containing useful metrics such as: \n", - " - `hits`: the distribution of celltypes in k=50 nearest neighbors." + " - `hits`: the distribution of celltypes in k=50 nearest neighbors." ] }, { @@ -349,7 +349,7 @@ } ], "source": [ - "predictions, nn_idxs, nn_dists, nn_stats = ca.get_predictions_kNN(\n", + "predictions, nn_idxs, nn_dists, nn_stats = ca.get_predictions_knn(\n", " adams.obsm[\"X_scimilarity\"]\n", ")\n", "adams.obs[\"predictions_unconstrained\"] = predictions.values" diff --git a/docs/notebooks/cell_search_tutorial_2.ipynb b/docs/notebooks/cell_search_tutorial_2.ipynb index 62197e0..51bfa7b 100644 --- a/docs/notebooks/cell_search_tutorial_2.ipynb +++ b/docs/notebooks/cell_search_tutorial_2.ipynb @@ -352,7 +352,7 @@ "source": [ "# Fibroblast query\n", "centroid_embedding, nn_idxs, nn_dists, fibro_results_metadata, qc_stats = (\n", - " cq.search_centroid_kNN(adams, \"used_in_fibro_query\")\n", + " cq.search_centroid_nearest(adams, \"used_in_fibro_query\")\n", ")\n", "print(qc_stats)" ] @@ -374,7 +374,7 @@ "source": [ "# Non-specific query\n", "centroid_embedding, nn_idxs, nn_dists, nonspecific_results_metadata, qc_stats = (\n", - " cq.search_centroid_kNN(adams, \"used_in_nonspecific_query\")\n", + " cq.search_centroid_nearest(adams, \"used_in_nonspecific_query\")\n", ")\n", "print(qc_stats)" ] diff --git a/docs/notebooks/cell_search_tutorial_3.ipynb b/docs/notebooks/cell_search_tutorial_3.ipynb index 2b32245..a6258db 100644 --- a/docs/notebooks/cell_search_tutorial_3.ipynb +++ b/docs/notebooks/cell_search_tutorial_3.ipynb @@ -301,7 +301,7 @@ ], "source": [ "# DC_Mature query\n", - "centroid_embedding, nn_idxs, nn_dists, dcmature_metadata, qc_stats = cq.search_centroid_kNN(\n", + "centroid_embedding, nn_idxs, nn_dists, dcmature_metadata, qc_stats = cq.search_centroid_nearest(\n", " adams, \"dcmature_query\"\n", ")\n", "print(qc_stats)" @@ -323,7 +323,7 @@ ], "source": [ "# Non-specific T cell query\n", - "centroid_embedding, nn_idxs, nn_dists, t_metadata, qc_stats = cq.search_centroid_kNN(\n", + "centroid_embedding, nn_idxs, nn_dists, t_metadata, qc_stats = cq.search_centroid_nearest(\n", " adams, \"t_query\"\n", ")\n", "print(qc_stats)" @@ -345,7 +345,7 @@ ], "source": [ "# Broad macrophage query\n", - "centroid_embedding, nn_idxs, nn_dists, macrophage_metadata, qc_stats = cq.search_centroid_kNN(\n", + "centroid_embedding, nn_idxs, nn_dists, macrophage_metadata, qc_stats = cq.search_centroid_nearest(\n", " adams, \"macrophage_query\"\n", ")\n", "print(qc_stats)" diff --git a/src/scimilarity/cell_annotation.py b/src/scimilarity/cell_annotation.py index b3ce45d..7b4806e 100644 --- a/src/scimilarity/cell_annotation.py +++ b/src/scimilarity/cell_annotation.py @@ -4,7 +4,7 @@ class CellAnnotation(CellSearchKNN): - """A class that annotates cells using a cell embedding and then kNN search.""" + """A class that annotates cells using a cell embedding and then knn search.""" def __init__( self, @@ -72,7 +72,7 @@ def __init__( self.safelist = None self.blocklist = None - def build_kNN( + def build_knn( self, input_data: Union["anndata.AnnData", List[str]], knn_filename: str = "labelled_kNN.bin", @@ -82,7 +82,7 @@ def build_kNN( M: int = 80, target_labels: Optional[List[str]] = None, ): - """Build and save a kNN index from a h5ad data file or directory of aligned.zarr stores. + """Build and save a knn index from a h5ad data file or directory of aligned.zarr stores. Parameters ---------- @@ -92,7 +92,7 @@ def build_kNN( Otherwise, the annotated data matrix with rows for cells and columns for genes. NOTE: The data should be curated to only contain valid cell ontology labels. knn_filename: str, default: "labelled_kNN.bin" - Filename of the kNN index. + Filename of the knn index. celltype_labels_filename: str, default: "reference_labels.tsv" Filename of the cell type reference labels. obs_field: str, default: "celltype_name" @@ -184,12 +184,12 @@ def build_kNN( with open(celltype_labels_fullpath, "r") as fh: self.idx2label = {i: line.strip() for i, line in enumerate(fh)} - def reset_kNN(self): - """Reset the kNN such that nothing is marked deleted. + def reset_knn(self): + """Reset the knn such that nothing is marked deleted. Examples -------- - >>> ca.reset_kNN() + >>> ca.reset_knn() """ self.blocklist = None @@ -223,7 +223,7 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]): self.blocklist = set(labels) self.safelist = None - self.reset_kNN() + self.reset_knn() for i, celltype_name in self.idx2label.items(): if celltype_name in self.blocklist: self.knn.mark_deleted(i) # mark blocklist @@ -258,7 +258,7 @@ def safelist_celltypes(self, labels: Union[List[str], Set[str]]): if celltype_name in self.safelist: self.knn.unmark_deleted(i) # unmark safelist - def get_predictions_kNN( + def get_predictions_knn( self, embeddings: "numpy.ndarray", k: int = 50, @@ -266,7 +266,7 @@ def get_predictions_kNN( weighting: bool = False, disable_progress: bool = False, ) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]: - """Get predictions from kNN search results. + """Get predictions from knn search results. Parameters ---------- @@ -305,7 +305,7 @@ def get_predictions_kNN( -------- >>> ca = CellAnnotation(model_path="/opt/data/model") >>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X) - >>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_kNN(embeddings) + >>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(embeddings) """ from collections import defaultdict @@ -418,7 +418,7 @@ def annotate_dataset( embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X) data.obsm["X_scimilarity"] = embeddings - predictions, _, _, nn_stats = self.get_predictions_kNN(embeddings) + predictions, _, _, nn_stats = self.get_predictions_knn(embeddings) data.obs["celltype_hint"] = predictions.values data.obs["min_dist"] = nn_stats["min_dist"].values data.obs["celltype_hits"] = nn_stats["hits"].values diff --git a/src/scimilarity/cell_query.py b/src/scimilarity/cell_query.py index f243d65..90a36d8 100644 --- a/src/scimilarity/cell_query.py +++ b/src/scimilarity/cell_query.py @@ -36,7 +36,7 @@ def __init__( residual: bool, default: False Use residual connections. load_knn: bool, default: True - Load the knn index. Set to False if kNN is not needed. + Load the knn index. Set to False if knn is not needed. Examples -------- @@ -296,7 +296,7 @@ def search_nearest( ef: int = None, max_dist: Optional[float] = None, ) -> Tuple[List["numpy.ndarray"], List["numpy.ndarray"], "pandas.DataFrame"]: - """Performs a nearest neighbors search against the kNN. + """Performs a nearest neighbors search against the knn. Parameters ---------- @@ -360,7 +360,7 @@ def search_nearest( return nn_idxs, nn_dists, metadata - def search_centroid_kNN( + def search_centroid_nearest( self, adata: "anndata.AnnData", centroid_key: str, @@ -421,7 +421,7 @@ def search_centroid_kNN( -------- >>> cells_used_in_query = adata.obs["celltype_name"] == "macrophage" >>> adata.obs["used_in_query"] = cells_used_in_query.astype(int) - >>> centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats = cq.search_centroid_kNN(adata, 'used_in_query') + >>> centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats = cq.search_centroid_nearest(adata, 'used_in_query') """ import numpy as np @@ -460,7 +460,7 @@ def search_centroid_kNN( return centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats - def search_cluster_centroids_kNN( + def search_cluster_centroids_nearest( self, adata: "anndata.AnnData", cluster_key: str, @@ -476,7 +476,7 @@ def search_cluster_centroids_kNN( Dict[str, "numpy.ndarray"], "pandas.DataFrame", ]: - """Performs a nearest neighbors search for cluster centroids against the kNN. + """Performs a nearest neighbors search for cluster centroids against the knn. Parameters ---------- @@ -516,7 +516,7 @@ def search_cluster_centroids_kNN( Examples -------- - >>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids_kNN(adata, "leidan") + >>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids_nearest(adata, "leidan") """ from .utils import get_cluster_centroids @@ -646,6 +646,7 @@ def search_centroid_exhaustive( adata: "anndata.AnnData", centroid_key: str, max_dist: float = 0.03, + metadata_filter: Optional[dict] = None, qc: bool = True, qc_params: dict = {"k_clusters": 10}, buffer_size: int = 100000, @@ -668,6 +669,9 @@ def search_centroid_exhaustive( The obs column key that marks cells to centroid as 1, otherwise 0. max_dist: float, default: 0.03 Filter for cells that are within the max distance to the query. + metadata_filter: dict, optional, default: None + A dictionary where keys represent column names and values + represent valid terms in the columns. qc: bool, default: True Whether to perform QC on the query qc_params: dict, default: {'k_clusters': 10} @@ -711,6 +715,7 @@ def search_centroid_exhaustive( nn_idxs, nn_dists, metadata = self.search_exhaustive( centroid_embedding, max_dist=max_dist, + metadata_filter=metadata_filter, buffer_size=buffer_size, ) @@ -728,3 +733,94 @@ def search_centroid_exhaustive( qc_stats["query_coherence"] = np.mean(query_overlap) return centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats + + def search_cluster_centroids_exhaustive( + self, + adata: "anndata.AnnData", + cluster_key: str, + cluster_label: Optional[str] = None, + max_dist: float = 0.03, + metadata_filter: Optional[dict] = None, + buffer_size: int = 100000, + skip_null: bool = True, + ) -> Tuple[ + "numpy.ndarray", + list, + Dict[str, "numpy.ndarray"], + Dict[str, "numpy.ndarray"], + "pandas.DataFrame", + ]: + """Performs a nearest neighbors search for cluster centroids against the knn. + + Parameters + ---------- + adata: anndata.AnnData + Annotated data matrix with rows for cells and columns for genes. + Requires a layers["counts"]. + cluster_key: str + The obs column key that contains cluster labels. + cluster_label: str, optional, default: None + The cluster label of interest. If None, then get the centroids of + all clusters, otherwise get only the centroid for the cluster + of interest + max_dist: float, default: 0.03 + Filter for cells that are within the max distance to the query. + metadata_filter: dict, optional, default: None + A dictionary where keys represent column names and values + represent valid terms in the columns. + buffer_size: int, default: 100000 + Batch size for processing cells. + skip_null: bool, default: True + Whether to skip cells with null/nan cluster labels. + + Returns + ------- + centroid_embeddings: numpy.ndarray + A 2D numpy array of the log normalized (1e4) cluster centroid embeddings. + cluster_idx: list + A list of cluster labels corresponding to the order returned in centroids. + nn_idxs: Dict[str, numpy.ndarray] + A 2D numpy array of nearest neighbor indices [num_cells x k]. + nn_dists: Dict[str, numpy.ndarray] + A 2D numpy array of nearest neighbor distances [num_cells x k]. + all_metadata: pandas.DataFrame + A pandas dataframe containing cell metadata for nearest neighbors + for all centroids. + + Examples + -------- + >>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids_exhaustive(adata, "leidan") + """ + + from .utils import get_cluster_centroids + + centroids, cluster_idx = get_cluster_centroids( + adata, self.gene_order, cluster_key, cluster_label, skip_null=skip_null + ) + + centroid_embeddings = self.get_embeddings(centroids) + + nn_idxs, nn_dists, metadata = self.search_exhaustive( + centroid_embeddings, + max_dist=max_dist, + metadata_filter=metadata_filter, + buffer_size=buffer_size, + ) + + metadata["centroid"] = metadata["embedding_idx"].map( + {i: x for i, x in enumerate(cluster_idx)} + ) + + nn_idxs_dict = {} + nn_dists_dict = {} + for i in range(len(cluster_idx)): + nn_idxs_dict[cluster_idx[i]] = [nn_idxs[i]] + nn_dists_dict[cluster_idx[i]] = [nn_dists[i]] + + return ( + centroid_embeddings, + cluster_idx, + nn_idxs_dict, + nn_dists_dict, + metadata, + )