forked from SarahMeurer/ECG-classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_utils.py
175 lines (139 loc) · 5.49 KB
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
### ps backend params
#ps.papersize : letter ## auto, letter, legal, ledger, A0-A10, B0-B10
#ps.useafm : False ## use of afm fonts, results in small files
#ps.usedistiller : False ## can be: None, ghostscript or xpdf
## Experimental: may produce smaller files.
## xpdf intended for production of publication quality files,
## but requires ghostscript, xpdf and ps2eps
#ps.distiller.res : 6000 ## dpi
#ps.fonttype : 3 ## Output Type 3 (Type3) or Type 42 (TrueType)
## See https://matplotlib.org/users/customizing.html#the-matplotlibrc-file
## for more details on the paths which are checked for the configuration file.
import pathlib
from matplotlib.figure import Figure
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as mfont
import matplotlib.colors as mcolors
mpl.rcParams['ps.papersize'] = 'auto'
mpl.rcParams['ps.useafm'] = False
mpl.rcParams['ps.usedistiller'] = 'ghostscript'
mpl.rcParams['ps.distiller.res'] = 6000
mpl.rcParams['ps.fonttype'] = 42
# Easy 10-color list
color_list = list(mcolors.TABLEAU_COLORS)
# Requires a distribution of LaTeX installed (tested with MiKTeX on Windows 10)
# mpl.rcParams['text.usetex'] = True
# mpl.rcParams['pgf.texsystem'] = "xelatex"
# preamble = "\n".join(
# [r"\usepackage{amsmath}"])
# mpl.rcParams['text.latex.preamble'] = preamble
# mpl.rcParams['pgf.preamble'] = preamble
# figsize defaults:
# 9cm x 4.5cm (paper)
# 16cm x 10cm (square)
def format_figure(
fig: Figure, figsize='paper', times='Times New Roman', arial='Arial',
tight_scale='x', custom=None, tight_kws=None
):
# Assumes a single axis in figure (no support for plt.subplots)
ax = fig.axes[0]
# Default fonts
times_path = mfont.findfont(times)
arial_path = mfont.findfont(arial)
arial_dict = dict(
fname=arial_path,
family='sans-serif'
)
times_dict = dict(
fname=times_path,
family='serif'
)
# Define font properties
ticks_font = mfont.FontProperties(size=8, **arial_dict)
labels_font = mfont.FontProperties(size=10, **times_dict)
legend_font = mfont.FontProperties(size=8, **times_dict)
# Set figure sizes
if figsize == 'paper':
figsize_dims = np.array([9, 4.5])
elif figsize == 'square':
figsize_dims = np.array([10, 9])
ticks_font = mfont.FontProperties(size=10, **arial_dict)
labels_font = mfont.FontProperties(size=12, **times_dict)
legend_font = mfont.FontProperties(size=10, **times_dict)
# Convert from cm to inch
figsize_dims = np.array(figsize_dims) / 2.54
# Adjust figure size
fig.set_size_inches(figsize_dims)
# fig.set_dpi(200)
# Tight axis scaling
if tight_scale in ['x', 'y', 'both']:
ax.autoscale(enable=True, axis=tight_scale, tight=True)
# Adjust the font of x and y labels
ax.set_xlabel(ax.get_xlabel(), font=labels_font)
ax.set_ylabel(ax.get_ylabel(), font=labels_font)
# Legend font config
leg = ax.get_legend()
if leg is not None:
ax.legend(
prop=legend_font,
loc=leg._get_loc(),
borderaxespad=leg.borderaxespad
)
# Set the font settings for axis tick labels
for tick in ax.get_xticklabels():
tick.set_fontproperties(ticks_font)
for tick in ax.get_yticklabels():
tick.set_fontproperties(ticks_font)
xscale = ax.get_xscale()
yscale = ax.get_yscale()
if xscale == 'log':
for tick in ax.get_xminorticklabels():
tick.set_fontproperties(ticks_font)
if yscale == 'log':
for tick in ax.get_yminorticklabels():
tick.set_fontproperties(ticks_font)
# Add scientific notation ticks for axis in the (10 ** scilimits) interval
ticklabel_options = dict(
style='scientific', scilimits=(-2, 3),
useMathText=True, useOffset=True
)
if xscale == 'linear' and figsize != 'square':
ax.ticklabel_format(axis='x', **ticklabel_options)
if yscale == 'linear' and figsize != 'square':
ax.ticklabel_format(axis='y', **ticklabel_options)
# Custom code you edit in here
if custom:
pass
# Set the font settings for axis offset text (a.k.a. scientific notation)
ax.xaxis.offsetText.set_font(ticks_font)
ax.yaxis.offsetText.set_font(ticks_font)
# Really tight layout (default pad is ~1.0)
if tight_kws is None:
fig.tight_layout(pad=0.1, h_pad=None, w_pad=None, rect=None)
else:
fig.tight_layout(pad=0.1, **tight_kws)
return fig
def save_fig(fig, name, path=None, format=None, dpi=600, close=False, usetex=True, **kwargs):
# Make sure to save on a folder that exists
if path is None:
path = 'figures'
path = pathlib.Path(path)
path.mkdir(parents=True, exist_ok=True)
if format is None:
format = ['pdf', 'png']
fig = format_figure(fig, **kwargs)
backend = 'pgf' if usetex else None
if 'pdf' in format:
# fig.savefig(path / f'{name}.pdf', format='pdf', backend=backend)
fig.savefig(path / f'{name}.pdf', format='pdf', transparent=True)
if 'eps' in format: # no support for pgf backend in eps
fig.savefig(path / f'{name}.eps', format='eps', transparent=True)
if usetex and ('pgf' in format):
fig.savefig(path / f'{name}.pgf', format='pgf', transparent=True)
if 'png' in format:
fig.savefig(path / f'{name}.png', format='png', transparent=True, dpi=dpi, backend=backend)
if close:
plt.close(fig)
return fig