Skip to content

Commit a3061c2

Browse files
marcovarronepre-commit-ci[bot]timtreis
authored
Change niche flavor to cellcharter_simple and default distance = 3 (#978)
* Change niche flavor to cellcharter_simple and default distance = 3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added warning message --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Treis <tim.treis@helmholtz-munich.de> Co-authored-by: Tim Treis <tim.treis@stud.uni-heidelberg.de>
1 parent c32e7fd commit a3061c2

File tree

1 file changed

+47
-9
lines changed

1 file changed

+47
-9
lines changed

src/squidpy/gr/_niche.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def calculate_niche(
4545
n_components: int | None = None,
4646
random_state: int = 42,
4747
spatial_connectivities_key: str = "spatial_connectivities",
48+
use_rep: str | None = None,
4849
inplace: bool = True,
4950
) -> AnnData:
5051
"""
@@ -58,7 +59,7 @@ def calculate_niche(
5859
Method to use for niche calculation. Available options are:
5960
- `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile.
6061
- `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication).
61-
- `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach.
62+
- `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA for dimensionality reduction. An arbitrary embedding can be used instead of PCA by setting the `use_rep` parameter which will try to find the embedding in `adata.obsm`.
6263
%(library_key)s
6364
If provided, niches will be calculated separately for each unique value in this column.
6465
Each niche will be prefixed with the library identifier.
@@ -103,6 +104,9 @@ def calculate_niche(
103104
Optional if flavor == `{fla.CELLCHARTER.s!r}`.
104105
spatial_connectivities_key
105106
Key in `adata.obsp` where spatial connectivities are stored.
107+
use_rep
108+
Key in `adata.obsm` where the embedding is stored. If provided, this embedding will be used instead of PCA for dimensionality reduction.
109+
Optional if flavor == `{fla.CELLCHARTER.s!r}`.
106110
inplace
107111
If 'True', perform the operation in place.
108112
If 'False', return a new AnnData object with the niche labels.
@@ -111,6 +115,12 @@ def calculate_niche(
111115
if flavor == "cellcharter" and aggregation is None:
112116
aggregation = "mean"
113117

118+
if distance is None:
119+
distance = 3 if flavor == "cellcharter" else 1
120+
121+
if flavor == "cellcharter" and n_components is None:
122+
n_components = 10
123+
114124
_validate_niche_args(
115125
data,
116126
flavor,
@@ -127,15 +137,13 @@ def calculate_niche(
127137
aggregation,
128138
n_components,
129139
random_state,
140+
use_rep,
130141
inplace,
131142
)
132143

133144
if resolutions is None:
134145
resolutions = [0.5]
135146

136-
if distance is None:
137-
distance = 1
138-
139147
if isinstance(data, SpatialData):
140148
orig_adata = data.tables[table_key]
141149
adata = orig_adata.copy()
@@ -225,6 +233,7 @@ def calculate_niche(
225233
n_components,
226234
random_state,
227235
spatial_connectivities_key,
236+
use_rep,
228237
)
229238

230239
if not inplace:
@@ -293,6 +302,7 @@ def _calculate_niches(
293302
n_components: int | None,
294303
random_state: int,
295304
spatial_connectivities_key: str,
305+
use_rep: str | None,
296306
) -> None:
297307
"""Calculate niches using the specified flavor and parameters."""
298308
if flavor == "neighborhood":
@@ -321,6 +331,7 @@ def _calculate_niches(
321331
n_components,
322332
random_state,
323333
spatial_connectivities_key,
334+
use_rep,
324335
)
325336

326337

@@ -470,6 +481,7 @@ def _get_cellcharter_niches(
470481
n_components: int,
471482
random_state: int,
472483
spatial_connectivities_key: str,
484+
use_rep: str | None = None,
473485
) -> None:
474486
"""adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py
475487
and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py"""
@@ -494,11 +506,32 @@ def _get_cellcharter_niches(
494506

495507
concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally
496508
arr = concatenated_matrix.toarray() # Densify
497-
arr_ad = ad.AnnData(X=arr)
498-
sc.tl.pca(arr_ad)
509+
510+
if use_rep is not None:
511+
# Use provided embedding from adata.obsm
512+
if use_rep not in adata.obsm:
513+
raise KeyError(
514+
f"Embedding key '{use_rep}' not found in adata.obsm. Available keys: {list(adata.obsm.keys())}"
515+
)
516+
embedding = adata.obsm[use_rep]
517+
# Ensure embedding has the right number of components
518+
if embedding.shape[1] < n_components:
519+
raise ValueError(
520+
f"Embedding has {embedding.shape[1]} components, but n_components={n_components}. Please provide an embedding with at least {n_components} components."
521+
)
522+
# Use only the first n_components
523+
embedding = embedding[:, :n_components]
524+
else:
525+
logg.warning(
526+
"CellCharter recommends to use a dimensionality reduced embedding of the data, e.g. a scVI embedding. Since 'use_rep' is not provided, PCA will be used as proxy - performance may be suboptimal."
527+
)
528+
529+
arr_ad = ad.AnnData(X=arr)
530+
sc.tl.pca(arr_ad)
531+
embedding = arr_ad.obsm["X_pca"]
499532

500533
# cluster concatenated matrix with GMM, each cluster label equals to a niche label
501-
niches = _get_GMM_clusters(arr_ad.obsm["X_pca"], n_components, random_state)
534+
niches = _get_GMM_clusters(embedding, n_components, random_state)
502535

503536
adata.obs["cellcharter_niche"] = pd.Categorical(niches)
504537
return
@@ -681,6 +714,7 @@ def _validate_niche_args(
681714
aggregation: str | None,
682715
n_components: int | None,
683716
random_state: int,
717+
use_rep: str | None,
684718
inplace: bool,
685719
) -> None:
686720
"""
@@ -761,8 +795,8 @@ def _validate_niche_args(
761795
],
762796
},
763797
"cellcharter": {
764-
"required": ["distance", "aggregation", "n_components", "random_state"],
765-
"optional": [],
798+
"required": ["distance", "aggregation", "random_state"],
799+
"optional": ["n_components", "use_rep"],
766800
"unused": [
767801
"groups",
768802
"min_niche_size",
@@ -794,6 +828,7 @@ def _validate_niche_args(
794828
"aggregation": aggregation,
795829
"n_components": n_components,
796830
"random_state": random_state,
831+
"use_rep": use_rep,
797832
},
798833
flavor_param_specs[flavor],
799834
)
@@ -828,6 +863,9 @@ def _validate_niche_args(
828863
if not isinstance(random_state, int):
829864
raise TypeError(f"'random_state' must be an integer, got {type(random_state).__name__}")
830865

866+
if use_rep is not None and not isinstance(use_rep, str):
867+
raise TypeError(f"'use_rep' must be a string, got {type(use_rep).__name__}")
868+
831869
# for mypy
832870
if resolutions is None:
833871
resolutions = [0.0]

0 commit comments

Comments
 (0)