Skip to content

Commit

Permalink
add search_cluster_centroids_exhaustive, fixed parameters, styling pass
Browse files Browse the repository at this point in the history
  • Loading branch information
tony-kuo committed Sep 14, 2024
1 parent dfaab85 commit 0681cd4
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 27 deletions.
6 changes: 3 additions & 3 deletions docs/notebooks/cell_annotation_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
"\n",
"Two methods within the CellAnnotation class:\n",
" 1. `annotate_dataset` - automatically computes embeddings.\n",
" 2. `get_predictions` - more detailed control of annotation.\n",
" 2. `get_predictions_knn` - more detailed control of annotation.\n",
"\n",
"*Description of inputs*\n",
" - `X_scimilarity`: embeddings from the model, which can be used to generate UMAPs in lieu of PCA and is generalized across datasets. \n",
Expand All @@ -313,7 +313,7 @@
" - `nn_idxs`: indicies of cells in the SCimilarity reference. \n",
" - `nn_dists`: the minimum distance within k=50 nearest neighbors.\n",
" - `nn_stats`: a dataframe containing useful metrics such as: \n",
" - `hits`: the distribution of celltypes in k=50 nearest neighbors."
" - `hits`: the distribution of celltypes in k=50 nearest neighbors."
]
},
{
Expand Down Expand Up @@ -349,7 +349,7 @@
}
],
"source": [
"predictions, nn_idxs, nn_dists, nn_stats = ca.get_predictions_kNN(\n",
"predictions, nn_idxs, nn_dists, nn_stats = ca.get_predictions_knn(\n",
" adams.obsm[\"X_scimilarity\"]\n",
")\n",
"adams.obs[\"predictions_unconstrained\"] = predictions.values"
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/cell_search_tutorial_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@
"source": [
"# Fibroblast query\n",
"centroid_embedding, nn_idxs, nn_dists, fibro_results_metadata, qc_stats = (\n",
" cq.search_centroid_kNN(adams, \"used_in_fibro_query\")\n",
" cq.search_centroid_nearest(adams, \"used_in_fibro_query\")\n",
")\n",
"print(qc_stats)"
]
Expand All @@ -374,7 +374,7 @@
"source": [
"# Non-specific query\n",
"centroid_embedding, nn_idxs, nn_dists, nonspecific_results_metadata, qc_stats = (\n",
" cq.search_centroid_kNN(adams, \"used_in_nonspecific_query\")\n",
" cq.search_centroid_nearest(adams, \"used_in_nonspecific_query\")\n",
")\n",
"print(qc_stats)"
]
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/cell_search_tutorial_3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@
],
"source": [
"# DC_Mature query\n",
"centroid_embedding, nn_idxs, nn_dists, dcmature_metadata, qc_stats = cq.search_centroid_kNN(\n",
"centroid_embedding, nn_idxs, nn_dists, dcmature_metadata, qc_stats = cq.search_centroid_nearest(\n",
" adams, \"dcmature_query\"\n",
")\n",
"print(qc_stats)"
Expand All @@ -323,7 +323,7 @@
],
"source": [
"# Non-specific T cell query\n",
"centroid_embedding, nn_idxs, nn_dists, t_metadata, qc_stats = cq.search_centroid_kNN(\n",
"centroid_embedding, nn_idxs, nn_dists, t_metadata, qc_stats = cq.search_centroid_nearest(\n",
" adams, \"t_query\"\n",
")\n",
"print(qc_stats)"
Expand All @@ -345,7 +345,7 @@
],
"source": [
"# Broad macrophage query\n",
"centroid_embedding, nn_idxs, nn_dists, macrophage_metadata, qc_stats = cq.search_centroid_kNN(\n",
"centroid_embedding, nn_idxs, nn_dists, macrophage_metadata, qc_stats = cq.search_centroid_nearest(\n",
" adams, \"macrophage_query\"\n",
")\n",
"print(qc_stats)"
Expand Down
24 changes: 12 additions & 12 deletions src/scimilarity/cell_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class CellAnnotation(CellSearchKNN):
"""A class that annotates cells using a cell embedding and then kNN search."""
"""A class that annotates cells using a cell embedding and then knn search."""

def __init__(
self,
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.safelist = None
self.blocklist = None

def build_kNN(
def build_knn(
self,
input_data: Union["anndata.AnnData", List[str]],
knn_filename: str = "labelled_kNN.bin",
Expand All @@ -82,7 +82,7 @@ def build_kNN(
M: int = 80,
target_labels: Optional[List[str]] = None,
):
"""Build and save a kNN index from a h5ad data file or directory of aligned.zarr stores.
"""Build and save a knn index from a h5ad data file or directory of aligned.zarr stores.
Parameters
----------
Expand All @@ -92,7 +92,7 @@ def build_kNN(
Otherwise, the annotated data matrix with rows for cells and columns for genes.
NOTE: The data should be curated to only contain valid cell ontology labels.
knn_filename: str, default: "labelled_kNN.bin"
Filename of the kNN index.
Filename of the knn index.
celltype_labels_filename: str, default: "reference_labels.tsv"
Filename of the cell type reference labels.
obs_field: str, default: "celltype_name"
Expand Down Expand Up @@ -184,12 +184,12 @@ def build_kNN(
with open(celltype_labels_fullpath, "r") as fh:
self.idx2label = {i: line.strip() for i, line in enumerate(fh)}

def reset_kNN(self):
"""Reset the kNN such that nothing is marked deleted.
def reset_knn(self):
"""Reset the knn such that nothing is marked deleted.
Examples
--------
>>> ca.reset_kNN()
>>> ca.reset_knn()
"""

self.blocklist = None
Expand Down Expand Up @@ -223,7 +223,7 @@ def blocklist_celltypes(self, labels: Union[List[str], Set[str]]):
self.blocklist = set(labels)
self.safelist = None

self.reset_kNN()
self.reset_knn()
for i, celltype_name in self.idx2label.items():
if celltype_name in self.blocklist:
self.knn.mark_deleted(i) # mark blocklist
Expand Down Expand Up @@ -258,15 +258,15 @@ def safelist_celltypes(self, labels: Union[List[str], Set[str]]):
if celltype_name in self.safelist:
self.knn.unmark_deleted(i) # unmark safelist

def get_predictions_kNN(
def get_predictions_knn(
self,
embeddings: "numpy.ndarray",
k: int = 50,
ef: int = 100,
weighting: bool = False,
disable_progress: bool = False,
) -> Tuple["numpy.ndarray", "numpy.ndarray", "numpy.ndarray", "pandas.DataFrame"]:
"""Get predictions from kNN search results.
"""Get predictions from knn search results.
Parameters
----------
Expand Down Expand Up @@ -305,7 +305,7 @@ def get_predictions_kNN(
--------
>>> ca = CellAnnotation(model_path="/opt/data/model")
>>> embeddings = ca.get_embeddings(align_dataset(data, ca.gene_order).X)
>>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_kNN(embeddings)
>>> predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(embeddings)
"""

from collections import defaultdict
Expand Down Expand Up @@ -418,7 +418,7 @@ def annotate_dataset(
embeddings = self.get_embeddings(align_dataset(data, self.gene_order).X)
data.obsm["X_scimilarity"] = embeddings

predictions, _, _, nn_stats = self.get_predictions_kNN(embeddings)
predictions, _, _, nn_stats = self.get_predictions_knn(embeddings)
data.obs["celltype_hint"] = predictions.values
data.obs["min_dist"] = nn_stats["min_dist"].values
data.obs["celltype_hits"] = nn_stats["hits"].values
Expand Down
110 changes: 103 additions & 7 deletions src/scimilarity/cell_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
residual: bool, default: False
Use residual connections.
load_knn: bool, default: True
Load the knn index. Set to False if kNN is not needed.
Load the knn index. Set to False if knn is not needed.
Examples
--------
Expand Down Expand Up @@ -296,7 +296,7 @@ def search_nearest(
ef: int = None,
max_dist: Optional[float] = None,
) -> Tuple[List["numpy.ndarray"], List["numpy.ndarray"], "pandas.DataFrame"]:
"""Performs a nearest neighbors search against the kNN.
"""Performs a nearest neighbors search against the knn.
Parameters
----------
Expand Down Expand Up @@ -360,7 +360,7 @@ def search_nearest(

return nn_idxs, nn_dists, metadata

def search_centroid_kNN(
def search_centroid_nearest(
self,
adata: "anndata.AnnData",
centroid_key: str,
Expand Down Expand Up @@ -421,7 +421,7 @@ def search_centroid_kNN(
--------
>>> cells_used_in_query = adata.obs["celltype_name"] == "macrophage"
>>> adata.obs["used_in_query"] = cells_used_in_query.astype(int)
>>> centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats = cq.search_centroid_kNN(adata, 'used_in_query')
>>> centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats = cq.search_centroid_nearest(adata, 'used_in_query')
"""

import numpy as np
Expand Down Expand Up @@ -460,7 +460,7 @@ def search_centroid_kNN(

return centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats

def search_cluster_centroids_kNN(
def search_cluster_centroids_nearest(
self,
adata: "anndata.AnnData",
cluster_key: str,
Expand All @@ -476,7 +476,7 @@ def search_cluster_centroids_kNN(
Dict[str, "numpy.ndarray"],
"pandas.DataFrame",
]:
"""Performs a nearest neighbors search for cluster centroids against the kNN.
"""Performs a nearest neighbors search for cluster centroids against the knn.
Parameters
----------
Expand Down Expand Up @@ -516,7 +516,7 @@ def search_cluster_centroids_kNN(
Examples
--------
>>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids_kNN(adata, "leidan")
>>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids_nearest(adata, "leidan")
"""

from .utils import get_cluster_centroids
Expand Down Expand Up @@ -646,6 +646,7 @@ def search_centroid_exhaustive(
adata: "anndata.AnnData",
centroid_key: str,
max_dist: float = 0.03,
metadata_filter: Optional[dict] = None,
qc: bool = True,
qc_params: dict = {"k_clusters": 10},
buffer_size: int = 100000,
Expand All @@ -668,6 +669,9 @@ def search_centroid_exhaustive(
The obs column key that marks cells to centroid as 1, otherwise 0.
max_dist: float, default: 0.03
Filter for cells that are within the max distance to the query.
metadata_filter: dict, optional, default: None
A dictionary where keys represent column names and values
represent valid terms in the columns.
qc: bool, default: True
Whether to perform QC on the query
qc_params: dict, default: {'k_clusters': 10}
Expand Down Expand Up @@ -711,6 +715,7 @@ def search_centroid_exhaustive(
nn_idxs, nn_dists, metadata = self.search_exhaustive(
centroid_embedding,
max_dist=max_dist,
metadata_filter=metadata_filter,
buffer_size=buffer_size,
)

Expand All @@ -728,3 +733,94 @@ def search_centroid_exhaustive(
qc_stats["query_coherence"] = np.mean(query_overlap)

return centroid_embedding, nn_idxs, nn_dists, metadata, qc_stats

def search_cluster_centroids_exhaustive(
self,
adata: "anndata.AnnData",
cluster_key: str,
cluster_label: Optional[str] = None,
max_dist: float = 0.03,
metadata_filter: Optional[dict] = None,
buffer_size: int = 100000,
skip_null: bool = True,
) -> Tuple[
"numpy.ndarray",
list,
Dict[str, "numpy.ndarray"],
Dict[str, "numpy.ndarray"],
"pandas.DataFrame",
]:
"""Performs a nearest neighbors search for cluster centroids against the knn.
Parameters
----------
adata: anndata.AnnData
Annotated data matrix with rows for cells and columns for genes.
Requires a layers["counts"].
cluster_key: str
The obs column key that contains cluster labels.
cluster_label: str, optional, default: None
The cluster label of interest. If None, then get the centroids of
all clusters, otherwise get only the centroid for the cluster
of interest
max_dist: float, default: 0.03
Filter for cells that are within the max distance to the query.
metadata_filter: dict, optional, default: None
A dictionary where keys represent column names and values
represent valid terms in the columns.
buffer_size: int, default: 100000
Batch size for processing cells.
skip_null: bool, default: True
Whether to skip cells with null/nan cluster labels.
Returns
-------
centroid_embeddings: numpy.ndarray
A 2D numpy array of the log normalized (1e4) cluster centroid embeddings.
cluster_idx: list
A list of cluster labels corresponding to the order returned in centroids.
nn_idxs: Dict[str, numpy.ndarray]
A 2D numpy array of nearest neighbor indices [num_cells x k].
nn_dists: Dict[str, numpy.ndarray]
A 2D numpy array of nearest neighbor distances [num_cells x k].
all_metadata: pandas.DataFrame
A pandas dataframe containing cell metadata for nearest neighbors
for all centroids.
Examples
--------
>>> centroid_embeddings, cluster_idx, nn_idx, nn_dists, all_metadata = cq.search_cluster_centroids_exhaustive(adata, "leidan")
"""

from .utils import get_cluster_centroids

centroids, cluster_idx = get_cluster_centroids(
adata, self.gene_order, cluster_key, cluster_label, skip_null=skip_null
)

centroid_embeddings = self.get_embeddings(centroids)

nn_idxs, nn_dists, metadata = self.search_exhaustive(
centroid_embeddings,
max_dist=max_dist,
metadata_filter=metadata_filter,
buffer_size=buffer_size,
)

metadata["centroid"] = metadata["embedding_idx"].map(
{i: x for i, x in enumerate(cluster_idx)}
)

nn_idxs_dict = {}
nn_dists_dict = {}
for i in range(len(cluster_idx)):
nn_idxs_dict[cluster_idx[i]] = [nn_idxs[i]]
nn_dists_dict[cluster_idx[i]] = [nn_dists[i]]

return (
centroid_embeddings,
cluster_idx,
nn_idxs_dict,
nn_dists_dict,
metadata,
)

0 comments on commit 0681cd4

Please sign in to comment.