diff --git a/src/palantir/plot.py b/src/palantir/plot.py index b5c458e5..1b371eae 100644 --- a/src/palantir/plot.py +++ b/src/palantir/plot.py @@ -1155,10 +1155,13 @@ def plot_stats( } scatter_kwargs.update(kwargs) - try: - cmap = matplotlib.colormaps[cmap] if isinstance(cmap, str) else cmap - except KeyError: + if cmap is None: cmap = matplotlib.colormaps["viridis"] + else: + try: + cmap = matplotlib.colormaps[cmap] if isinstance(cmap, str) else cmap + except KeyError: + cmap = matplotlib.colormaps["viridis"] cmap = copy(cmap) cmap.set_bad(na_color) diff --git a/src/palantir/utils.py b/src/palantir/utils.py index ef8feb2c..077926af 100644 --- a/src/palantir/utils.py +++ b/src/palantir/utils.py @@ -454,10 +454,10 @@ def run_diffusion_maps( and returned. """ - if isinstance(data, pd.DataFrame): - data_df = data - else: + if isinstance(data, sc.AnnData): data_df = pd.DataFrame(data.obsm[pca_key], index=data.obs_names) + else: + data_df = data if not isinstance(data_df, pd.DataFrame) and not issparse(data_df): raise ValueError("'data_df' should be a pd.DataFrame or sc.AnnData") diff --git a/tests/plot.py b/tests/plot.py index 6a8805b3..8a669803 100644 --- a/tests/plot.py +++ b/tests/plot.py @@ -550,7 +550,7 @@ def test_plot_stats_optional_parameters(mock_anndata): def test_plot_stats_masking(mock_anndata): # Create a condition here that you want to mask mask_condition = mock_anndata.obs["palantir_pseudotime"] > 0.5 - mock_anndata.obsm["branch_masks"] = mask_condition + mock_anndata.obsm["branch_masks"] = pd.DataFrame({"mock_branch": mask_condition}) fig, ax = plot_stats( mock_anndata, x="palantir_pseudotime",