diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 13c7afef..99717051 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -45,6 +45,7 @@ def calculate_niche( n_components: int | None = None, random_state: int = 42, spatial_connectivities_key: str = "spatial_connectivities", + use_rep: str | None = None, inplace: bool = True, ) -> AnnData: """ @@ -58,7 +59,7 @@ def calculate_niche( Method to use for niche calculation. Available options are: - `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication). - - `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. + - `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA for dimensionality reduction. An arbitrary embedding can be used instead of PCA by setting the `use_rep` parameter which will try to find the embedding in `adata.obsm`. %(library_key)s If provided, niches will be calculated separately for each unique value in this column. Each niche will be prefixed with the library identifier. @@ -103,6 +104,9 @@ def calculate_niche( Optional if flavor == `{fla.CELLCHARTER.s!r}`. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. + use_rep + Key in `adata.obsm` where the embedding is stored. If provided, this embedding will be used instead of PCA for dimensionality reduction. + Optional if flavor == `{fla.CELLCHARTER.s!r}`. inplace If 'True', perform the operation in place. If 'False', return a new AnnData object with the niche labels. @@ -111,6 +115,12 @@ def calculate_niche( if flavor == "cellcharter" and aggregation is None: aggregation = "mean" + if distance is None: + distance = 3 if flavor == "cellcharter" else 1 + + if flavor == "cellcharter" and n_components is None: + n_components = 10 + _validate_niche_args( data, flavor, @@ -127,15 +137,13 @@ def calculate_niche( aggregation, n_components, random_state, + use_rep, inplace, ) if resolutions is None: resolutions = [0.5] - if distance is None: - distance = 1 - if isinstance(data, SpatialData): orig_adata = data.tables[table_key] adata = orig_adata.copy() @@ -225,6 +233,7 @@ def calculate_niche( n_components, random_state, spatial_connectivities_key, + use_rep, ) if not inplace: @@ -293,6 +302,7 @@ def _calculate_niches( n_components: int | None, random_state: int, spatial_connectivities_key: str, + use_rep: str | None, ) -> None: """Calculate niches using the specified flavor and parameters.""" if flavor == "neighborhood": @@ -321,6 +331,7 @@ def _calculate_niches( n_components, random_state, spatial_connectivities_key, + use_rep, ) @@ -470,6 +481,7 @@ def _get_cellcharter_niches( n_components: int, random_state: int, spatial_connectivities_key: str, + use_rep: str | None = None, ) -> None: """adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" @@ -494,11 +506,32 @@ def _get_cellcharter_niches( concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally arr = concatenated_matrix.toarray() # Densify - arr_ad = ad.AnnData(X=arr) - sc.tl.pca(arr_ad) + + if use_rep is not None: + # Use provided embedding from adata.obsm + if use_rep not in adata.obsm: + raise KeyError( + f"Embedding key '{use_rep}' not found in adata.obsm. Available keys: {list(adata.obsm.keys())}" + ) + embedding = adata.obsm[use_rep] + # Ensure embedding has the right number of components + if embedding.shape[1] < n_components: + raise ValueError( + f"Embedding has {embedding.shape[1]} components, but n_components={n_components}. Please provide an embedding with at least {n_components} components." + ) + # Use only the first n_components + embedding = embedding[:, :n_components] + else: + logg.warning( + "CellCharter recommends to use a dimensionality reduced embedding of the data, e.g. a scVI embedding. Since 'use_rep' is not provided, PCA will be used as proxy - performance may be suboptimal." + ) + + arr_ad = ad.AnnData(X=arr) + sc.tl.pca(arr_ad) + embedding = arr_ad.obsm["X_pca"] # cluster concatenated matrix with GMM, each cluster label equals to a niche label - niches = _get_GMM_clusters(arr_ad.obsm["X_pca"], n_components, random_state) + niches = _get_GMM_clusters(embedding, n_components, random_state) adata.obs["cellcharter_niche"] = pd.Categorical(niches) return @@ -681,6 +714,7 @@ def _validate_niche_args( aggregation: str | None, n_components: int | None, random_state: int, + use_rep: str | None, inplace: bool, ) -> None: """ @@ -761,8 +795,8 @@ def _validate_niche_args( ], }, "cellcharter": { - "required": ["distance", "aggregation", "n_components", "random_state"], - "optional": [], + "required": ["distance", "aggregation", "random_state"], + "optional": ["n_components", "use_rep"], "unused": [ "groups", "min_niche_size", @@ -794,6 +828,7 @@ def _validate_niche_args( "aggregation": aggregation, "n_components": n_components, "random_state": random_state, + "use_rep": use_rep, }, flavor_param_specs[flavor], ) @@ -828,6 +863,9 @@ def _validate_niche_args( if not isinstance(random_state, int): raise TypeError(f"'random_state' must be an integer, got {type(random_state).__name__}") + if use_rep is not None and not isinstance(use_rep, str): + raise TypeError(f"'use_rep' must be a string, got {type(use_rep).__name__}") + # for mypy if resolutions is None: resolutions = [0.0]