From f3c5f5adef6e7b4421484adb9849af4e94b3bf2e Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Mon, 23 Dec 2024 11:23:50 -0800 Subject: [PATCH 01/12] Switch to `mosaic` for planar plotting. - Interface to the `plot_horiz_field` has been simplified slightly. - The mask array now must be the same shape as the input field array. - Added a framework function to convert a cell mask to and edge mask. --- polaris/mpas/__init__.py | 1 + polaris/mpas/mask.py | 35 +++ .../ocean/tasks/baroclinic_channel/init.py | 10 +- .../tasks/baroclinic_channel/rpe/analysis.py | 4 +- polaris/ocean/tasks/baroclinic_channel/viz.py | 12 +- polaris/ocean/tasks/barotropic_gyre/init.py | 4 +- polaris/ocean/tasks/ice_shelf_2d/viz.py | 42 +-- .../ocean/tasks/inertial_gravity_wave/viz.py | 14 +- .../ocean/tasks/manufactured_solution/viz.py | 12 +- polaris/viz/planar.py | 282 ++++-------------- 10 files changed, 148 insertions(+), 268 deletions(-) create mode 100644 polaris/mpas/mask.py diff --git a/polaris/mpas/__init__.py b/polaris/mpas/__init__.py index 1a4ab173a..8e9e32956 100644 --- a/polaris/mpas/__init__.py +++ b/polaris/mpas/__init__.py @@ -1,2 +1,3 @@ from polaris.mpas.area import area_for_field +from polaris.mpas.mask import cell_mask_2_edge_mask from polaris.mpas.time import time_index_from_xtime, time_since_start diff --git a/polaris/mpas/mask.py b/polaris/mpas/mask.py new file mode 100644 index 000000000..c2ff4690d --- /dev/null +++ b/polaris/mpas/mask.py @@ -0,0 +1,35 @@ + +def cell_mask_2_edge_mask(ds_mesh, cell_mask): + """Convert a cell mask to edge mask using mesh connectivity information + + True corresponds to valid cells and False are invalid cells + + Parameters + ---------- + ds_mesh : xarray.Dataset + The MPAS mesh + + cell_mask : xarray.DataArray + The cell mask we want to convert to an edge mask + + + Returns + ------- + edge_mask : xarray.DataArray + The edge mask corresponding to the input cell mask + """ + + # test if any are False + if ~cell_mask.any(): + return ds_mesh.nEdges > -1 + + # zero index the connectivity array + cellsOnEdge = (ds_mesh.cellsOnEdge - 1) + + # using nCells (dim) instead of indexToCellID since it's already 0 indexed + masked_cells = ds_mesh.nCells.where(~cell_mask, drop=True).astype(int) + + # use inverse so True/False convention matches input cell_mask + edge_mask = ~cellsOnEdge.isin(masked_cells).any("TWO") + + return edge_mask diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 5093857ae..24c97d7f7 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -7,6 +7,7 @@ from polaris import Step from polaris.mesh.planar import compute_planar_hex_nx_ny +from polaris.mpas import cell_mask_2_edge_mask from polaris.ocean.vertical import init_vertical_coord from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -163,10 +164,11 @@ def run(self): write_netcdf(ds, 'initial_state.nc') cell_mask = ds.maxLevelCell >= 1 + edge_mask = cell_mask_2_edge_mask(ds, cell_mask) - plot_horiz_field(ds, ds_mesh, 'normalVelocity', + plot_horiz_field(ds_mesh, ds['normalVelocity'], 'initial_normal_velocity.png', cmap='cmo.balance', - show_patch_edges=True, cell_mask=cell_mask) + show_patch_edges=True, field_mask=edge_mask) y_min = ds_mesh.yVertex.min().values y_max = ds_mesh.yVertex.max().values @@ -191,6 +193,6 @@ def run(self): vmin=vmin, vmax=vmax, cmap='cmo.thermal', colorbar_label=r'$^\circ$C', color_start_and_end=True) - plot_horiz_field(ds, ds_mesh, 'temperature', 'initial_temperature.png', + plot_horiz_field(ds_mesh, ds['temperature'], 'initial_temperature.png', vmin=vmin, vmax=vmax, cmap='cmo.thermal', - cell_mask=cell_mask, transect_x=x, transect_y=y) + field_mask=cell_mask, transect_x=x, transect_y=y) diff --git a/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py b/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py index 3d8c3ee98..e99e23124 100644 --- a/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py +++ b/polaris/ocean/tasks/baroclinic_channel/rpe/analysis.py @@ -104,10 +104,10 @@ def run(self): time_index = np.argmin(np.abs(times - time)) cell_mask = ds_init.maxLevelCell >= 1 - plot_horiz_field(ds, ds_mesh, 'temperature', ax=ax, + plot_horiz_field(ds_mesh, ds['temperature'], ax=ax, cmap='cmo.thermal', t_index=time_index, vmin=min_temp, vmax=max_temp, - cmap_title='SST (C)', cell_mask=cell_mask) + cmap_title='SST (C)', field_mask=cell_mask) ax.set_title(f'day {times[time_index]:g}, $\\nu_h=${nu:g}') plt.savefig(output_filename) diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index 617253e71..eb79c324a 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -3,6 +3,7 @@ import xarray as xr from polaris import Step +from polaris.mpas import cell_mask_2_edge_mask from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -43,13 +44,14 @@ def run(self): ds = xr.load_dataset('output.nc') t_index = ds.sizes['Time'] - 1 cell_mask = ds_init.maxLevelCell >= 1 + edge_mask = cell_mask_2_edge_mask(ds_init, cell_mask) max_velocity = np.max(np.abs(ds.normalVelocity.values)) - plot_horiz_field(ds, ds_mesh, 'normalVelocity', + plot_horiz_field(ds_mesh, ds['normalVelocity'], 'final_normalVelocity.png', t_index=t_index, vmin=-max_velocity, vmax=max_velocity, cmap='cmo.balance', show_patch_edges=True, - cell_mask=cell_mask) + field_mask=edge_mask) y_min = ds_mesh.yVertex.min().values y_max = ds_mesh.yVertex.max().values @@ -76,7 +78,7 @@ def run(self): vmin=vmin, vmax=vmax, cmap='cmo.thermal', colorbar_label=r'$^\circ$C', color_start_and_end=True) - plot_horiz_field(ds, ds_mesh, 'temperature', 'final_temperature.png', + plot_horiz_field(ds_mesh, ds['temperature'], 'final_temperature.png', t_index=t_index, vmin=vmin, vmax=vmax, - cmap='cmo.thermal', cell_mask=cell_mask, transect_x=x, - transect_y=y) + cmap='cmo.thermal', field_mask=cell_mask, + transect_x=x, transect_y=y) diff --git a/polaris/ocean/tasks/barotropic_gyre/init.py b/polaris/ocean/tasks/barotropic_gyre/init.py index e36448b14..fa79f7604 100644 --- a/polaris/ocean/tasks/barotropic_gyre/init.py +++ b/polaris/ocean/tasks/barotropic_gyre/init.py @@ -140,7 +140,7 @@ def run(self): cell_mask = ds.maxLevelCell >= 1 - plot_horiz_field(ds_forcing, ds_mesh, 'windStressZonal', + plot_horiz_field(ds_mesh, ds_forcing['windStressZonal'], 'forcing_wind_stress_zonal.png', cmap='cmo.balance', - show_patch_edges=True, cell_mask=cell_mask, + show_patch_edges=True, field_mask=cell_mask, vmin=-0.1, vmax=0.1) diff --git a/polaris/ocean/tasks/ice_shelf_2d/viz.py b/polaris/ocean/tasks/ice_shelf_2d/viz.py index 84d334935..a32ece4a6 100644 --- a/polaris/ocean/tasks/ice_shelf_2d/viz.py +++ b/polaris/ocean/tasks/ice_shelf_2d/viz.py @@ -5,6 +5,7 @@ import xarray as xr from polaris import Step +from polaris.mpas import cell_mask_2_edge_mask from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -118,9 +119,10 @@ def run(self): # Plot water column thickness horizontal ds_init cell_mask = ds_init.maxLevelCell >= 1 - plot_horiz_field(ds_horiz, ds_mesh, 'columnThickness', + edge_mask = cell_mask_2_edge_mask(ds_init, cell_mask) + plot_horiz_field(ds_mesh, ds_horiz['columnThickness'], 'H_horiz_init.png', t_index=None, - cell_mask=cell_mask) + field_mask=cell_mask) time_index = -1 # Plot the final state ds_horiz = self._process_ds(ds, ds_ice, ds_init, @@ -130,27 +132,27 @@ def run(self): vmax_del_ssh = np.max(ds_horiz.delSsh.values) vmax_del_p = np.amax(ds_horiz.delLandIcePressure.values) # Plot water column thickness horizontal - plot_horiz_field(ds_horiz, ds_mesh, 'columnThickness', + plot_horiz_field(ds_mesh, ds_horiz['columnThickness'], f'H_horiz_t{time_index}.png', t_index=None, - cell_mask=cell_mask) - plot_horiz_field(ds_horiz, ds_mesh, 'landIceFreshwaterFlux', + field_mask=cell_mask) + plot_horiz_field(ds_mesh, ds_horiz['landIceFreshwaterFlux'], f'melt_horiz_t{time_index}.png', t_index=None, - cell_mask=cell_mask) + field_mask=cell_mask) if 'wettingVelocityFactor' in ds_horiz.keys(): - plot_horiz_field(ds_horiz, ds_mesh, 'wettingVelocityFactor', + plot_horiz_field(ds_mesh, ds_horiz['wettingVelocityFactor'], f'wet_horiz_t{time_index}.png', t_index=None, - z_index=None, cell_mask=cell_mask, + z_index=None, field_mask=edge_mask, vmin=0, vmax=1, cmap='cmo.ice') # Plot difference in ssh - plot_horiz_field(ds_horiz, ds_mesh, 'delSsh', + plot_horiz_field(ds_mesh, ds_horiz['delSsh'], f'del_ssh_horiz_t{time_index}.png', t_index=None, - cell_mask=cell_mask, + field_mask=cell_mask, vmin=vmin_del_ssh, vmax=vmax_del_ssh) # Plot difference in land ice pressure - plot_horiz_field(ds_horiz, ds_mesh, 'delLandIcePressure', + plot_horiz_field(ds_mesh, ds_horiz['delLandIcePressure'], f'del_land_ice_pressure_horiz_t{time_index}.png', - t_index=None, cell_mask=cell_mask, + t_index=None, field_mask=cell_mask, vmin=-vmax_del_p, vmax=vmax_del_p, cmap='cmo.balance') @@ -163,24 +165,24 @@ def run(self): max_level_cell=ds_init.maxLevelCell - 1, spherical=False) - plot_horiz_field(ds, ds_mesh, 'velocityX', + plot_horiz_field(ds_mesh, ds['velocityX'], f'u_surf_horiz_t{time_index}.png', t_index=time_index, - z_index=0, cell_mask=cell_mask, + z_index=0, field_mask=cell_mask, vmin=-vmax_uv, vmax=vmax_uv, cmap_title=r'm/s', cmap='cmo.balance') - plot_horiz_field(ds, ds_mesh, 'velocityX', + plot_horiz_field(ds_mesh, ds['velocityX'], f'u_bot_horiz_t{time_index}.png', t_index=time_index, - z_index=-1, cell_mask=cell_mask, + z_index=-1, field_mask=cell_mask, vmin=-vmax_uv, vmax=vmax_uv, cmap_title=r'm/s', cmap='cmo.balance') - plot_horiz_field(ds, ds_mesh, 'velocityY', + plot_horiz_field(ds_mesh, ds['velocityY'], f'v_surf_horiz_t{time_index}.png', t_index=time_index, - z_index=0, cell_mask=cell_mask, + z_index=0, field_mask=cell_mask, vmin=-vmax_uv, vmax=vmax_uv, cmap_title=r'm/s', cmap='cmo.balance') - plot_horiz_field(ds, ds_mesh, 'velocityY', + plot_horiz_field(ds_mesh, ds['velocityY'], f'v_bot_horiz_t{time_index}.png', t_index=time_index, - z_index=-1, cell_mask=cell_mask, + z_index=-1, field_mask=cell_mask, vmin=-vmax_uv, vmax=vmax_uv, cmap_title=r'm/s', cmap='cmo.balance') plot_transect(ds_transect, diff --git a/polaris/ocean/tasks/inertial_gravity_wave/viz.py b/polaris/ocean/tasks/inertial_gravity_wave/viz.py index 673d96521..18e6932e0 100644 --- a/polaris/ocean/tasks/inertial_gravity_wave/viz.py +++ b/polaris/ocean/tasks/inertial_gravity_wave/viz.py @@ -162,19 +162,19 @@ def run(self): error_range = np.max(np.abs(ds.ssh_error.values)) cell_mask = ds_init.maxLevelCell >= 1 - patches, patch_mask = plot_horiz_field( - ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance', + descriptor = plot_horiz_field( + ds_mesh, ds['ssh'], ax=axes[i, 0], cmap='cmo.balance', t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0, - cmap_title="SSH (m)", cell_mask=cell_mask) - plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1], + cmap_title="SSH (m)", field_mask=cell_mask) + plot_horiz_field(ds_mesh, ds['ssh_exact'], ax=axes[i, 1], cmap='cmo.balance', vmin=-eta0, vmax=eta0, cmap_title="SSH (m)", - patches=patches, patch_mask=patch_mask) - plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2], + descriptor=descriptor) + plot_horiz_field(ds_mesh, ds['ssh_error'], ax=axes[i, 2], cmap='cmo.balance', cmap_title=r"$\Delta$ SSH (m)", vmin=-error_range, vmax=error_range, - patches=patches, patch_mask=patch_mask) + descriptor=descriptor) axes[0, 0].set_title('Numerical solution') axes[0, 1].set_title('Analytical solution') diff --git a/polaris/ocean/tasks/manufactured_solution/viz.py b/polaris/ocean/tasks/manufactured_solution/viz.py index a87306da9..928dc9ad3 100644 --- a/polaris/ocean/tasks/manufactured_solution/viz.py +++ b/polaris/ocean/tasks/manufactured_solution/viz.py @@ -175,18 +175,18 @@ def run(self): error_range = np.max(np.abs(ds.ssh_error.values)) cell_mask = ds_init.maxLevelCell >= 1 - patches, patch_mask = plot_horiz_field( + descriptor = plot_horiz_field( ds, ds_mesh, 'ssh', ax=axes[i, 0], cmap='cmo.balance', t_index=ds.sizes["Time"] - 1, vmin=-eta0, vmax=eta0, - cmap_title="SSH", cell_mask=cell_mask) - plot_horiz_field(ds, ds_mesh, 'ssh_exact', ax=axes[i, 1], + cmap_title="SSH", field_mask=cell_mask) + plot_horiz_field(ds_mesh, ds['ssh_exact'], ax=axes[i, 1], cmap='cmo.balance', vmin=-eta0, vmax=eta0, cmap_title="SSH", - patches=patches, patch_mask=patch_mask) - plot_horiz_field(ds, ds_mesh, 'ssh_error', ax=axes[i, 2], + descriptor=descriptor) + plot_horiz_field(ds_mesh, ds['ssh_error'], ax=axes[i, 2], cmap='cmo.balance', cmap_title="dSSH", vmin=-error_range, vmax=error_range, - patches=patches, patch_mask=patch_mask) + descriptor=descriptor) axes[0, 0].set_title('Numerical solution') axes[0, 1].set_title('Analytical solution') diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index 28b41de52..86af021b6 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -1,39 +1,35 @@ import os import cmocean # noqa: F401 +import matplotlib import matplotlib.pyplot as plt +import mosaic import numpy as np -from matplotlib.collections import PatchCollection from matplotlib.colors import LogNorm -from matplotlib.patches import Polygon from polaris.viz.style import use_mplstyle -def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 +def plot_horiz_field(ds_mesh, field, out_file_name=None, # noqa: C901 ax=None, title=None, t_index=None, z_index=None, vmin=None, vmax=None, show_patch_edges=False, cmap=None, cmap_set_under=None, cmap_set_over=None, cmap_scale='linear', cmap_title=None, figsize=None, - vert_dim='nVertLevels', cell_mask=None, patches=None, - patch_mask=None, transect_x=None, transect_y=None, - transect_color='black', transect_start='red', - transect_end='green', transect_linewidth=2., - transect_markersize=12.): + vert_dim='nVertLevels', field_mask=None, descriptor=None, + transect_x=None, transect_y=None, transect_color='black', + transect_start='red', transect_end='green', + transect_linewidth=2., transect_markersize=12.): """ Plot a horizontal field from a planar domain using x,y coordinates at a single time and depth slice. Parameters ---------- - ds : xarray.Dataset - A data set containing fieldName - ds_mesh : xarray.Dataset A data set containing horizontal mesh variables - field_name : str - The name of the variable to plot, which must be present in ds + data_array : xarray.DataArray + The data array to plot out_file_name : str, optional The path to which the plot image should be written @@ -83,17 +79,11 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 vert_dim : str, optional Name of the vertical dimension - cell_mask : numpy.ndarray, optional - A ``bool`` mask indicating where cells are valid, used to mask fields - on both cells and edges. Not used if ``patches`` and ``patch_mask`` - are supplied - - patches : list of numpy.ndarray, optional - Patches from a previous call to ``plot_horiz_field()`` + field_mask : xarray.DataArray, optional + A ``bool`` mask indicating where the `data_array` is valid. - patch_mask : numpy.ndarray, optional - A mask of where the field has patches from a previous call to - ``plot_horiz_field()`` + descriptor : mosaic.Descriptor, optional + Descriptor from a previous call to ``plot_horiz_field()`` transect_x : numpy.ndarray or xarray.DataArray, optional The x coordinates of a transect to plot on the @@ -118,22 +108,10 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 Returns ------- - patches : list of numpy.ndarray - Patches to reuse for future plots. Patches for cells can only be - reused for other plots on cells and similarly for edges. - - patch_mask : numpy.ndarray - A mask used to select entries in the field that have patches + descriptor : mosaic.Descriptor + For reuse with future plots. Patches are cached, so the Descriptor only + needs to be created once per mesh file. """ - if field_name not in ds: - raise ValueError( - f'{field_name} must be present in ds before plotting.') - - if patches is not None: - if patch_mask is None: - raise ValueError('You must supply both patches and patch_mask ' - 'from a previous call to plot_horiz_field()') - if (transect_x is None) != (transect_y is None): raise ValueError('You must supply both transect_x and transect_y or ' 'neither') @@ -146,16 +124,14 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if create_fig: if out_file_name is None: - out_file_name = f'{field_name}.png' + out_file_name = f'{field.name}.png' try: os.makedirs(os.path.dirname(out_file_name)) except OSError: pass if title is None: - title = field_name - - field = ds[field_name] + title = field.name if 'Time' in field.dims and t_index is None: t_index = 0 @@ -166,38 +142,31 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if z_index is not None: field = field.isel({vert_dim: z_index}) - if patches is None: - if cell_mask is None: - cell_mask = np.ones_like(field, dtype='bool') - if 'nCells' in field.dims: - patch_mask = cell_mask - patches, patch_mask = _compute_cell_patches(ds_mesh, patch_mask) - elif 'nEdges' in field.dims: - patch_mask = _edge_mask_from_cell_mask(ds_mesh, cell_mask) - patches, patch_mask = _compute_edge_patches(ds_mesh, patch_mask) - else: - raise ValueError('Cannot plot a field without dim nCells or ' - 'nEdges') - local_patches = PatchCollection(patches, alpha=1.) - local_patches.set_array(field[patch_mask]) + if descriptor is None: + descriptor = mosaic.Descriptor(ds_mesh) + + pcolor_kwargs = dict( + cmap=None, edgecolor='face', norm=None, vmin=vmin, vmax=vmax + ) + if cmap is not None: - local_patches.set_cmap(cmap) - if cmap_set_under is not None: - current_cmap = local_patches.get_cmap() - current_cmap.set_under(cmap_set_under) - if cmap_set_over is not None: - current_cmap = local_patches.get_cmap() - current_cmap.set_over(cmap_set_over) + if isinstance(cmap, str): + cmap = matplotlib.colormaps[cmap] + if cmap_set_under is not None: + cmap.set_under(cmap_set_under) + if cmap_set_over is not None: + cmap.set_over(cmap_set_over) + + pcolor_kwargs['cmap'] = cmap if show_patch_edges: - local_patches.set_edgecolor('black') - else: - local_patches.set_edgecolor('face') - local_patches.set_clim(vmin=vmin, vmax=vmax) + pcolor_kwargs['edgecolor'] = 'black' + pcolor_kwargs['linewidth'] = 0.25 if cmap_scale == 'log': - local_patches.set_norm(LogNorm(vmin=max(1e-10, vmin), - vmax=vmax, clip=False)) + pcolor_kwargs['norm'] = LogNorm( + vmin=max(1e-10, vmin), vmax=vmax, clip=False + ) if figsize is None: width = ds_mesh.xCell.max() - ds_mesh.xCell.min() @@ -210,18 +179,35 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 if create_fig: plt.figure(figsize=figsize) ax = plt.subplot(111) - ax.add_collection(local_patches) + + if field_mask is not None: + + if field_mask.shape != field.shape: + raise ValueError(f"The shape of `field_mask`: {field_mask.shape} " + f"does match shape of `field array`: " + f"{field.shape} make sure both arrays are defined" + f" at the same location") + + if np.any(~field_mask): + field = field.where(field_mask) + + collection = mosaic.polypcolor(ax, descriptor, field, **pcolor_kwargs) + ax.set_xlabel('x (km)') ax.set_ylabel('y (km)') ax.set_aspect('equal') ax.autoscale(tight=True) - cbar = plt.colorbar(local_patches, extend='both', shrink=0.7, ax=ax) + # scale ticks to be in kilometers + ax.xaxis.set_major_formatter(lambda x, pos: f'{x / 1e3:g}') + ax.yaxis.set_major_formatter(lambda x, pos: f'{x / 1e3:g}') + + cbar = plt.colorbar(collection, extend='both', shrink=0.7, ax=ax) if cmap_title is not None: cbar.set_label(cmap_title) if transect_x is not None: - transect_x = 1e-3 * transect_x - transect_y = 1e-3 * transect_y + transect_x = transect_x + transect_y = transect_y ax.plot(transect_x, transect_y, color=transect_color, linewidth=transect_linewidth) if transect_start is not None: @@ -235,152 +221,4 @@ def plot_horiz_field(ds, ds_mesh, field_name, out_file_name=None, # noqa: C901 plt.savefig(out_file_name, bbox_inches='tight', pad_inches=0.2) plt.close() - return patches, patch_mask - - -def _edge_mask_from_cell_mask(ds, cell_mask): - cells_on_edge = ds.cellsOnEdge - 1 - valid = cells_on_edge >= 0 - # the edge mask is True if either adjacent cell is valid and its mask is - # True - edge_mask = np.logical_or( - np.logical_and(valid[:, 0], cell_mask[cells_on_edge[:, 0]]), - np.logical_and(valid[:, 1], cell_mask[cells_on_edge[:, 1]])) - return edge_mask - - -def _compute_cell_patches(ds, mask): - patches = [] - num_vertices_on_cell = ds.nEdgesOnCell.values - vertices_on_cell = ds.verticesOnCell.values - 1 - x_cell = ds.xCell.values - y_cell = ds.yCell.values - x_vertex = ds.xVertex.values - y_vertex = ds.yVertex.values - - is_periodic = ds.attrs['is_periodic'].strip() == 'YES' - is_x_periodic = False - is_y_periodic = False - if is_periodic: - x_period = ds.attrs['x_period'] - if x_period > 0.: - is_x_periodic = True - y_period = ds.attrs['y_period'] - if y_period > 0.: - is_y_periodic = True - - for cell_index in range(ds.sizes['nCells']): - if not mask[cell_index]: - continue - num_vertices = num_vertices_on_cell[cell_index] - vertex_indices = vertices_on_cell[cell_index, :num_vertices] - vertices = np.zeros((num_vertices, 2)) - vertices[:, 0] = 1e-3 * x_vertex[vertex_indices] - vertices[:, 1] = 1e-3 * y_vertex[vertex_indices] - - if is_x_periodic: - # Fix cells that span the periodic boundaries - for count, vertex_index in enumerate(vertex_indices): - vertices = _fix_vertices(vertices, - loc_center=x_cell[cell_index] * 1e-3, - index=count, - period=x_period * 1e-3, - period_index=0) - if is_y_periodic: - # Fix cells that span the periodic boundaries - for count, vertex_index in enumerate(vertex_indices): - vertices = _fix_vertices(vertices, - loc_center=y_cell[cell_index] * 1e-3, - index=count, - period=y_period * 1e-3, - period_index=1) - polygon = Polygon(vertices, closed=True) - patches.append(polygon) - - return patches, mask - - -def _compute_edge_patches(ds, mask): - patches = [] - cells_on_edge = ds.cellsOnEdge.values - 1 - vertices_on_edge = ds.verticesOnEdge.values - 1 - x_cell = ds.xCell.values - y_cell = ds.yCell.values - x_edge = ds.xEdge.values - y_edge = ds.yEdge.values - x_vertex = ds.xVertex.values - y_vertex = ds.yVertex.values - boundary_vertex = ds.boundaryVertex.values - - is_periodic = ds.attrs['is_periodic'].strip() == 'YES' - is_x_periodic = False - is_y_periodic = False - if is_periodic: - x_period = ds.attrs['x_period'] - if x_period > 0.: - is_x_periodic = True - y_period = ds.attrs['y_period'] - if y_period > 0.: - is_y_periodic = True - - for edge_index in range(ds.sizes['nEdges']): - if not mask[edge_index]: - continue - cell_indices = cells_on_edge[edge_index] - vertex_indices = vertices_on_edge[edge_index, :] - # Remove edges on boundaries because they are always invalid - if any(boundary_vertex[vertex_indices]): - mask[edge_index] = 0 - continue - vertices = np.zeros((4, 2)) - vertices[0, 0] = 1e-3 * x_vertex[vertex_indices[0]] - vertices[0, 1] = 1e-3 * y_vertex[vertex_indices[0]] - vertices[1, 0] = 1e-3 * x_cell[cell_indices[0]] - vertices[1, 1] = 1e-3 * y_cell[cell_indices[0]] - vertices[2, 0] = 1e-3 * x_vertex[vertex_indices[1]] - vertices[2, 1] = 1e-3 * y_vertex[vertex_indices[1]] - vertices[3, 0] = 1e-3 * x_cell[cell_indices[1]] - vertices[3, 1] = 1e-3 * y_cell[cell_indices[1]] - if is_x_periodic: - # Fix cells that span the periodic boundaries - for count, vertex_index in enumerate(vertex_indices): - new_index = np.where(count == 0, 0, 2) - vertices = _fix_kite(vertices, - loc_center=x_edge[edge_index] * 1e-3, - index=new_index, - period=x_period * 1e-3, - period_index=0) - if is_y_periodic: - # Fix cells that span the periodic boundaries - for count, vertex_index in enumerate(vertex_indices): - new_index = np.where(count == 0, 0, 2) - vertices = _fix_kite(vertices, - loc_center=y_edge[edge_index] * 1e-3, - index=new_index, - period=y_period * 1e-3, - period_index=1) - polygon = Polygon(vertices, closed=True) - patches.append(polygon) - - return patches, mask - - -def _fix_vertices(vertices, loc_center, index, period, period_index): - if vertices[index, period_index] - loc_center > 0.5 * period: - vertices[index, period_index] += -period - elif vertices[index, period_index] - loc_center < -0.5 * period: - vertices[index, period_index] += period - return vertices - - -def _fix_kite(vertices, loc_center, index, period, period_index): - if vertices[index, period_index] - loc_center > 0.5 * period: - vertices[index, period_index] += -period - elif vertices[index, period_index] - loc_center < -0.5 * period: - vertices[index, period_index] += period - # We need to check the cell node of the kite as well - if vertices[index + 1, period_index] - loc_center > 0.5 * period: - vertices[index + 1, period_index] += -period - elif vertices[index + 1, period_index] - loc_center < -0.5 * period: - vertices[index + 1, period_index] += period - return vertices + return descriptor From ebbb9662e7fff947913f2d0fb4b31f2e621b85ab Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Tue, 24 Dec 2024 13:51:30 -0800 Subject: [PATCH 02/12] Swith to `mosaic` for spherical mesh plotting. Also simplified the config options for the colorbar options, which makes the spherical and lat/lon plots have the same options --- .../ocean/tasks/cosine_bell/cosine_bell.cfg | 4 +- .../ocean/tasks/geostrophic/geostrophic.cfg | 16 ++-- .../sphere_transport/sphere_transport.cfg | 16 ++-- polaris/viz/spherical.py | 85 ++++++++----------- 4 files changed, 52 insertions(+), 69 deletions(-) diff --git a/polaris/ocean/tasks/cosine_bell/cosine_bell.cfg b/polaris/ocean/tasks/cosine_bell/cosine_bell.cfg index 06884ce13..281a6d633 100644 --- a/polaris/ocean/tasks/cosine_bell/cosine_bell.cfg +++ b/polaris/ocean/tasks/cosine_bell/cosine_bell.cfg @@ -90,5 +90,5 @@ colormap_name = viridis # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = 0.0, 1.0 +# A dictionary with keywords for the norm +norm_args = {'vmin': 0., 'vmax': 1.} diff --git a/polaris/ocean/tasks/geostrophic/geostrophic.cfg b/polaris/ocean/tasks/geostrophic/geostrophic.cfg index f48a573af..028a78f66 100644 --- a/polaris/ocean/tasks/geostrophic/geostrophic.cfg +++ b/polaris/ocean/tasks/geostrophic/geostrophic.cfg @@ -66,8 +66,8 @@ colormap_name = cmo.deep # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = 1000.0, 3000.0 +# A dictionary with keywords for the norm +norm_args = {'vmin': 1000.0, 'vmax': 3000.0} # colorbar label label = water-column thickness (m) @@ -82,8 +82,8 @@ colormap_name = cmo.delta # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = -40.0, 40.0 +# A dictionary with keywords for the norm +norm_args = {'vmin': -40.0, 'vmax': 40.0} # colorbar label label = velocity (m/s) @@ -98,8 +98,8 @@ colormap_name = cmo.balance # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = -10.0, 10.0 +# A dictionary with keywords for the norm +norm_args = {'vmin': -10.0, 'vmax': 10.0} # colorbar label label = water-column thickness (m) @@ -114,8 +114,8 @@ colormap_name = cmo.balance # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = -0.3, 0.3 +# A dictionary with keywords for the norm +norm_args = {'vmin': -0.3, 'vmax': 0.3} # colorbar label label = velocity (m/s) diff --git a/polaris/ocean/tasks/sphere_transport/sphere_transport.cfg b/polaris/ocean/tasks/sphere_transport/sphere_transport.cfg index 0e1d4d4c7..b5a4da2a9 100644 --- a/polaris/ocean/tasks/sphere_transport/sphere_transport.cfg +++ b/polaris/ocean/tasks/sphere_transport/sphere_transport.cfg @@ -94,8 +94,8 @@ over_color = orange # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = 0., 1. +# A dictionary with keywords for the norm +norm_args = {'vmin': 0., 'vmax': 1.} # options for plotting tracer differences from sphere transport tests @@ -108,8 +108,8 @@ colormap_name = cmo.balance # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = -0.25, 0.25 +# A dictionary with keywords for the norm +norm_args = {'vmin': -0.25, 'vmax': 0.25} # options for thickness visualization for the sphere transport test case @@ -122,8 +122,8 @@ colormap_name = viridis # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = 99., 101. +# A dictionary with keywords for the norm +norm_args = {'vmin': 99., 'vmax': 101.} # options for plotting tracer differences from sphere transport tests @@ -136,5 +136,5 @@ colormap_name = cmo.balance # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = -0.25, 0.25 +# A dictionary with keywords for the norm +norm_args = {'vmin': -0.25, 'vmax': 0.25} diff --git a/polaris/viz/spherical.py b/polaris/viz/spherical.py index d65d3ed21..1d61d62f4 100644 --- a/polaris/viz/spherical.py +++ b/polaris/viz/spherical.py @@ -2,14 +2,10 @@ import cartopy import cmocean # noqa: F401 -import geoviews.feature as gf -import holoviews as hv -import hvplot.pandas # noqa: F401 -import matplotlib import matplotlib.colors as cols import matplotlib.pyplot as plt -import uxarray as ux -from matplotlib import cm +import mosaic +import xarray as xr from mpl_toolkits.axes_grid1.inset_locator import inset_axes from pyremap.descriptor.utility import interp_extrap_corner @@ -47,9 +43,6 @@ def plot_global_mpas_field(mesh_filename, da, out_filename, config, norm_type The norm: {'linear', 'log'} - colorbar_limits - The minimum and maximum value of the colorbar - title : str, optional The subtitle of the plot @@ -71,59 +64,49 @@ def plot_global_mpas_field(mesh_filename, da, out_filename, config, patch_edge_color : str, optional The color of patch edges (if not the same as the face) """ - matplotlib.use('agg') - uxgrid = ux.open_grid(mesh_filename) - uxda = ux.UxDataArray(da, uxgrid=uxgrid) - - gdf_data = uxda.to_geodataframe() - hv.extension('matplotlib') + use_mplstyle() + transform = cartopy.crs.Geodetic() projection = cartopy.crs.PlateCarree(central_longitude=central_longitude) - colormap = config.get(colormap_section, 'colormap_name') - cmap = cm.get_cmap(colormap) - if config.has_option(colormap_section, 'under_color'): - under_color = config.get(colormap_section, 'under_color') - cmap.set_under(under_color) - if config.has_option(colormap_section, 'over_color'): - over_color = config.get(colormap_section, 'over_color') - cmap.set_over(over_color) + mesh_ds = xr.open_dataset(mesh_filename) + descriptor = mosaic.Descriptor( + mesh_ds, projection=projection, transform=transform, use_latlon=True + ) - norm_type = config.get(colormap_section, 'norm_type') - if norm_type == 'linear': - logz = False - elif norm_type == 'log': - logz = True - else: - raise ValueError(f'Unsupported norm_type: {norm_type}') + fig, ax = plt.subplots(figsize=figsize, + constrained_layout=True, + subplot_kw=dict(projection=projection)) + + colormap, norm, ticks = _setup_colormap(config, colormap_section) + + pcolor_kwargs = dict( + cmap=colormap, norm=norm, zorder=1, edgecolors='face', linewidths=0.2 + ) - colorbar_limits = config.getlist(colormap_section, 'colorbar_limits', - dtype=float) + if patch_edge_color is not None: + pcolor_kwargs['edgecolors'] = patch_edge_color - plot = gdf_data.hvplot.polygons( - c=da.name, cmap=cmap, logz=logz, - clim=tuple(colorbar_limits), - clabel=colorbar_label, - width=1600, height=800, title=title, - xlabel="Longitude", ylabel="Latitude", - projection=projection, - rasterize=True) + gl = ax.gridlines( + color='gray', linestyle=':', zorder=5, draw_labels=True, linewidth=0.5 + ) + gl.right_labels = False + gl.top_labels = False + + pc = mosaic.polypcolor(ax, descriptor, da, **pcolor_kwargs) if plot_land: - plot = plot * gf.land * gf.coastline + ax._add_land_lakes_coastline(ax) - plot.opts(cbar_extend='both', cbar_width=0.03) + cbar = fig.colorbar( + pc, ax=ax, label=colorbar_label, extend='both', shrink=0.6 + ) + + if ticks is not None: + cbar.set_ticks(ticks) + cbar.set_ticklabels([f'{tick}' for tick in ticks]) - if patch_edge_color is not None: - gdf_grid = uxgrid.to_geodataframe() - # color parameter seems to be ignored -- always plots blue - edge_plot = gdf_grid.hvplot.paths( - linewidth=0.2, color=patch_edge_color, projection=projection) - plot = plot * edge_plot - - fig = hv.render(plot) - fig.set_size_inches(figsize) fig.savefig(out_filename, dpi=dpi, bbox_inches='tight', pad_inches=0.1) From b46ec7ec0fa1aaca45d0a09a071113045ac21dd5 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Wed, 22 Jan 2025 09:51:09 -0800 Subject: [PATCH 03/12] Update dependencies for `mosaic` and bump alpha version. Dependencies for `xarray` based visualization have been removed --- deploy/conda-dev-spec.template | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/deploy/conda-dev-spec.template b/deploy/conda-dev-spec.template index def310e45..010337739 100644 --- a/deploy/conda-dev-spec.template +++ b/deploy/conda-dev-spec.template @@ -1,6 +1,5 @@ # Base python>=3.9,<3.13 -antimeridian cartopy cartopy_offlinedata cmocean @@ -9,9 +8,6 @@ dask <2025.1.0 esmf={{ esmf }}={{ mpi_prefix }}_* ffmpeg geometric_features={{ geometric_features }} -geoviews -holoviews -hvplot importlib_resources ipython jupyter @@ -22,6 +18,7 @@ mache={{ mache }} matplotlib-base>=3.9.0 metis={{ metis }} moab={{ moab }}=*_tempest_* +mosaic>=1.0.0,<2.0.0 mpas_tools={{ mpas_tools }} nco netcdf4=*=nompi_* @@ -37,8 +34,6 @@ ruamel.yaml requests scipy>=1.8.0 shapely>=2.0,<3.0 -spatialpandas -uxarray <2025.01.0 xarray # Static typing From 049035f6696c1b5085d52deee6743bb8352362c5 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Wed, 22 Jan 2025 14:30:05 -0800 Subject: [PATCH 04/12] Use mosaic for vertex plotting instead of tricontourf. --- .../ocean/tasks/barotropic_gyre/analysis.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/polaris/ocean/tasks/barotropic_gyre/analysis.py b/polaris/ocean/tasks/barotropic_gyre/analysis.py index 430638c77..e757061ad 100644 --- a/polaris/ocean/tasks/barotropic_gyre/analysis.py +++ b/polaris/ocean/tasks/barotropic_gyre/analysis.py @@ -2,6 +2,7 @@ import cmocean # noqa: F401 import matplotlib.pyplot as plt +import mosaic import numpy as np import xarray as xr from mpas_tools.ocean import compute_barotropic_streamfunction @@ -61,26 +62,35 @@ def run(self): boundary_condition=self.boundary_condition) print(f'L2 error norm for {self.boundary_condition} bsf: {error:1.2e}') + descriptor = mosaic.Descriptor(ds_mesh) + use_mplstyle() pad = 20 fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 2)) - x0 = ds_mesh.xEdge.min() - y0 = ds_mesh.yEdge.min() - x_vertex = (ds_mesh['xVertex'] - x0) * 1.e-3 - y_vertex = (ds_mesh['yVertex'] - y0) * 1.e-3 + x0 = float(ds_mesh.xEdge.min()) + y0 = float(ds_mesh.yEdge.min()) + + # offset coordinates + descriptor.vertex_patches[..., 0] -= x0 + descriptor.vertex_patches[..., 1] -= y0 + # convert to km + descriptor.vertex_patches *= 1.e-3 + eta0 = max(np.max(np.abs(field_exact.values)), np.max(np.abs(field_mpas.values))) - s = axes[0].tricontourf(x_vertex, y_vertex, field_mpas, 10, - vmin=-eta0, vmax=eta0, cmap='cmo.balance') + + s = mosaic.polypcolor(axes[0], descriptor, field_mpas, vmin=-eta0, + vmax=eta0, cmap='cmo.balance', antialiased=False) cbar = fig.colorbar(s, ax=axes[0]) cbar.ax.set_title(r'$\psi$') - s = axes[1].tricontourf(x_vertex, y_vertex, field_exact, 10, - vmin=-eta0, vmax=eta0, cmap='cmo.balance') + s = mosaic.polypcolor(axes[1], descriptor, field_exact, vmin=-eta0, + vmax=eta0, cmap='cmo.balance', antialiased=False) cbar = fig.colorbar(s, ax=axes[1]) cbar.ax.set_title(r'$\psi$') eta0 = np.max(np.abs(field_mpas.values - field_exact.values)) - s = axes[2].tricontourf(x_vertex, y_vertex, field_mpas - field_exact, - 10, vmin=-eta0, vmax=eta0, cmap='cmo.balance') + s = mosaic.polypcolor(axes[2], descriptor, field_mpas - field_exact, + vmin=-eta0, vmax=eta0, cmap='cmo.balance', + antialiased=False) cbar = fig.colorbar(s, ax=axes[2]) cbar.ax.set_title(r'$d\psi$') axes[0].set_title('Numerical solution', pad=pad) @@ -90,7 +100,14 @@ def run(self): axes[1].set_xlabel('x (km)') axes[2].set_title('Error (Numerical - Analytical)', pad=pad) axes[2].set_xlabel('x (km)') + + xmin = descriptor.vertex_patches[..., 0].min() + xmax = descriptor.vertex_patches[..., 0].max() + ymin = descriptor.vertex_patches[..., 1].min() + ymax = descriptor.vertex_patches[..., 1].max() for ax in axes: + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) ax.set_aspect('equal') fig.savefig('comparison.png', bbox_inches='tight', pad_inches=0.1) From 19c6c599202dddc36f01fdf44a44138cc41fadaf Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Wed, 22 Jan 2025 16:44:10 -0800 Subject: [PATCH 05/12] Add title to spherical plots if provided. --- polaris/viz/spherical.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/polaris/viz/spherical.py b/polaris/viz/spherical.py index 1d61d62f4..91db58601 100644 --- a/polaris/viz/spherical.py +++ b/polaris/viz/spherical.py @@ -79,6 +79,9 @@ def plot_global_mpas_field(mesh_filename, da, out_filename, config, constrained_layout=True, subplot_kw=dict(projection=projection)) + if title is not None: + fig.suptitle(title, y=0.935) + colormap, norm, ticks = _setup_colormap(config, colormap_section) pcolor_kwargs = dict( From 65c0354ee609147f7fff6aa74618935f72d36770 Mon Sep 17 00:00:00 2001 From: Andrew Nolan <32367657+andrewdnolan@users.noreply.github.com> Date: Thu, 23 Jan 2025 07:32:37 -0700 Subject: [PATCH 06/12] Apply suggestions from code review Co-authored-by: Xylar Asay-Davis --- polaris/mpas/mask.py | 4 ++-- polaris/ocean/tasks/barotropic_gyre/analysis.py | 4 ++-- polaris/viz/planar.py | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/polaris/mpas/mask.py b/polaris/mpas/mask.py index c2ff4690d..bcfa96f61 100644 --- a/polaris/mpas/mask.py +++ b/polaris/mpas/mask.py @@ -1,5 +1,5 @@ -def cell_mask_2_edge_mask(ds_mesh, cell_mask): +def cell_mask_to_edge_mask(ds_mesh, cell_mask): """Convert a cell mask to edge mask using mesh connectivity information True corresponds to valid cells and False are invalid cells @@ -24,7 +24,7 @@ def cell_mask_2_edge_mask(ds_mesh, cell_mask): return ds_mesh.nEdges > -1 # zero index the connectivity array - cellsOnEdge = (ds_mesh.cellsOnEdge - 1) + cells_on_edge = (ds_mesh.cellsOnEdge - 1) # using nCells (dim) instead of indexToCellID since it's already 0 indexed masked_cells = ds_mesh.nCells.where(~cell_mask, drop=True).astype(int) diff --git a/polaris/ocean/tasks/barotropic_gyre/analysis.py b/polaris/ocean/tasks/barotropic_gyre/analysis.py index e757061ad..f6dfce21f 100644 --- a/polaris/ocean/tasks/barotropic_gyre/analysis.py +++ b/polaris/ocean/tasks/barotropic_gyre/analysis.py @@ -67,8 +67,8 @@ def run(self): use_mplstyle() pad = 20 fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 2)) - x0 = float(ds_mesh.xEdge.min()) - y0 = float(ds_mesh.yEdge.min()) + x0 = ds_mesh.xEdge.min().values + y0 = ds_mesh.yEdge.min().values # offset coordinates descriptor.vertex_patches[..., 0] -= x0 diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index 86af021b6..04c255184 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -28,7 +28,7 @@ def plot_horiz_field(ds_mesh, field, out_file_name=None, # noqa: C901 ds_mesh : xarray.Dataset A data set containing horizontal mesh variables - data_array : xarray.DataArray + field : xarray.DataArray The data array to plot out_file_name : str, optional @@ -80,7 +80,7 @@ def plot_horiz_field(ds_mesh, field, out_file_name=None, # noqa: C901 Name of the vertical dimension field_mask : xarray.DataArray, optional - A ``bool`` mask indicating where the `data_array` is valid. + A ``bool`` mask indicating where the ``field`` is valid. descriptor : mosaic.Descriptor, optional Descriptor from a previous call to ``plot_horiz_field()`` @@ -206,8 +206,6 @@ def plot_horiz_field(ds_mesh, field, out_file_name=None, # noqa: C901 cbar.set_label(cmap_title) if transect_x is not None: - transect_x = transect_x - transect_y = transect_y ax.plot(transect_x, transect_y, color=transect_color, linewidth=transect_linewidth) if transect_start is not None: From 49f414f1f73dd4793742bf2452a2abef82e80842 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Thu, 23 Jan 2025 08:07:08 -0800 Subject: [PATCH 07/12] More suggestions from code review. - Renamed all instances of cell_mask_to_edge_mask and added to API - Fixed typo with adding cartopy features to spherical plots - Added mosaic to intersphinx mapping --- docs/conf.py | 1 + docs/developers_guide/api.md | 1 + polaris/mpas/__init__.py | 2 +- polaris/mpas/mask.py | 2 +- polaris/ocean/tasks/baroclinic_channel/init.py | 4 ++-- polaris/ocean/tasks/baroclinic_channel/viz.py | 4 ++-- polaris/ocean/tasks/ice_shelf_2d/viz.py | 4 ++-- polaris/viz/spherical.py | 2 +- 8 files changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e39eb23ae..01ade9a1c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -71,6 +71,7 @@ ('https://mpas-dev.github.io/geometric_features/main', None), 'matplotlib': ('https://matplotlib.org/stable', None), 'mpas_tools': ('https://mpas-dev.github.io/MPAS-Tools/master', None), + 'mosaic': ('https://docs.e3sm.org/mosaic/index.html', None), 'numpy': ('https://numpy.org/doc/stable', None), 'python': ('https://docs.python.org', None), 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), diff --git a/docs/developers_guide/api.md b/docs/developers_guide/api.md index 05d440112..8354e17ad 100644 --- a/docs/developers_guide/api.md +++ b/docs/developers_guide/api.md @@ -297,6 +297,7 @@ seaice/api :toctree: generated/ area_for_field + cell_mask_to_edge_mask time_index_from_xtime ``` diff --git a/polaris/mpas/__init__.py b/polaris/mpas/__init__.py index 8e9e32956..b6e24871a 100644 --- a/polaris/mpas/__init__.py +++ b/polaris/mpas/__init__.py @@ -1,3 +1,3 @@ from polaris.mpas.area import area_for_field -from polaris.mpas.mask import cell_mask_2_edge_mask +from polaris.mpas.mask import cell_mask_to_edge_mask from polaris.mpas.time import time_index_from_xtime, time_since_start diff --git a/polaris/mpas/mask.py b/polaris/mpas/mask.py index bcfa96f61..744ac6a15 100644 --- a/polaris/mpas/mask.py +++ b/polaris/mpas/mask.py @@ -30,6 +30,6 @@ def cell_mask_to_edge_mask(ds_mesh, cell_mask): masked_cells = ds_mesh.nCells.where(~cell_mask, drop=True).astype(int) # use inverse so True/False convention matches input cell_mask - edge_mask = ~cellsOnEdge.isin(masked_cells).any("TWO") + edge_mask = ~cells_on_edge.isin(masked_cells).any("TWO") return edge_mask diff --git a/polaris/ocean/tasks/baroclinic_channel/init.py b/polaris/ocean/tasks/baroclinic_channel/init.py index 24c97d7f7..2ae5a6fae 100644 --- a/polaris/ocean/tasks/baroclinic_channel/init.py +++ b/polaris/ocean/tasks/baroclinic_channel/init.py @@ -7,7 +7,7 @@ from polaris import Step from polaris.mesh.planar import compute_planar_hex_nx_ny -from polaris.mpas import cell_mask_2_edge_mask +from polaris.mpas import cell_mask_to_edge_mask from polaris.ocean.vertical import init_vertical_coord from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -164,7 +164,7 @@ def run(self): write_netcdf(ds, 'initial_state.nc') cell_mask = ds.maxLevelCell >= 1 - edge_mask = cell_mask_2_edge_mask(ds, cell_mask) + edge_mask = cell_mask_to_edge_mask(ds, cell_mask) plot_horiz_field(ds_mesh, ds['normalVelocity'], 'initial_normal_velocity.png', cmap='cmo.balance', diff --git a/polaris/ocean/tasks/baroclinic_channel/viz.py b/polaris/ocean/tasks/baroclinic_channel/viz.py index eb79c324a..80741cdb1 100644 --- a/polaris/ocean/tasks/baroclinic_channel/viz.py +++ b/polaris/ocean/tasks/baroclinic_channel/viz.py @@ -3,7 +3,7 @@ import xarray as xr from polaris import Step -from polaris.mpas import cell_mask_2_edge_mask +from polaris.mpas import cell_mask_to_edge_mask from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -44,7 +44,7 @@ def run(self): ds = xr.load_dataset('output.nc') t_index = ds.sizes['Time'] - 1 cell_mask = ds_init.maxLevelCell >= 1 - edge_mask = cell_mask_2_edge_mask(ds_init, cell_mask) + edge_mask = cell_mask_to_edge_mask(ds_init, cell_mask) max_velocity = np.max(np.abs(ds.normalVelocity.values)) plot_horiz_field(ds_mesh, ds['normalVelocity'], 'final_normalVelocity.png', diff --git a/polaris/ocean/tasks/ice_shelf_2d/viz.py b/polaris/ocean/tasks/ice_shelf_2d/viz.py index a32ece4a6..69660beb5 100644 --- a/polaris/ocean/tasks/ice_shelf_2d/viz.py +++ b/polaris/ocean/tasks/ice_shelf_2d/viz.py @@ -5,7 +5,7 @@ import xarray as xr from polaris import Step -from polaris.mpas import cell_mask_2_edge_mask +from polaris.mpas import cell_mask_to_edge_mask from polaris.ocean.viz import compute_transect, plot_transect from polaris.viz import plot_horiz_field @@ -119,7 +119,7 @@ def run(self): # Plot water column thickness horizontal ds_init cell_mask = ds_init.maxLevelCell >= 1 - edge_mask = cell_mask_2_edge_mask(ds_init, cell_mask) + edge_mask = cell_mask_to_edge_mask(ds_init, cell_mask) plot_horiz_field(ds_mesh, ds_horiz['columnThickness'], 'H_horiz_init.png', t_index=None, field_mask=cell_mask) diff --git a/polaris/viz/spherical.py b/polaris/viz/spherical.py index 91db58601..d5f7c095d 100644 --- a/polaris/viz/spherical.py +++ b/polaris/viz/spherical.py @@ -100,7 +100,7 @@ def plot_global_mpas_field(mesh_filename, da, out_filename, config, pc = mosaic.polypcolor(ax, descriptor, da, **pcolor_kwargs) if plot_land: - ax._add_land_lakes_coastline(ax) + _add_land_lakes_coastline(ax) cbar = fig.colorbar( pc, ax=ax, label=colorbar_label, extend='both', shrink=0.6 From d9382afadfe5126b78bb8248b3969e17015aad54 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Thu, 23 Jan 2025 12:34:27 -0800 Subject: [PATCH 08/12] Update documentation for mosaic based visualization. --- docs/conf.py | 2 +- .../framework/visualization.md | 143 ++++++++---------- 2 files changed, 60 insertions(+), 85 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 01ade9a1c..dda3cf086 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -71,7 +71,7 @@ ('https://mpas-dev.github.io/geometric_features/main', None), 'matplotlib': ('https://matplotlib.org/stable', None), 'mpas_tools': ('https://mpas-dev.github.io/MPAS-Tools/master', None), - 'mosaic': ('https://docs.e3sm.org/mosaic/index.html', None), + 'mosaic': ('https://docs.e3sm.org/mosaic', None), 'numpy': ('https://numpy.org/doc/stable', None), 'python': ('https://docs.python.org', None), 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), diff --git a/docs/developers_guide/framework/visualization.md b/docs/developers_guide/framework/visualization.md index ce15e0932..c978ef8c9 100644 --- a/docs/developers_guide/framework/visualization.md +++ b/docs/developers_guide/framework/visualization.md @@ -6,9 +6,10 @@ Visualization is an optional, but desirable aspect of tasks. Often, visualization is an optional step of a task but can also be included as part of other steps such as `init` or `analysis`. -While developers can write their own visualization scripts associated with -individual tasks, the following shared visualization routines are -provided in `polaris.viz`: +Horizontal visualization of MPAS fields is enabled through the use of +[`mosaic`](https://docs.e3sm.org/mosaic/). While developers can write their +own visualization scripts associated with individual tasks, the following +shared visualization routines are provided in `polaris.viz`: (dev-visualization-style)= @@ -25,13 +26,15 @@ before creating a `matplotlib` figure. ## horizontal fields from planar meshes -{py:func}`polaris.viz.plot_horiz_field()` produces a patches-style -visualization of x-y fields across a single vertical level at a single time -step. The image file (png) is saved to the directory from which -{py:func}`polaris.viz.plot_horiz_field()` is called. The function -automatically detects whether the field specified by its variable name is -a cell-centered variable or an edge-variable and generates the patches, the -polygons characterized by the field values, accordingly. +{py:func}`polaris.viz.plot_horiz_field()` produces a visualization of +horizontal fields at their native mesh location (i.e. cells, edges, or +vertices) at a single vertical level and a single time step. The image file +(png) is saved to the directory from which +{py:func}`polaris.viz.plot_horiz_field()` is called. +{py:func}`polaris.viz.plot_horiz_field()` is jut a wrapper for +{py:func}`mosaic.polypcolor()`, which automatically detects whether the field +to be plotted is defined at cells, edges, or vertices and generates the patches +(i.e. the polygons characterized by the field values) accordingly. ```{image} images/baroclinic_channel_cell_patches.png :align: center @@ -47,39 +50,38 @@ An example function call that uses the default vertical level (top) is: ```python cell_mask = ds_init.maxLevelCell >= 1 -plot_horiz_field(config, ds, ds_mesh, 'normalVelocity', - 'final_normalVelocity.png', - t_index=t_index, - vmin=-max_velocity, vmax=max_velocity, - cmap='cmo.balance', show_patch_edges=True, - cell_mask=cell_mask) -``` +edge_mask = cell_mask_to_edge_mask(ds_init, cell_mask) -The `cell_mask` argument can be any field indicating which horizontal cells -are valid and which are not. A typical value for ocean plots is as shown -above: whether there are any active cells in the water column. +plot_horiz_field(ds_mesh, ds['normalVelocity'], 'final_normalVelocity.png', + t_index=t_index, vmin=-max_velocity, vmax=max_velocity, + cmap='cmo.balance', show_patch_edges=True, + field_mask=edge_mask) +``` -For increased efficiency, you can store the `patches` and `patch_mask` from -one call to `plot_horiz_field()` and reuse them in subsequent calls. The -`patches` and `patch_mask` are specific to the dimension (`nCell` or `nEdges`) -of the field to plot and the `cell_mask`. So separate `patches` and -`patch_mask` variables should be stored for as needed: +The `field_mask` argument can be any field indicating which horizontal mesh +locations are valid and which are not, but it must be the same shape as data +array being plotted. A typical value for ocean plots is as shown +above: whether there are any active cells in the water column and then the cell +mask is converted to an edges mask using the +{py:func}`polaris.mpas.cell_mask_to_edge_mask()` function. +For increased efficiency, you can store the instance of +{py:class}`mosaic.Descriptor` returned by `plot_horiz_field()` and reuse it in +subsequent calls; assuming you are plotting with the same mesh. ```python cell_mask = ds_init.maxLevelCell >= 1 -cell_patches, cell_patch_mask = plot_horiz_field( - ds=ds, ds_mesh=ds_mesh, field_name='ssh', out_file_name='plots/ssh.png', - vmin=-720, vmax=0, figsize=figsize, cell_mask=cell_mask) +descriptor = plot_horiz_field(ds_mesh, ds['ssh'], 'plots/ssh.png', + vmin=-720, vmax=0, figsize=figsize, + field_mask=cell_mask) -plot_horiz_field(ds=ds, ds_mesh=ds_mesh, field_name='bottomDepth', - out_file_name='plots/bottomDepth.png', vmin=0, vmax=720, - figsize=figsize, patches=cell_patches, - patch_mask=cell_patch_mask) +plot_horiz_field(ds_mesh, ds['bottomDepth'], 'plots/bottomDepth.png', + vmin=0, vmax=720, figsize=figsize, field_mask=cell_mask, + descriptor=descriptor) -edge_patches, edge_patch_mask = plot_horiz_field( - ds=ds, ds_mesh=ds_mesh, field_name='normalVelocity', - out_file_name='plots/normalVelocity.png', t_index=t_index, - vmin=-0.1, vmax=0.1, cmap='cmo.balance', cell_mask=cell_mask) +edge_mask = cell_mask_to_edge_mask(ds_mesh, cell_mask) +plot_horiz_field(ds_mesh, ds['normalVelocity'], 'plots/normalVelocity.png', + t_index=t_index, vmin=-0.1, vmax=0.1, cmap='cmo.balance', + field_mask=edge_mask, descriptor=descriptor) ... ``` @@ -91,7 +93,14 @@ edge_patches, edge_patch_mask = plot_horiz_field( ### plotting from spherical MPAS meshes You can use {py:func}`polaris.viz.plot_global_mpas_field()` to plot a field on -a spherical MPAS mesh. +a spherical MPAS mesh. Like the planar visualization function, this is also +just a wrapper to {py:func}`mosaic.polypcolor()`. Thanks to `mosaic` variables +defined at cells, edges, and vertices are all support as well as meshes with +culled land boundaries are also supported. While `mosaic` +[supports](https://docs.e3sm.org/mosaic/user_guide/wrapping.html) a variety +of map projection for spherical meshes, +{py:func}`polaris.viz.plot_global_mpas_field()` currently only supports +[`cartopy.crs.PlateCarree`](https://scitools.org.uk/cartopy/docs/latest/reference/projections.html#cartopy.crs.PlateCarree). ```{image} images/cosine_bell_final_mpas.png :align: center @@ -127,7 +136,7 @@ The `central_longitude` defaults to `0.0` and can be set to another value (typically 180 degrees) for visualizing quantities that would otherwise be divided across the antimeridian. -The `colormap_section` of the config file must contain config options for +The `_viz` section of the config file must contain config options for specifying the colormap: ```cfg @@ -141,26 +150,25 @@ colormap_name = viridis # the type of norm used in the colormap norm_type = linear -# colorbar limits -colorbar_limits = 0.0, 1.0 +# A dictionary with keywords for the norm +norm_args = {'vmin': 0., 'vmax': 1.} ``` `colormap_name` can be any available matplotlib colormap. For ocean test cases, we recommend importing [cmocean](https://matplotlib.org/cmocean/) so the standard ocean colormaps are available. -The `norm_type` is one of `linear` (a linear colormap) or `log` (a logarithmic -colormap). +The `norm_type` is one of `linear` (a linear colormap), `symlog` (a +[symmetric log](https://matplotlib.org/stable/gallery/images_contours_and_fields/colormap_normalizations_symlognorm.html) +colormap with a central linear region), or `log` (a logarithmic colormap). -The `colorbar_limits` are the lower and upper bound of the colorbar range. +The `norm_args` depend on the `norm_typ` and are the arguments to +{py:class}`matplotlib.colors.Normalize`, {py:class}`matplotlib.colors.SymLogNorm`, +and {py:class}`matplotlib.colors.LogNorm`, respectively. -There are also two optional config options used to set the colors on either end of the colormap: +The config option `colorbar_ticks` (if it is defined) specifies tick locations +along the colorbar. If it is not specified, they are determined automatically. -```cfg -# [optional] colormap set_under and set_over options -under_color = k -over_color = orange -``` ### plotting from lat/lon grids You can use {py:func}`polaris.viz.plot_global_lat_lon_field()` to plot a field @@ -200,38 +208,5 @@ class Viz(Step): title='Tracer at init', plot_land=False) ``` -The `colormap_section` of the config file must contain config options for -specifying the colormap: - -```cfg -# options for visualization for the cosine bell convergence task -[cosine_bell_viz] - -# colormap options -# colormap -colormap_name = viridis - -# the type of norm used in the colormap -norm_type = linear - -# A dictionary with keywords for the norm -norm_args = {'vmin': 0., 'vmax': 1.} - -# We could provide colorbar tick marks but we'll leave the defaults -# colorbar_ticks = np.linspace(0., 1., 9) -``` - -`colormap_name` can be any available matplotlib colormap. For ocean test -cases, we recommend importing [cmocean](https://matplotlib.org/cmocean/) so -the standard ocean colormaps are available. - -The `norm_type` is one of `linear` (a linear colormap), `symlog` (a -[symmetric log](https://matplotlib.org/stable/gallery/images_contours_and_fields/colormap_normalizations_symlognorm.html) -colormap with a central linear region), or `log` (a logarithmic colormap). - -The `norm_args` depend on the `norm_typ` and are the arguments to -{py:class}`matplotlib.colors.Normalize`, {py:class}`matplotlib.colors.SymLogNorm`, -and {py:class}`matplotlib.colors.LogNorm`, respectively. - -The config option `colorbar_ticks` (if it is defined) specifies tick locations -along the colorbar. If it is not specified, they are determined automatically. +The `_viz` of the config file is the same as what's used by +{py:func}`polaris.viz.plot_global_mpas_field()`. From 11c9a5d8a4f9215a400f19b17f775236e4ea66fb Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Thu, 23 Jan 2025 14:19:18 -0800 Subject: [PATCH 09/12] Use discrete colorbar for plotting barotropic stream function --- polaris/ocean/tasks/barotropic_gyre/analysis.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/polaris/ocean/tasks/barotropic_gyre/analysis.py b/polaris/ocean/tasks/barotropic_gyre/analysis.py index f6dfce21f..21a07a0d1 100644 --- a/polaris/ocean/tasks/barotropic_gyre/analysis.py +++ b/polaris/ocean/tasks/barotropic_gyre/analysis.py @@ -5,6 +5,7 @@ import mosaic import numpy as np import xarray as xr +from matplotlib import colors as mcolors from mpas_tools.ocean import compute_barotropic_streamfunction from polaris.mpas import area_for_field @@ -79,18 +80,22 @@ def run(self): eta0 = max(np.max(np.abs(field_exact.values)), np.max(np.abs(field_mpas.values))) - s = mosaic.polypcolor(axes[0], descriptor, field_mpas, vmin=-eta0, - vmax=eta0, cmap='cmo.balance', antialiased=False) + bounds = np.linspace(-eta0, eta0, 21) + norm = mcolors.BoundaryNorm(bounds, cmocean.cm.amp.N) + s = mosaic.polypcolor(axes[0], descriptor, field_mpas, + cmap='cmo.balance', norm=norm, antialiased=False) cbar = fig.colorbar(s, ax=axes[0]) cbar.ax.set_title(r'$\psi$') - s = mosaic.polypcolor(axes[1], descriptor, field_exact, vmin=-eta0, - vmax=eta0, cmap='cmo.balance', antialiased=False) + s = mosaic.polypcolor(axes[1], descriptor, field_exact, + cmap='cmo.balance', norm=norm, antialiased=False) cbar = fig.colorbar(s, ax=axes[1]) cbar.ax.set_title(r'$\psi$') + eta0 = np.max(np.abs(field_mpas.values - field_exact.values)) + bounds = np.linspace(-eta0, eta0, 11) + norm = mcolors.BoundaryNorm(bounds, cmocean.cm.balance.N) s = mosaic.polypcolor(axes[2], descriptor, field_mpas - field_exact, - vmin=-eta0, vmax=eta0, cmap='cmo.balance', - antialiased=False) + cmap='cmo.balance', norm=norm, antialiased=False) cbar = fig.colorbar(s, ax=axes[2]) cbar.ax.set_title(r'$d\psi$') axes[0].set_title('Numerical solution', pad=pad) From 394fa4223f7c7a6a009e1a134976e4e96111ed75 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Wed, 5 Feb 2025 07:44:15 -0800 Subject: [PATCH 10/12] Bump mosaic version and let mosaic set axis limits for periodic axes --- deploy/conda-dev-spec.template | 2 +- polaris/viz/planar.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/deploy/conda-dev-spec.template b/deploy/conda-dev-spec.template index 010337739..a2ee1f391 100644 --- a/deploy/conda-dev-spec.template +++ b/deploy/conda-dev-spec.template @@ -18,7 +18,7 @@ mache={{ mache }} matplotlib-base>=3.9.0 metis={{ metis }} moab={{ moab }}=*_tempest_* -mosaic>=1.0.0,<2.0.0 +mosaic>=1.1.0,<2.0.0 mpas_tools={{ mpas_tools }} nco netcdf4=*=nompi_* diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index 04c255184..efb5a7fe2 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -196,7 +196,15 @@ def plot_horiz_field(ds_mesh, field, out_file_name=None, # noqa: C901 ax.set_xlabel('x (km)') ax.set_ylabel('y (km)') ax.set_aspect('equal') - ax.autoscale(tight=True) + + if descriptor.is_periodic: + if descriptor.x_period and not descriptor.y_period: + ax.autoscale(axis='y', tight=True) + elif not descriptor.x_period and descriptor.y_period: + ax.autoscale(axis='x', tight=True) + else: + ax.autoscale(axis='both', tight=True) + # scale ticks to be in kilometers ax.xaxis.set_major_formatter(lambda x, pos: f'{x / 1e3:g}') ax.yaxis.set_major_formatter(lambda x, pos: f'{x / 1e3:g}') From f4ffbb74b10699a1e553c98877d098a124526ea8 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Wed, 5 Feb 2025 07:53:52 -0800 Subject: [PATCH 11/12] Bump alpha version following rebase --- polaris/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/version.py b/polaris/version.py index 83da76a51..0922efb42 100644 --- a/polaris/version.py +++ b/polaris/version.py @@ -1 +1 @@ -__version__ = '0.5.0-alpha.2' +__version__ = '0.5.0-alpha.3' From 9684bea4540b33c781033090c1bbfd3d930f2404 Mon Sep 17 00:00:00 2001 From: Andrew Nolan Date: Thu, 6 Feb 2025 10:22:01 -0800 Subject: [PATCH 12/12] Undo tight axis limits for periodic plots (for now). --- polaris/viz/planar.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/polaris/viz/planar.py b/polaris/viz/planar.py index efb5a7fe2..29d386112 100644 --- a/polaris/viz/planar.py +++ b/polaris/viz/planar.py @@ -196,14 +196,17 @@ def plot_horiz_field(ds_mesh, field, out_file_name=None, # noqa: C901 ax.set_xlabel('x (km)') ax.set_ylabel('y (km)') ax.set_aspect('equal') - - if descriptor.is_periodic: - if descriptor.x_period and not descriptor.y_period: - ax.autoscale(axis='y', tight=True) - elif not descriptor.x_period and descriptor.y_period: - ax.autoscale(axis='x', tight=True) - else: - ax.autoscale(axis='both', tight=True) + ax.autoscale(tight=True) + + # uncomment below once mosaic can mirror patches across periodic axis: + # -------------------------------------------------------------------- + # if descriptor.is_periodic: + # if descriptor.x_period and not descriptor.y_period: + # ax.autoscale(axis='y', tight=True) + # elif not descriptor.x_period and descriptor.y_period: + # ax.autoscale(axis='x', tight=True) + # else: + # ax.autoscale(axis='both', tight=True) # scale ticks to be in kilometers ax.xaxis.set_major_formatter(lambda x, pos: f'{x / 1e3:g}')