Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3164,6 +3164,24 @@ def __getstate__(self):
# re-attached to another.
state.pop("canvas")

serialized_align = {axis: [] for axis in ("x", "y")}
align_groups = getattr(self, "_align_label_groups", None)
if align_groups is not None:
axes = self.axes
axis_lookup = {id(ax): idx for idx, ax in enumerate(axes)}
for axis_name, grouper in align_groups.items():
for group in grouper:
indices = [
axis_lookup[id(ax)]
for ax in group
if ax is not None and id(ax) in axis_lookup
]
if indices:
serialized_align[axis_name].append(indices)

state["_serialized_align_label_groups"] = serialized_align
state.pop("_align_label_groups", None)

# discard any changes to the dpi due to pixel ratio changes
state["_dpi"] = state.get('_original_dpi', state['_dpi'])

Expand All @@ -3179,6 +3197,7 @@ def __getstate__(self):
def __setstate__(self, state):
version = state.pop('__mpl_version__')
restore_to_pylab = state.pop('_restore_to_pylab', False)
serialized_align = state.pop('_serialized_align_label_groups', None)

if version != mpl.__version__:
_api.warn_external(
Expand All @@ -3187,6 +3206,22 @@ def __setstate__(self, state):

self.__dict__ = state

self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper()}
if serialized_align is not None:
axes = self.axes
for axis_name, groups in serialized_align.items():
grouper = self._align_label_groups.get(axis_name)
if grouper is None:
continue
for group in groups:
axes_group = [
axes[idx] for idx in group
if 0 <= idx < len(axes)
]
if not axes_group:
continue
grouper.join(axes_group[0], *axes_group[1:])

# re-initialise some of the unstored state information
FigureCanvasBase(self) # Set self.canvas.

Expand Down
40 changes: 40 additions & 0 deletions lib/matplotlib/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,46 @@ def test_align_labels_stray_axes():
np.testing.assert_allclose(yn[::2], yn[1::2])


def _serialized_alignment_groups(fig, axis):
axes = fig.axes
lookup = {id(ax): idx for idx, ax in enumerate(axes)}
groups = []
for group in fig._align_label_groups[axis]:
indices = [
lookup[id(ax)]
for ax in group
if ax is not None and id(ax) in lookup
]
if indices:
groups.append(tuple(sorted(indices)))
return sorted(groups)


def test_align_labels_pickle_roundtrip():
fig, axs = plt.subplots(2, 1)
axs[0].plot([0, 1])
axs[0].set_ylabel('speed')
axs[0].set_xlabel('time')
axs[1].plot([0, 1])
axs[1].set_ylabel('accel')
axs[1].set_xlabel('time')

fig.align_labels()

serialized = pickle.dumps(fig)
reloaded = pickle.loads(serialized)

original_x = _serialized_alignment_groups(fig, 'x')
original_y = _serialized_alignment_groups(fig, 'y')
reloaded_x = _serialized_alignment_groups(reloaded, 'x')
reloaded_y = _serialized_alignment_groups(reloaded, 'y')

assert original_x
assert original_y
assert original_x == reloaded_x
assert original_y == reloaded_y


def test_figure_label():
# pyplot figure creation, selection, and closing with label/number/instance
plt.close('all')
Expand Down