From 92287c857a82770a797fc66b43e244fd2708a923 Mon Sep 17 00:00:00 2001 From: Casey Brooks Date: Thu, 25 Dec 2025 12:46:41 +0000 Subject: [PATCH] fix(figure): serialize align label groups for pickle --- lib/matplotlib/figure.py | 35 +++++++++++++++++++++++++ lib/matplotlib/tests/test_figure.py | 40 +++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index b4c38368bfe0..a1adb1ce0cb3 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -3164,6 +3164,24 @@ def __getstate__(self): # re-attached to another. state.pop("canvas") + serialized_align = {axis: [] for axis in ("x", "y")} + align_groups = getattr(self, "_align_label_groups", None) + if align_groups is not None: + axes = self.axes + axis_lookup = {id(ax): idx for idx, ax in enumerate(axes)} + for axis_name, grouper in align_groups.items(): + for group in grouper: + indices = [ + axis_lookup[id(ax)] + for ax in group + if ax is not None and id(ax) in axis_lookup + ] + if indices: + serialized_align[axis_name].append(indices) + + state["_serialized_align_label_groups"] = serialized_align + state.pop("_align_label_groups", None) + # discard any changes to the dpi due to pixel ratio changes state["_dpi"] = state.get('_original_dpi', state['_dpi']) @@ -3179,6 +3197,7 @@ def __getstate__(self): def __setstate__(self, state): version = state.pop('__mpl_version__') restore_to_pylab = state.pop('_restore_to_pylab', False) + serialized_align = state.pop('_serialized_align_label_groups', None) if version != mpl.__version__: _api.warn_external( @@ -3187,6 +3206,22 @@ def __setstate__(self, state): self.__dict__ = state + self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper()} + if serialized_align is not None: + axes = self.axes + for axis_name, groups in serialized_align.items(): + grouper = self._align_label_groups.get(axis_name) + if grouper is None: + continue + for group in groups: + axes_group = [ + axes[idx] for idx in group + if 0 <= idx < len(axes) + ] + if not axes_group: + continue + grouper.join(axes_group[0], *axes_group[1:]) + # re-initialise some of the unstored state information FigureCanvasBase(self) # Set self.canvas. diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 80d80f969163..757844c3de31 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -101,6 +101,46 @@ def test_align_labels_stray_axes(): np.testing.assert_allclose(yn[::2], yn[1::2]) +def _serialized_alignment_groups(fig, axis): + axes = fig.axes + lookup = {id(ax): idx for idx, ax in enumerate(axes)} + groups = [] + for group in fig._align_label_groups[axis]: + indices = [ + lookup[id(ax)] + for ax in group + if ax is not None and id(ax) in lookup + ] + if indices: + groups.append(tuple(sorted(indices))) + return sorted(groups) + + +def test_align_labels_pickle_roundtrip(): + fig, axs = plt.subplots(2, 1) + axs[0].plot([0, 1]) + axs[0].set_ylabel('speed') + axs[0].set_xlabel('time') + axs[1].plot([0, 1]) + axs[1].set_ylabel('accel') + axs[1].set_xlabel('time') + + fig.align_labels() + + serialized = pickle.dumps(fig) + reloaded = pickle.loads(serialized) + + original_x = _serialized_alignment_groups(fig, 'x') + original_y = _serialized_alignment_groups(fig, 'y') + reloaded_x = _serialized_alignment_groups(reloaded, 'x') + reloaded_y = _serialized_alignment_groups(reloaded, 'y') + + assert original_x + assert original_y + assert original_x == reloaded_x + assert original_y == reloaded_y + + def test_figure_label(): # pyplot figure creation, selection, and closing with label/number/instance plt.close('all')