diff --git a/src/plottools/common.py b/src/plottools/common.py index 9a6a736..5217e17 100644 --- a/src/plottools/common.py +++ b/src/plottools/common.py @@ -27,7 +27,7 @@ import matplotlib.ticker as ticker -def common_xlabels(fig, axes=None): +def common_xlabels(fig, *axes): """ Reduce common xlabels. Remove all xlabels except for one that is centered at the bottommost axes. @@ -36,14 +36,16 @@ def common_xlabels(fig, axes=None): ---------- fig: matplotlib figure The figure containing the axes. - axes: None or sequence of matplotlib axes + axes: Sequence of matplotlib axes Axes whose xlabels should be merged. - If None take all axes of the figure. + If not specified, take all axes of the figure. """ - if axes is None: + if len(axes) == 0: axes = fig.get_axes() - if isinstance(axes, np.ndarray): - axes = axes.ravel() + if len(axes) == 0: + return + if len(axes) == 1 and isinstance(axes[0], np.ndarray): + axes = axes[0].ravel() coords = np.array([ax.get_position().get_points().ravel() for ax in axes]) minx = np.min(coords[:,0]) maxx = np.max(coords[:,2]) @@ -68,7 +70,7 @@ def common_xlabels(fig, axes=None): done = True -def common_ylabels(fig, axes=None): +def common_ylabels(fig, *axes): """ Reduce common ylabels. Remove all ylabels except for one that is centered at the leftmost axes. @@ -79,12 +81,14 @@ def common_ylabels(fig, axes=None): The figure containing the axes. axes: None or sequence of matplotlib axes Axes whose ylabels should be merged. - If None take all axes of the figure. + If not specified, take all axes of the figure. """ - if axes is None: + if len(axes) == 0: axes = fig.get_axes() - if isinstance(axes, np.ndarray): - axes = axes.ravel() + if len(axes) == 0: + return + if len(axes) == 1 and isinstance(axes[0], np.ndarray): + axes = axes[0].ravel() coords = np.array([ax.get_position().get_points().ravel() for ax in axes]) # center common ylabel: minx = np.min(coords[:,0]) @@ -108,7 +112,7 @@ def common_ylabels(fig, axes=None): done = True -def common_xticks(fig, axes=None): +def common_xticks(fig, *axes): """ Reduce common xtick labels and xlabels. Keep xtick labels only at the lowest axes and center the common xlabel. @@ -119,12 +123,14 @@ def common_xticks(fig, axes=None): The figure containing the axes. axes: None or sequence of matplotlib axes Axes whose xticks should be combined. - If None take all axes of the figure. + If not specified, take all axes of the figure. """ - if axes is None: + if len(axes) == 0: axes = fig.get_axes() - if isinstance(axes, np.ndarray): - axes = axes.ravel() + if len(axes) == 0: + return + if len(axes) == 1 and isinstance(axes[0], np.ndarray): + axes = axes[0].ravel() coords = np.array([ax.get_position().get_points().ravel() for ax in axes]) minx = np.min(coords[:,0]) maxx = np.max(coords[:,2]) @@ -150,7 +156,7 @@ def common_xticks(fig, axes=None): done = True -def common_yticks(fig, axes=None): +def common_yticks(fig, *axes): """ Reduce common ytick labels and ylabels. Keep ytick labels only at the leftmost axes and center the common ylabel. @@ -161,12 +167,14 @@ def common_yticks(fig, axes=None): The figure containing the axes. axes: None or sequence of matplotlib axes Axes whose yticks should be combined. - If None take all axes of the figure. + If not specified, take all axes of the figure. """ - if axes is None: + if len(axes) == 0: axes = fig.get_axes() - if isinstance(axes, np.ndarray): - axes = axes.ravel() + if len(axes) == 0: + return + if len(axes) == 1 and isinstance(axes[0], np.ndarray): + axes = axes[0].ravel() coords = np.array([ax.get_position().get_points().ravel() for ax in axes]) minx = np.min(coords[:,0]) maxx = np.max(coords[:,2]) @@ -190,7 +198,7 @@ def common_yticks(fig, axes=None): done = True -def common_xspines(fig, axes=None): +def common_xspines(fig, *axes): """ Reduce common x-spines, xtick labels, and xlabels. Keep spine and xtick labels only at the lowest axes and center the common xlabel. @@ -201,12 +209,14 @@ def common_xspines(fig, axes=None): The figure containing the axes. axes: None or sequence of matplotlib axes Axes whose xticks should be combined. - If None take all axes of the figure. + If not specified, take all axes of the figure. """ - if axes is None: + if len(axes) == 0: axes = fig.get_axes() - if isinstance(axes, np.ndarray): - axes = axes.ravel() + if len(axes) == 0: + return + if len(axes) == 1 and isinstance(axes[0], np.ndarray): + axes = axes[0].ravel() coords = np.array([ax.get_position().get_points().ravel() for ax in axes]) minx = np.min(coords[:,0]) maxx = np.max(coords[:,2]) @@ -233,7 +243,7 @@ def common_xspines(fig, axes=None): done = True -def common_yspines(fig, axes=None): +def common_yspines(fig, *axes): """ Reduce common y-spines, ytick labels, and ylabels. Keep spine and ytick labels only at the lowest axes and center the common ylabel. @@ -244,12 +254,14 @@ def common_yspines(fig, axes=None): The figure containing the axes. axes: None or sequence of matplotlib axes Axes whose yticks should be combined. - If None take all axes of the figure. + If not specified, take all axes of the figure. """ - if axes is None: + if len(axes) == 0: axes = fig.get_axes() - if isinstance(axes, np.ndarray): - axes = axes.ravel() + if len(axes) == 0: + return + if len(axes) == 1 and isinstance(axes[0], np.ndarray): + axes = axes[0].ravel() coords = np.array([ax.get_position().get_points().ravel() for ax in axes]) minx = np.min(coords[:,0]) maxx = np.max(coords[:,2])