From 5c319e81d1f2fedb96c232c7c9e4532cce87babd Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Thu, 25 Dec 2025 12:42:05 +0000 Subject: [PATCH] fix(pyplot): respect registered cmap key --- lib/matplotlib/pyplot.py | 9 +++++--- lib/matplotlib/tests/test_pyplot.py | 32 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 58ce4c03fa87..4a919cdd4beb 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -2278,13 +2278,16 @@ def set_cmap(cmap: Colormap | str) -> None: matplotlib.cm.register_cmap matplotlib.cm.get_cmap """ - cmap = get_cmap(cmap) + cmap_obj = get_cmap(cmap) - rc('image', cmap=cmap.name) + if isinstance(cmap, str): + rc('image', cmap=cmap) + else: + rc('image', cmap=cmap_obj.name) im = gci() if im is not None: - im.set_cmap(cmap) + im.set_cmap(cmap_obj) @_copy_docstring_and_deprecators(matplotlib.image.imread) diff --git a/lib/matplotlib/tests/test_pyplot.py b/lib/matplotlib/tests/test_pyplot.py index 16a927e9e154..1575be5a2f29 100644 --- a/lib/matplotlib/tests/test_pyplot.py +++ b/lib/matplotlib/tests/test_pyplot.py @@ -9,6 +9,7 @@ import matplotlib as mpl from matplotlib.testing import subprocess_run_for_testing from matplotlib import pyplot as plt +from matplotlib.colors import LinearSegmentedColormap def test_pyplot_up_to_date(tmpdir): @@ -80,6 +81,37 @@ def test_stackplot_smoke(): plt.stackplot([1, 2, 3], [1, 2, 3]) +def test_set_cmap_uses_registered_key_for_rcparams(): + original_cmap = mpl.rcParams['image.cmap'] + cmap = LinearSegmentedColormap.from_list( + 'some_cmap_name', ['black', 'white'] + ) + mpl.colormaps.register(cmap, name='my_cmap_name', force=True) + try: + plt.set_cmap('my_cmap_name') + assert mpl.rcParams['image.cmap'] == 'my_cmap_name' + plt.imshow(np.arange(4).reshape(2, 2)) + finally: + mpl.rcParams['image.cmap'] = original_cmap + mpl.colormaps.unregister('my_cmap_name') + plt.close('all') + + +def test_set_cmap_colormap_object_rcparams_name_passthrough(): + original_cmap = mpl.rcParams['image.cmap'] + cmap = LinearSegmentedColormap.from_list( + 'unregistered_internal_name', ['black', 'white'] + ) + mpl.colormaps.register(cmap, name='alias_for_object', force=True) + try: + plt.set_cmap(cmap) + assert mpl.rcParams['image.cmap'] == 'unregistered_internal_name' + finally: + mpl.rcParams['image.cmap'] = original_cmap + mpl.colormaps.unregister('alias_for_object') + plt.close('all') + + def test_nrows_error(): with pytest.raises(TypeError): plt.subplot(nrows=1)