Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions src/squidpy/gr/_niche.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def calculate_niche(
n_components: int | None = None,
random_state: int = 42,
spatial_connectivities_key: str = "spatial_connectivities",
use_rep: str | None = None,
inplace: bool = True,
) -> AnnData:
"""
Expand All @@ -58,7 +59,7 @@ def calculate_niche(
Method to use for niche calculation. Available options are:
- `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile.
- `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication).
- `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach.
- `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA for dimensionality reduction. An arbitrary embedding can be used instead of PCA by setting the `use_rep` parameter which will try to find the embedding in `adata.obsm`.
%(library_key)s
If provided, niches will be calculated separately for each unique value in this column.
Each niche will be prefixed with the library identifier.
Expand Down Expand Up @@ -103,6 +104,9 @@ def calculate_niche(
Optional if flavor == `{fla.CELLCHARTER.s!r}`.
spatial_connectivities_key
Key in `adata.obsp` where spatial connectivities are stored.
use_rep
Key in `adata.obsm` where the embedding is stored. If provided, this embedding will be used instead of PCA for dimensionality reduction.
Optional if flavor == `{fla.CELLCHARTER.s!r}`.
inplace
If 'True', perform the operation in place.
If 'False', return a new AnnData object with the niche labels.
Expand All @@ -111,6 +115,12 @@ def calculate_niche(
if flavor == "cellcharter" and aggregation is None:
aggregation = "mean"

if distance is None:
distance = 3 if flavor == "cellcharter" else 1

if flavor == "cellcharter" and n_components is None:
n_components = 10

_validate_niche_args(
data,
flavor,
Expand All @@ -127,15 +137,13 @@ def calculate_niche(
aggregation,
n_components,
random_state,
use_rep,
inplace,
)

if resolutions is None:
resolutions = [0.5]

if distance is None:
distance = 1

if isinstance(data, SpatialData):
orig_adata = data.tables[table_key]
adata = orig_adata.copy()
Expand Down Expand Up @@ -225,6 +233,7 @@ def calculate_niche(
n_components,
random_state,
spatial_connectivities_key,
use_rep,
)

if not inplace:
Expand Down Expand Up @@ -293,6 +302,7 @@ def _calculate_niches(
n_components: int | None,
random_state: int,
spatial_connectivities_key: str,
use_rep: str | None,
) -> None:
"""Calculate niches using the specified flavor and parameters."""
if flavor == "neighborhood":
Expand Down Expand Up @@ -321,6 +331,7 @@ def _calculate_niches(
n_components,
random_state,
spatial_connectivities_key,
use_rep,
)


Expand Down Expand Up @@ -470,6 +481,7 @@ def _get_cellcharter_niches(
n_components: int,
random_state: int,
spatial_connectivities_key: str,
use_rep: str | None = None,
) -> None:
"""adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py
and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py"""
Expand All @@ -494,11 +506,32 @@ def _get_cellcharter_niches(

concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally
arr = concatenated_matrix.toarray() # Densify
arr_ad = ad.AnnData(X=arr)
sc.tl.pca(arr_ad)

if use_rep is not None:
# Use provided embedding from adata.obsm
if use_rep not in adata.obsm:
raise KeyError(
f"Embedding key '{use_rep}' not found in adata.obsm. Available keys: {list(adata.obsm.keys())}"
)
embedding = adata.obsm[use_rep]
# Ensure embedding has the right number of components
if embedding.shape[1] < n_components:
raise ValueError(
f"Embedding has {embedding.shape[1]} components, but n_components={n_components}. Please provide an embedding with at least {n_components} components."
)
# Use only the first n_components
embedding = embedding[:, :n_components]
else:
logg.warning(
"CellCharter recommends to use a dimensionality reduced embedding of the data, e.g. a scVI embedding. Since 'use_rep' is not provided, PCA will be used as proxy - performance may be suboptimal."
)

arr_ad = ad.AnnData(X=arr)
sc.tl.pca(arr_ad)
embedding = arr_ad.obsm["X_pca"]

# cluster concatenated matrix with GMM, each cluster label equals to a niche label
niches = _get_GMM_clusters(arr_ad.obsm["X_pca"], n_components, random_state)
niches = _get_GMM_clusters(embedding, n_components, random_state)

adata.obs["cellcharter_niche"] = pd.Categorical(niches)
return
Expand Down Expand Up @@ -681,6 +714,7 @@ def _validate_niche_args(
aggregation: str | None,
n_components: int | None,
random_state: int,
use_rep: str | None,
inplace: bool,
) -> None:
"""
Expand Down Expand Up @@ -761,8 +795,8 @@ def _validate_niche_args(
],
},
"cellcharter": {
"required": ["distance", "aggregation", "n_components", "random_state"],
"optional": [],
"required": ["distance", "aggregation", "random_state"],
"optional": ["n_components", "use_rep"],
"unused": [
"groups",
"min_niche_size",
Expand Down Expand Up @@ -794,6 +828,7 @@ def _validate_niche_args(
"aggregation": aggregation,
"n_components": n_components,
"random_state": random_state,
"use_rep": use_rep,
},
flavor_param_specs[flavor],
)
Expand Down Expand Up @@ -828,6 +863,9 @@ def _validate_niche_args(
if not isinstance(random_state, int):
raise TypeError(f"'random_state' must be an integer, got {type(random_state).__name__}")

if use_rep is not None and not isinstance(use_rep, str):
raise TypeError(f"'use_rep' must be a string, got {type(use_rep).__name__}")

# for mypy
if resolutions is None:
resolutions = [0.0]
Expand Down