Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/NCAR/ldcpy
Browse files Browse the repository at this point in the history
  • Loading branch information
allibco committed Sep 23, 2020
2 parents 88194dc + 2882382 commit ba068e0
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions ldcpy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
quantile=None,
calc_ssim=False,
contour_levs=24,
vert_plot=False,
):

self._ds = ds
Expand All @@ -71,6 +72,7 @@ def __init__(
self._quantile = None
self._calc_ssim = calc_ssim
self._contour_levs = contour_levs
self.vert_plot = vert_plot

def verify_plot_parameters(self):
if len(self._sets) < 2 and self._metric_type in [
Expand Down Expand Up @@ -223,7 +225,10 @@ def update_label(event_axes):
return

def spatial_plot(self, da_sets, titles):
nrows = int((da_sets.sets.size + 1) / 2)
if self.vert_plot:
nrows = int((da_sets.sets.size))
else:
nrows = int((da_sets.sets.size + 1) / 2)
if len(da_sets) == 1:
ncols = 1
else:
Expand All @@ -238,7 +243,10 @@ def spatial_plot(self, da_sets, titles):
for i in range(da_sets.sets.size):
cy_datas[i], lon_sets[i] = add_cyclic_point(da_sets[i], coord=da_sets[i]['lon'])

fig = plt.figure(dpi=300, figsize=(9, 2.5 * nrows))
if self.vert_plot:
fig = plt.figure(dpi=300, figsize=(4.5, 2.5 * nrows))
else:
fig = plt.figure(dpi=300, figsize=(9, 2.5 * nrows))

mymap = copy.copy(mpl.cm.get_cmap(f'{self._color}'))
mymap.set_under(color='black')
Expand All @@ -248,9 +256,14 @@ def spatial_plot(self, da_sets, titles):
axs = {}
psets = {}
for i in range(da_sets.sets.size):
axs[i] = plt.subplot(
nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=0.0)
)
if self.vert_plot:
axs[i] = plt.subplot(
nrows, 1, i + 1, projection=ccrs.Robinson(central_longitude=0.0)
)
else:
axs[i] = plt.subplot(
nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=0.0)
)

axs[i].set_facecolor('#39ff14')

Expand Down Expand Up @@ -292,7 +305,8 @@ def spatial_plot(self, da_sets, titles):
axs[i].set_title(titles[i])

# add colorbar
fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.95)
if self.vert_plot is False:
fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.95)

cbs = []
if not all([np.isnan(cy_datas[i]).all() for i in range(len(cy_datas))]):
Expand Down Expand Up @@ -346,7 +360,12 @@ def hist_plot(self, plot_data, title):
else:
mpl.pyplot.xlabel(f'{self._metric}')
mpl.pyplot.title(f'time-series histogram: {title[0]}')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)
if self.vert_plot:
plt.legend(loc='upper right', borderaxespad=1.0)
plt.rcParams.update({'font.size': 16})
else:
plt.rcParams.update({'font.size': 10})
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)

def periodogram_plot(self, plot_data, title):
plt.figure()
Expand All @@ -361,7 +380,13 @@ def periodogram_plot(self, plot_data, title):
freqs = np.array(range(1, int(dat.size / 2))) / dat.size

mpl.pyplot.plot(freqs, i, label=plot_data[j].sets.data)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)
if self.vert_plot:
plt.legend(loc='upper right', borderaxespad=1.0)
plt.rcParams.update({'font.size': 16})
else:
plt.rcParams.update({'font.size': 10})
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)

mpl.pyplot.title(f'periodogram: {title[0]}')
mpl.pyplot.ylabel('Spectrum')
mpl.pyplot.xlabel('Frequency')
Expand Down Expand Up @@ -415,6 +440,11 @@ def time_series_plot(
mpl.style.use('default')

plt.figure()
if self.vert_plot:
plt.rcParams.update({'font.size': 16})
else:
plt.rcParams.update({'font.size': 10})

for i in range(da_sets.sets.size):
if self._group_by is not None:
plt.plot(
Expand All @@ -438,7 +468,11 @@ def time_series_plot(
mpl.pyplot.plot(c_d_time, da_sets[i], f'C{i}', label=f'{da_sets.sets.data[i]}')
ax = plt.gca()

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)
if self.vert_plot:
plt.legend(loc='upper right', borderaxespad=1.0)
else:
plt.rcParams.update({'font.size': 10})
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)
mpl.pyplot.ylabel(plot_ylabel)
mpl.pyplot.yscale(self._scale)
self._label_offset(ax)
Expand All @@ -455,7 +489,7 @@ def time_series_plot(
]
unique_month_labels = list(dict.fromkeys(month_labels))
plt.gca().set_xticklabels(unique_month_labels)
plt.xticks(rotation=45)
plt.xticks(rotation=90)
# else:
# mpl.pyplot.xticks(
# pd.date_range(
Expand Down Expand Up @@ -510,6 +544,7 @@ def plot(
start=None,
end=None,
calc_ssim=False,
vert_plot=False,
):
"""
Plots the data given an xarray dataset
Expand Down Expand Up @@ -639,6 +674,7 @@ class in ldcpy.plot, for more information about the available metrics see ldcpy.
standardized_err,
quantile,
calc_ssim,
vert_plot=vert_plot,
)

mp.verify_plot_parameters()
Expand Down

0 comments on commit ba068e0

Please sign in to comment.