diff --git a/ldcpy/plot.py b/ldcpy/plot.py index 9bcb5fb7..b9793bfb 100644 --- a/ldcpy/plot.py +++ b/ldcpy/plot.py @@ -45,6 +45,7 @@ def __init__( quantile=None, calc_ssim=False, contour_levs=24, + vert_plot=False, ): self._ds = ds @@ -71,6 +72,7 @@ def __init__( self._quantile = None self._calc_ssim = calc_ssim self._contour_levs = contour_levs + self.vert_plot = vert_plot def verify_plot_parameters(self): if len(self._sets) < 2 and self._metric_type in [ @@ -223,7 +225,10 @@ def update_label(event_axes): return def spatial_plot(self, da_sets, titles): - nrows = int((da_sets.sets.size + 1) / 2) + if self.vert_plot: + nrows = int((da_sets.sets.size)) + else: + nrows = int((da_sets.sets.size + 1) / 2) if len(da_sets) == 1: ncols = 1 else: @@ -238,7 +243,10 @@ def spatial_plot(self, da_sets, titles): for i in range(da_sets.sets.size): cy_datas[i], lon_sets[i] = add_cyclic_point(da_sets[i], coord=da_sets[i]['lon']) - fig = plt.figure(dpi=300, figsize=(9, 2.5 * nrows)) + if self.vert_plot: + fig = plt.figure(dpi=300, figsize=(4.5, 2.5 * nrows)) + else: + fig = plt.figure(dpi=300, figsize=(9, 2.5 * nrows)) mymap = copy.copy(mpl.cm.get_cmap(f'{self._color}')) mymap.set_under(color='black') @@ -248,9 +256,14 @@ def spatial_plot(self, da_sets, titles): axs = {} psets = {} for i in range(da_sets.sets.size): - axs[i] = plt.subplot( - nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=0.0) - ) + if self.vert_plot: + axs[i] = plt.subplot( + nrows, 1, i + 1, projection=ccrs.Robinson(central_longitude=0.0) + ) + else: + axs[i] = plt.subplot( + nrows, ncols, i + 1, projection=ccrs.Robinson(central_longitude=0.0) + ) axs[i].set_facecolor('#39ff14') @@ -292,7 +305,8 @@ def spatial_plot(self, da_sets, titles): axs[i].set_title(titles[i]) # add colorbar - fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.95) + if self.vert_plot is False: + fig.subplots_adjust(left=0.1, right=0.9, bottom=0.2, top=0.95) cbs = [] if not all([np.isnan(cy_datas[i]).all() for i in range(len(cy_datas))]): @@ -346,7 +360,12 @@ def hist_plot(self, plot_data, title): else: mpl.pyplot.xlabel(f'{self._metric}') mpl.pyplot.title(f'time-series histogram: {title[0]}') - plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) + if self.vert_plot: + plt.legend(loc='upper right', borderaxespad=1.0) + plt.rcParams.update({'font.size': 16}) + else: + plt.rcParams.update({'font.size': 10}) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) def periodogram_plot(self, plot_data, title): plt.figure() @@ -361,7 +380,13 @@ def periodogram_plot(self, plot_data, title): freqs = np.array(range(1, int(dat.size / 2))) / dat.size mpl.pyplot.plot(freqs, i, label=plot_data[j].sets.data) - plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) + if self.vert_plot: + plt.legend(loc='upper right', borderaxespad=1.0) + plt.rcParams.update({'font.size': 16}) + else: + plt.rcParams.update({'font.size': 10}) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) + mpl.pyplot.title(f'periodogram: {title[0]}') mpl.pyplot.ylabel('Spectrum') mpl.pyplot.xlabel('Frequency') @@ -415,6 +440,11 @@ def time_series_plot( mpl.style.use('default') plt.figure() + if self.vert_plot: + plt.rcParams.update({'font.size': 16}) + else: + plt.rcParams.update({'font.size': 10}) + for i in range(da_sets.sets.size): if self._group_by is not None: plt.plot( @@ -438,7 +468,11 @@ def time_series_plot( mpl.pyplot.plot(c_d_time, da_sets[i], f'C{i}', label=f'{da_sets.sets.data[i]}') ax = plt.gca() - plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) + if self.vert_plot: + plt.legend(loc='upper right', borderaxespad=1.0) + else: + plt.rcParams.update({'font.size': 10}) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0) mpl.pyplot.ylabel(plot_ylabel) mpl.pyplot.yscale(self._scale) self._label_offset(ax) @@ -455,7 +489,7 @@ def time_series_plot( ] unique_month_labels = list(dict.fromkeys(month_labels)) plt.gca().set_xticklabels(unique_month_labels) - plt.xticks(rotation=45) + plt.xticks(rotation=90) # else: # mpl.pyplot.xticks( # pd.date_range( @@ -510,6 +544,7 @@ def plot( start=None, end=None, calc_ssim=False, + vert_plot=False, ): """ Plots the data given an xarray dataset @@ -639,6 +674,7 @@ class in ldcpy.plot, for more information about the available metrics see ldcpy. standardized_err, quantile, calc_ssim, + vert_plot=vert_plot, ) mp.verify_plot_parameters()