diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 90b0bf117e16..6a2e7b3e9e8c 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -1307,6 +1307,15 @@ def __clear(self): self._get_lines = _process_plot_var_args(self) self._get_patches_for_fill = _process_plot_var_args(self, 'fill') + legend = getattr(self, "legend_", None) + if legend is not None: + legend.remove() + + for child in list(getattr(self, "_children", ())): + if child is legend: + continue + child.remove() + self._gridOn = mpl.rcParams['axes.grid'] self._children = [] self._mouseover_set = _OrderedSet() diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 6c18ba1a643e..77c717eb58c8 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -947,7 +947,19 @@ def clear(self, keep_observers=False): for ax in tuple(self.axes): # Iterate over the copy. ax.clear() - self.delaxes(ax) # Remove ax from self._axstack. + ax.remove() + + for artist_list in ( + self.artists, + self.lines, + self.patches, + self.texts, + self.images, + self.legends): + for artist in list(artist_list): + artist.remove() + if getattr(artist, "figure", None) is not None: + artist.figure = None self.artists = [] self.lines = [] diff --git a/lib/matplotlib/tests/test_artist.py b/lib/matplotlib/tests/test_artist.py index 7e302784c332..2ca43febe730 100644 --- a/lib/matplotlib/tests/test_artist.py +++ b/lib/matplotlib/tests/test_artist.py @@ -214,6 +214,41 @@ def test_remove(): assert ax.stale +@pytest.mark.backend("Agg") +def test_cla_unsets_artist_axes(): + fig, ax = plt.subplots() + line, = ax.plot([0, 1], [0, 1]) + patch = ax.add_patch(mpatches.Rectangle((0, 0), 1, 1)) + image = ax.imshow(np.arange(4).reshape(2, 2)) + text = ax.text(0.5, 0.5, "hi") + collection = ax.scatter([0], [0]) + legend = ax.legend([line], ["line"]) + + ax.cla() + + for artist in (line, patch, image, text, collection, legend): + assert artist.axes is None + + +@pytest.mark.backend("Agg") +def test_clf_unsets_figure_and_axes_parents(): + fig, ax = plt.subplots() + line, = ax.plot([0, 1], [0, 1]) + image = ax.imshow(np.arange(4).reshape(2, 2)) + fig_text = fig.text(0.5, 0.5, "hi") + fig_legend = fig.legend([line], ["line"]) + colorbar = fig.colorbar(image, ax=ax) + + fig.clf() + + assert fig.axes == [] + assert line.axes is None + assert image.axes is None + assert fig_text.figure is None + assert fig_legend.figure is None + assert colorbar.ax.figure is None + + @image_comparison(["default_edges.png"], remove_text=True, style='default') def test_default_edges(): # Remove this line when this test image is regenerated.