Skip to content

Commit 0bf4455

Browse files
committed
[axes] improved set_axes_spyle()
1 parent 7ea3fe4 commit 0bf4455

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

src/plottools/axes.py

+46-16
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import matplotlib.pyplot as plt
3636

3737

38-
def set_arrow_style(ax, spines='lb'):
38+
def set_arrow_style(ax, spines='lb', xpos=0, ypos=0):
3939
"""Turn the axes into arrows through the origin.
4040
4141
Note: Call this function *after* you have set the xlabel and ylabel.
@@ -46,6 +46,12 @@ def set_arrow_style(ax, spines='lb'):
4646
Axes whose style is changed.
4747
If figure, then apply manipulations on all axes of the figure.
4848
If list of axes, apply manipulations on each of the given axes.
49+
spines: str
50+
String specifying which spines should be turned into arrows ('lrbt').
51+
xpos: float
52+
Position of the verical axis ('lr') on the x-axis.
53+
ypos: float
54+
Position of the horizontal axis ('bt') on the y-axis.
4955
"""
5056
# collect axes:
5157
if isinstance(ax, (list, tuple, np.ndarray)):
@@ -58,27 +64,51 @@ def set_arrow_style(ax, spines='lb'):
5864
if not isinstance(axs, (list, tuple)):
5965
axs = [axs]
6066
for ax in axs:
61-
ax.show_spines(spines)
67+
show_spines = ''
68+
if ax.spines['top'].get_visible():
69+
show_spines += 't'
70+
if ax.spines['bottom'].get_visible():
71+
show_spines += 'b'
72+
if ax.spines['left'].get_visible():
73+
show_spines += 'l'
74+
if ax.spines['right'].get_visible():
75+
show_spines += 'r'
6276
ax.set_spines_outward(spines, 0)
63-
ax.set_spines_zero(spines, 0)
77+
xspines = ''
78+
yspines = ''
6479
for s in spines:
6580
if s in 'bt':
81+
xspines += s
82+
if s == 'b':
83+
show_spines = show_spines.replace('t', '')
84+
elif s == 't':
85+
show_spines = show_spines.replace('b', '')
6686
ax.arrow_spines(s, flush=0, extend=0)
6787
elif s in 'lr':
88+
yspines += s
89+
if s == 'l':
90+
show_spines = show_spines.replace('r', '')
91+
elif s == 'r':
92+
show_spines = show_spines.replace('l', '')
6893
ax.arrow_spines(s, flush=0, extend=1)
69-
label = ax.xaxis.get_label()
70-
x, y = label.get_position()
71-
label.set_position([1, y])
72-
label.set_horizontalalignment('right')
73-
label = ax.yaxis.get_label()
74-
x, y = label.get_position()
75-
label.set_position([x, 1])
76-
label.set_rotation(0)
77-
label.set_horizontalalignment('right')
78-
ax.xaxis.set_tick_params(direction='inout',
79-
length=2*plt.rcParams['xtick.major.size'])
80-
ax.yaxis.set_tick_params(direction='inout', labelrotation=0,
81-
length=2*plt.rcParams['ytick.major.size'])
94+
ax.show_spines(show_spines)
95+
if xspines:
96+
ax.set_spines_zero(xspines, ypos)
97+
label = ax.xaxis.get_label()
98+
x, y = label.get_position()
99+
label.set_position([1, y])
100+
label.set_horizontalalignment('right')
101+
ax.xaxis.set_tick_params(direction='inout',
102+
length=2*plt.rcParams['xtick.major.size'])
103+
if yspines:
104+
ax.set_spines_zero(yspines, xpos)
105+
label = ax.yaxis.get_label()
106+
x, y = label.get_position()
107+
label.set_position([x, 1])
108+
label.set_rotation(0)
109+
label.set_horizontalalignment('right')
110+
ax.yaxis.set_tick_params(direction='inout', labelrotation=0,
111+
length=2*plt.rcParams['ytick.major.size'])
82112

83113

84114
def axes_params(xmargin=None, ymargin=None, zmargin=None, color=None,

0 commit comments

Comments
 (0)