35
35
import matplotlib .pyplot as plt
36
36
37
37
38
- def set_arrow_style (ax , spines = 'lb' ):
38
+ def set_arrow_style (ax , spines = 'lb' , xpos = 0 , ypos = 0 ):
39
39
"""Turn the axes into arrows through the origin.
40
40
41
41
Note: Call this function *after* you have set the xlabel and ylabel.
@@ -46,6 +46,12 @@ def set_arrow_style(ax, spines='lb'):
46
46
Axes whose style is changed.
47
47
If figure, then apply manipulations on all axes of the figure.
48
48
If list of axes, apply manipulations on each of the given axes.
49
+ spines: str
50
+ String specifying which spines should be turned into arrows ('lrbt').
51
+ xpos: float
52
+ Position of the verical axis ('lr') on the x-axis.
53
+ ypos: float
54
+ Position of the horizontal axis ('bt') on the y-axis.
49
55
"""
50
56
# collect axes:
51
57
if isinstance (ax , (list , tuple , np .ndarray )):
@@ -58,27 +64,51 @@ def set_arrow_style(ax, spines='lb'):
58
64
if not isinstance (axs , (list , tuple )):
59
65
axs = [axs ]
60
66
for ax in axs :
61
- ax .show_spines (spines )
67
+ show_spines = ''
68
+ if ax .spines ['top' ].get_visible ():
69
+ show_spines += 't'
70
+ if ax .spines ['bottom' ].get_visible ():
71
+ show_spines += 'b'
72
+ if ax .spines ['left' ].get_visible ():
73
+ show_spines += 'l'
74
+ if ax .spines ['right' ].get_visible ():
75
+ show_spines += 'r'
62
76
ax .set_spines_outward (spines , 0 )
63
- ax .set_spines_zero (spines , 0 )
77
+ xspines = ''
78
+ yspines = ''
64
79
for s in spines :
65
80
if s in 'bt' :
81
+ xspines += s
82
+ if s == 'b' :
83
+ show_spines = show_spines .replace ('t' , '' )
84
+ elif s == 't' :
85
+ show_spines = show_spines .replace ('b' , '' )
66
86
ax .arrow_spines (s , flush = 0 , extend = 0 )
67
87
elif s in 'lr' :
88
+ yspines += s
89
+ if s == 'l' :
90
+ show_spines = show_spines .replace ('r' , '' )
91
+ elif s == 'r' :
92
+ show_spines = show_spines .replace ('l' , '' )
68
93
ax .arrow_spines (s , flush = 0 , extend = 1 )
69
- label = ax .xaxis .get_label ()
70
- x , y = label .get_position ()
71
- label .set_position ([1 , y ])
72
- label .set_horizontalalignment ('right' )
73
- label = ax .yaxis .get_label ()
74
- x , y = label .get_position ()
75
- label .set_position ([x , 1 ])
76
- label .set_rotation (0 )
77
- label .set_horizontalalignment ('right' )
78
- ax .xaxis .set_tick_params (direction = 'inout' ,
79
- length = 2 * plt .rcParams ['xtick.major.size' ])
80
- ax .yaxis .set_tick_params (direction = 'inout' , labelrotation = 0 ,
81
- length = 2 * plt .rcParams ['ytick.major.size' ])
94
+ ax .show_spines (show_spines )
95
+ if xspines :
96
+ ax .set_spines_zero (xspines , ypos )
97
+ label = ax .xaxis .get_label ()
98
+ x , y = label .get_position ()
99
+ label .set_position ([1 , y ])
100
+ label .set_horizontalalignment ('right' )
101
+ ax .xaxis .set_tick_params (direction = 'inout' ,
102
+ length = 2 * plt .rcParams ['xtick.major.size' ])
103
+ if yspines :
104
+ ax .set_spines_zero (yspines , xpos )
105
+ label = ax .yaxis .get_label ()
106
+ x , y = label .get_position ()
107
+ label .set_position ([x , 1 ])
108
+ label .set_rotation (0 )
109
+ label .set_horizontalalignment ('right' )
110
+ ax .yaxis .set_tick_params (direction = 'inout' , labelrotation = 0 ,
111
+ length = 2 * plt .rcParams ['ytick.major.size' ])
82
112
83
113
84
114
def axes_params (xmargin = None , ymargin = None , zmargin = None , color = None ,
0 commit comments