From 66734991fbe06e868eb60edaea7ce9a14c821564 Mon Sep 17 00:00:00 2001 From: Marco Varrone Date: Mon, 24 Mar 2025 10:29:40 +0100 Subject: [PATCH 1/3] Change niche flavor to cellcharter_simple and default distance = 3 --- src/squidpy/_constants/_constants.py | 2 +- src/squidpy/gr/_niche.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/squidpy/_constants/_constants.py b/src/squidpy/_constants/_constants.py index 403f072ba..add11ecc2 100644 --- a/src/squidpy/_constants/_constants.py +++ b/src/squidpy/_constants/_constants.py @@ -129,6 +129,6 @@ class TenxVersions(str, ModeEnum): class NicheDefinitions(ModeEnum): NEIGHBORHOOD = "neighborhood" UTAG = "utag" - CELLCHARTER = "cellcharter" + CELLCHARTER = "cellcharter_simple" SPOT = "spot" BANKSY = "banksy" diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index c9752885c..d4244befc 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -29,7 +29,7 @@ @inject_docs(fla=NicheDefinitions) def calculate_niche( data: AnnData | SpatialData, - flavor: Literal["neighborhood", "utag", "cellcharter"], + flavor: Literal["neighborhood", "utag", "cellcharter_simple"], library_key: str | None = None, table_key: str | None = None, mask: pd.core.series.Series = None, @@ -58,7 +58,7 @@ def calculate_niche( Method to use for niche calculation. Available options are: - `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication). - - `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. + - `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA instead of scVI for dimensionality reduction. %(library_key)s If provided, niches will be calculated separately for each unique value in this column. Each niche will be prefixed with the library identifier. @@ -108,7 +108,7 @@ def calculate_niche( If 'False', return a new AnnData object with the niche labels. """ - if flavor == "cellcharter" and aggregation is None: + if flavor == "cellcharter_simple" and aggregation is None: aggregation = "mean" _validate_niche_args( @@ -134,7 +134,7 @@ def calculate_niche( resolutions = [0.5] if distance is None: - distance = 1 + distance = 3 if flavor == "cellcharter_simple" else 1 if isinstance(data, SpatialData): orig_adata = data.tables[table_key] @@ -187,7 +187,7 @@ def calculate_niche( mask=lib_mask, groups=groups, n_neighbors=n_neighbors, - resolutions=None if flavor == "cellcharter" else resolutions, + resolutions=None if flavor == "cellcharter_simple" else resolutions, min_niche_size=min_niche_size, scale=scale, abs_nhood=abs_nhood, @@ -258,7 +258,7 @@ def _get_result_columns( library_str = f"_{library_key}" if library_key is not None else "" - if flavor == "cellcharter": + if flavor == "cellcharter_simple": base_column = "cellcharter_niche" if library_key is None: return [base_column] @@ -311,7 +311,7 @@ def _calculate_niches( ) elif flavor == "utag": _get_utag_niches(adata, n_neighbors, resolutions, spatial_connectivities_key) - elif flavor == "cellcharter": + elif flavor == "cellcharter_simple": assert isinstance(aggregation, str) # for mypy assert isinstance(n_components, int) # for mypy _get_cellcharter_niches( @@ -667,7 +667,7 @@ def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str) def _validate_niche_args( data: AnnData | SpatialData, - flavor: Literal["neighborhood", "utag", "cellcharter"], + flavor: Literal["neighborhood", "utag", "cellcharter_simple"], library_key: str | None, table_key: str | None, groups: str | None, @@ -697,8 +697,8 @@ def _validate_niche_args( if not isinstance(data, AnnData | SpatialData): raise TypeError(f"'data' must be an AnnData or SpatialData object, got {type(data).__name__}") - if flavor not in ["neighborhood", "utag", "cellcharter"]: - raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter'.") + if flavor not in ["neighborhood", "utag", "cellcharter_simple"]: + raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter_simple'.") if library_key is not None: if not isinstance(library_key, str): @@ -760,7 +760,7 @@ def _validate_niche_args( "random_state", ], }, - "cellcharter": { + "cellcharter_simple": { "required": ["distance", "aggregation", "n_components", "random_state"], "optional": [], "unused": [ @@ -809,7 +809,7 @@ def _validate_niche_args( if distance is not None and isinstance(distance, int) and distance < 1: raise ValueError(f"'distance' must be at least 1, got {distance}") - elif flavor == "cellcharter": + elif flavor == "cellcharter_simple": if distance is not None and not isinstance(distance, int): raise TypeError(f"'distance' must be an integer, got {type(distance).__name__}") if distance is not None and distance < 1: @@ -843,7 +843,7 @@ def _check_unnecessary_args(flavor: str, param_dict: dict[str, Any], param_specs Parameters ---------- flavor - The flavor being used ('neighborhood', 'utag', or 'cellcharter') + The flavor being used ('neighborhood', 'utag', or 'cellcharter_simple') param_dict Dictionary of parameter names to their values param_specs From 69d7d1beb76213308eb67a342e7a84be3727bf35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Mar 2025 09:34:41 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index d4244befc..7e5bcbc95 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -698,7 +698,9 @@ def _validate_niche_args( raise TypeError(f"'data' must be an AnnData or SpatialData object, got {type(data).__name__}") if flavor not in ["neighborhood", "utag", "cellcharter_simple"]: - raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter_simple'.") + raise ValueError( + f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter_simple'." + ) if library_key is not None: if not isinstance(library_key, str): From 98a3a78ed1b5c3ef2f48fc35fb22d7e3646bfae9 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 23 Oct 2025 21:16:44 +0200 Subject: [PATCH 3/3] added warning message --- src/squidpy/_constants/_constants.py | 2 +- src/squidpy/gr/_niche.py | 80 ++++++++++++++++++++-------- 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/src/squidpy/_constants/_constants.py b/src/squidpy/_constants/_constants.py index add11ecc2..403f072ba 100644 --- a/src/squidpy/_constants/_constants.py +++ b/src/squidpy/_constants/_constants.py @@ -129,6 +129,6 @@ class TenxVersions(str, ModeEnum): class NicheDefinitions(ModeEnum): NEIGHBORHOOD = "neighborhood" UTAG = "utag" - CELLCHARTER = "cellcharter_simple" + CELLCHARTER = "cellcharter" SPOT = "spot" BANKSY = "banksy" diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 474eb2a38..99717051b 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -29,7 +29,7 @@ @inject_docs(fla=NicheDefinitions) def calculate_niche( data: AnnData | SpatialData, - flavor: Literal["neighborhood", "utag", "cellcharter_simple"], + flavor: Literal["neighborhood", "utag", "cellcharter"], library_key: str | None = None, table_key: str | None = None, mask: pd.core.series.Series = None, @@ -45,6 +45,7 @@ def calculate_niche( n_components: int | None = None, random_state: int = 42, spatial_connectivities_key: str = "spatial_connectivities", + use_rep: str | None = None, inplace: bool = True, ) -> AnnData: """ @@ -58,7 +59,7 @@ def calculate_niche( Method to use for niche calculation. Available options are: - `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication). - - `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA instead of scVI for dimensionality reduction. + - `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA for dimensionality reduction. An arbitrary embedding can be used instead of PCA by setting the `use_rep` parameter which will try to find the embedding in `adata.obsm`. %(library_key)s If provided, niches will be calculated separately for each unique value in this column. Each niche will be prefixed with the library identifier. @@ -103,14 +104,23 @@ def calculate_niche( Optional if flavor == `{fla.CELLCHARTER.s!r}`. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. + use_rep + Key in `adata.obsm` where the embedding is stored. If provided, this embedding will be used instead of PCA for dimensionality reduction. + Optional if flavor == `{fla.CELLCHARTER.s!r}`. inplace If 'True', perform the operation in place. If 'False', return a new AnnData object with the niche labels. """ - if flavor == "cellcharter_simple" and aggregation is None: + if flavor == "cellcharter" and aggregation is None: aggregation = "mean" + if distance is None: + distance = 3 if flavor == "cellcharter" else 1 + + if flavor == "cellcharter" and n_components is None: + n_components = 10 + _validate_niche_args( data, flavor, @@ -127,15 +137,13 @@ def calculate_niche( aggregation, n_components, random_state, + use_rep, inplace, ) if resolutions is None: resolutions = [0.5] - if distance is None: - distance = 3 if flavor == "cellcharter_simple" else 1 - if isinstance(data, SpatialData): orig_adata = data.tables[table_key] adata = orig_adata.copy() @@ -187,7 +195,7 @@ def calculate_niche( mask=lib_mask, groups=groups, n_neighbors=n_neighbors, - resolutions=None if flavor == "cellcharter_simple" else resolutions, + resolutions=None if flavor == "cellcharter" else resolutions, min_niche_size=min_niche_size, scale=scale, abs_nhood=abs_nhood, @@ -225,6 +233,7 @@ def calculate_niche( n_components, random_state, spatial_connectivities_key, + use_rep, ) if not inplace: @@ -258,7 +267,7 @@ def _get_result_columns( library_str = f"_{library_key}" if library_key is not None else "" - if flavor == "cellcharter_simple": + if flavor == "cellcharter": base_column = "cellcharter_niche" if library_key is None: return [base_column] @@ -293,6 +302,7 @@ def _calculate_niches( n_components: int | None, random_state: int, spatial_connectivities_key: str, + use_rep: str | None, ) -> None: """Calculate niches using the specified flavor and parameters.""" if flavor == "neighborhood": @@ -311,7 +321,7 @@ def _calculate_niches( ) elif flavor == "utag": _get_utag_niches(adata, n_neighbors, resolutions, spatial_connectivities_key) - elif flavor == "cellcharter_simple": + elif flavor == "cellcharter": assert isinstance(aggregation, str) # for mypy assert isinstance(n_components, int) # for mypy _get_cellcharter_niches( @@ -321,6 +331,7 @@ def _calculate_niches( n_components, random_state, spatial_connectivities_key, + use_rep, ) @@ -470,6 +481,7 @@ def _get_cellcharter_niches( n_components: int, random_state: int, spatial_connectivities_key: str, + use_rep: str | None = None, ) -> None: """adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" @@ -494,11 +506,32 @@ def _get_cellcharter_niches( concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally arr = concatenated_matrix.toarray() # Densify - arr_ad = ad.AnnData(X=arr) - sc.tl.pca(arr_ad) + + if use_rep is not None: + # Use provided embedding from adata.obsm + if use_rep not in adata.obsm: + raise KeyError( + f"Embedding key '{use_rep}' not found in adata.obsm. Available keys: {list(adata.obsm.keys())}" + ) + embedding = adata.obsm[use_rep] + # Ensure embedding has the right number of components + if embedding.shape[1] < n_components: + raise ValueError( + f"Embedding has {embedding.shape[1]} components, but n_components={n_components}. Please provide an embedding with at least {n_components} components." + ) + # Use only the first n_components + embedding = embedding[:, :n_components] + else: + logg.warning( + "CellCharter recommends to use a dimensionality reduced embedding of the data, e.g. a scVI embedding. Since 'use_rep' is not provided, PCA will be used as proxy - performance may be suboptimal." + ) + + arr_ad = ad.AnnData(X=arr) + sc.tl.pca(arr_ad) + embedding = arr_ad.obsm["X_pca"] # cluster concatenated matrix with GMM, each cluster label equals to a niche label - niches = _get_GMM_clusters(arr_ad.obsm["X_pca"], n_components, random_state) + niches = _get_GMM_clusters(embedding, n_components, random_state) adata.obs["cellcharter_niche"] = pd.Categorical(niches) return @@ -667,7 +700,7 @@ def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str) def _validate_niche_args( data: AnnData | SpatialData, - flavor: Literal["neighborhood", "utag", "cellcharter_simple"], + flavor: Literal["neighborhood", "utag", "cellcharter"], library_key: str | None, table_key: str | None, groups: str | None, @@ -681,6 +714,7 @@ def _validate_niche_args( aggregation: str | None, n_components: int | None, random_state: int, + use_rep: str | None, inplace: bool, ) -> None: """ @@ -697,10 +731,8 @@ def _validate_niche_args( if not isinstance(data, AnnData | SpatialData): raise TypeError(f"'data' must be an AnnData or SpatialData object, got {type(data).__name__}") - if flavor not in ["neighborhood", "utag", "cellcharter_simple"]: - raise ValueError( - f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter_simple'." - ) + if flavor not in ["neighborhood", "utag", "cellcharter"]: + raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter'.") if library_key is not None: if not isinstance(library_key, str): @@ -762,9 +794,9 @@ def _validate_niche_args( "random_state", ], }, - "cellcharter_simple": { - "required": ["distance", "aggregation", "n_components", "random_state"], - "optional": [], + "cellcharter": { + "required": ["distance", "aggregation", "random_state"], + "optional": ["n_components", "use_rep"], "unused": [ "groups", "min_niche_size", @@ -796,6 +828,7 @@ def _validate_niche_args( "aggregation": aggregation, "n_components": n_components, "random_state": random_state, + "use_rep": use_rep, }, flavor_param_specs[flavor], ) @@ -811,7 +844,7 @@ def _validate_niche_args( if distance is not None and isinstance(distance, int) and distance < 1: raise ValueError(f"'distance' must be at least 1, got {distance}") - elif flavor == "cellcharter_simple": + elif flavor == "cellcharter": if distance is not None and not isinstance(distance, int): raise TypeError(f"'distance' must be an integer, got {type(distance).__name__}") if distance is not None and distance < 1: @@ -830,6 +863,9 @@ def _validate_niche_args( if not isinstance(random_state, int): raise TypeError(f"'random_state' must be an integer, got {type(random_state).__name__}") + if use_rep is not None and not isinstance(use_rep, str): + raise TypeError(f"'use_rep' must be a string, got {type(use_rep).__name__}") + # for mypy if resolutions is None: resolutions = [0.0] @@ -845,7 +881,7 @@ def _check_unnecessary_args(flavor: str, param_dict: dict[str, Any], param_specs Parameters ---------- flavor - The flavor being used ('neighborhood', 'utag', or 'cellcharter_simple') + The flavor being used ('neighborhood', 'utag', or 'cellcharter') param_dict Dictionary of parameter names to their values param_specs