Skip to content

Commit

Permalink
[common] improved axes argument handling
Browse files Browse the repository at this point in the history
  • Loading branch information
janscience committed Nov 2, 2024
1 parent c2580ab commit ad71d2a
Showing 1 changed file with 54 additions and 12 deletions.
66 changes: 54 additions & 12 deletions src/plottools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,15 @@ def common_xlabels(fig, *axes):
axes = fig.get_axes()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
axs = []
for ax in axes:
if isinstance(ax, np.ndarray):
axs.extend(ax.ravel())
elif isinstance(ax, (tuple, list)):
axs.extend(ax)
else:
axs.append(ax)
axes = axs
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand Down Expand Up @@ -87,8 +94,15 @@ def common_ylabels(fig, *axes):
axes = fig.get_axes()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
axs = []
for ax in axes:
if isinstance(ax, np.ndarray):
axs.extend(ax.ravel())
elif isinstance(ax, (tuple, list)):
axs.extend(ax)
else:
axs.append(ax)
axes = axs
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
# center common ylabel:
minx = np.min(coords[:,0])
Expand Down Expand Up @@ -129,8 +143,15 @@ def common_xticks(fig, *axes):
axes = fig.get_axes()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
axs = []
for ax in axes:
if isinstance(ax, np.ndarray):
axs.extend(ax.ravel())
elif isinstance(ax, (tuple, list)):
axs.extend(ax)
else:
axs.append(ax)
axes = axs
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand Down Expand Up @@ -173,8 +194,15 @@ def common_yticks(fig, *axes):
axes = fig.get_axes()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
axs = []
for ax in axes:
if isinstance(ax, np.ndarray):
axs.extend(ax.ravel())
elif isinstance(ax, (tuple, list)):
axs.extend(ax)
else:
axs.append(ax)
axes = axs
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand Down Expand Up @@ -215,8 +243,15 @@ def common_xspines(fig, *axes):
axes = fig.get_axes()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
axs = []
for ax in axes:
if isinstance(ax, np.ndarray):
axs.extend(ax.ravel())
elif isinstance(ax, (tuple, list)):
axs.extend(ax)
else:
axs.append(ax)
axes = axs
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand Down Expand Up @@ -260,8 +295,15 @@ def common_yspines(fig, *axes):
axes = fig.get_axes()
if len(axes) == 0:
return
if len(axes) == 1 and isinstance(axes[0], np.ndarray):
axes = axes[0].ravel()
axs = []
for ax in axes:
if isinstance(ax, np.ndarray):
axs.extend(ax.ravel())
elif isinstance(ax, (tuple, list)):
axs.extend(ax)
else:
axs.append(ax)
axes = axs
coords = np.array([ax.get_position().get_points().ravel() for ax in axes])
minx = np.min(coords[:,0])
maxx = np.max(coords[:,2])
Expand Down

0 comments on commit ad71d2a

Please sign in to comment.