diff --git a/gpm/visualization/facetgrid.py b/gpm/visualization/facetgrid.py index d9a593a..99a5778 100644 --- a/gpm/visualization/facetgrid.py +++ b/gpm/visualization/facetgrid.py @@ -357,18 +357,33 @@ def map_dataarray( @abstractmethod def _remove_bottom_ticks_and_labels(self, ax): """Method removing axis ticks and labels on the bottom of the subplots.""" - + raise NotImplementedError + @abstractmethod def _remove_left_ticks_and_labels(self, ax): """Method removing axis ticks and labels on the left of the subplots.""" - + raise NotImplementedError + + def map_to_axes(self, func, **kwargs): + """Map a function to each axes.""" + n_rows, n_cols = self.axs.shape + missing_bottom_plots = [not ax.has_data() for ax in self.axs[n_rows - 1]] + idx_bottom_plots = np.where(missing_bottom_plots)[0] + has_missing_bottom_plots = len(idx_bottom_plots) > 0 + for i in range(0, n_rows): + for j in range(0, n_cols): + if has_missing_bottom_plots and i == n_rows and j in idx_bottom_plots: + continue + # Otherwise apply function + func(ax=self.axs[i, j], **kwargs) + def remove_bottom_ticks_and_labels(self): """Remove the bottom ticks and labels from each subplot.""" - self.map(lambda: self._remove_bottom_ticks_and_labels(plt.gca())) - + self.map_to_axes(func=self._remove_bottom_ticks_and_labels) + def remove_left_ticks_and_labels(self): """Remove the left ticks and labels from each subplot.""" - self.map(lambda: self._remove_left_ticks_and_labels(plt.gca())) + self.map_to_axes(func=self._remove_left_ticks_and_labels) def remove_duplicated_axis_labels(self): """Remove axis labels which are not located on the left or bottom of the figure.""" @@ -414,7 +429,16 @@ def add_colorbar(self, **cbar_kwargs) -> None: ) # Add ticklabel if ticklabels is not None: - self.cbar.ax.set_yticklabels(ticklabels) + # Retrieve ticks + ticks = cbar_kwargs.get("ticks", None) + if ticks is None: + ticks = self.cbar.get_ticks() + # Remove existing ticklabels + self.cbar.set_ticklabels([]) + self.cbar.set_ticklabels([], minor=True) + # Add custom ticklabels + self.cbar.set_ticks(ticks, labels=ticklabels) + # self.cbar.ax.set_yticklabels(ticklabels) def remove_title_dimension_prefix(self, row=True, col=True): """Remove the dimension prefix from the subplot labels.""" diff --git a/gpm/visualization/plot.py b/gpm/visualization/plot.py index 4324d08..05a3aad 100644 --- a/gpm/visualization/plot.py +++ b/gpm/visualization/plot.py @@ -545,7 +545,16 @@ def plot_colorbar(p, ax, cbar_kwargs=None): # Add colorbar cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs) if ticklabels is not None: - _ = cbar.ax.set_yticklabels(ticklabels) if orientation == "vertical" else cbar.ax.set_xticklabels(ticklabels) + # Retrieve ticks + ticks = cbar_kwargs.get("ticks", None) + if ticks is None: + ticks = cbar.get_ticks() + # Remove existing ticklabels + cbar.set_ticklabels([]) + cbar.set_ticklabels([], minor=True) + # Add custom ticklabels + p.colorbar.set_ticks(ticks, labels=ticklabels) + # _ = cbar.ax.set_yticklabels(ticklabels) if orientation == "vertical" else cbar.ax.set_xticklabels(ticklabels) return cbar