From 79f2952ca8d5961620996837e8e05c4e3d4fb469 Mon Sep 17 00:00:00 2001 From: Anton Steketee <79179784+anton-seaice@users.noreply.github.com> Date: Fri, 26 Apr 2024 08:59:38 +1000 Subject: [PATCH] Fix formatting of cice grid file (#18) * Add anglet variant (for cice6) and change x/y var names to match those uses in cice output --- esmgrids/cice_grid.py | 36 ++++++--- esmgrids/cli.py | 8 +- test/test_cice_grid.py | 168 +++++++++++++++++++++++++---------------- test/test_grids.py | 3 - 4 files changed, 136 insertions(+), 79 deletions(-) diff --git a/esmgrids/cice_grid.py b/esmgrids/cice_grid.py index 6445cf8..6639b2d 100644 --- a/esmgrids/cice_grid.py +++ b/esmgrids/cice_grid.py @@ -31,7 +31,11 @@ def fromfile(cls, h_grid_def, mask_file=None, description="CICE tripolar"): area_t = f.variables["tarea"][:] area_u = f.variables["uarea"][:] - angle_t = np.rad2deg(f.variables["angleT"][:]) + try: + angle_t = np.rad2deg(f.variables["anglet"][:]) + except KeyError: + angle_t = np.rad2deg(f.variables["angleT"][:]) + angle_u = np.rad2deg(f.variables["angle"][:]) if "clon_t" in f.variables: @@ -69,12 +73,12 @@ def _create_2d_nc_var(self, f, name): return f.createVariable( name, "f8", - dimensions=("ny", "nx"), + dimensions=("nj", "ni"), compression="zlib", complevel=1, ) - def write(self, grid_filename, mask_filename, metadata=None): + def write(self, grid_filename, mask_filename, metadata=None, variant=None): """ Write out CICE grid to netcdf @@ -86,16 +90,23 @@ def write(self, grid_filename, mask_filename, metadata=None): The name of the mask file to write metadata: dict Any global or variable metadata attributes to add to the files being written + variant: str + Use variant='cice5-auscom' for access-om2/cice5-auscom builds, otherwise use None """ + if variant is not None and variant != "cice5-auscom": + raise NotImplementedError(f"{variant} not recognised") + # Grid file f = nc.Dataset(grid_filename, "w") # Create dimensions. - f.createDimension("nx", self.num_lon_points) - # nx is the grid_longitude but doesn't have a value other than its index - f.createDimension("ny", self.num_lat_points) - # ny is the grid_latitude but doesn't have a value other than its index + f.createDimension( + "ni", self.num_lon_points + ) # ni is the grid_longitude but doesn't have a value other than its index + f.createDimension( + "nj", self.num_lat_points + ) # nj is the grid_latitude but doesn't have a value other than its index # Make all CICE grid variables. # names are based on https://cfconventions.org/Data/cf-standard-names/current/build/cf-standard-name-table.html @@ -135,7 +146,12 @@ def write(self, grid_filename, mask_filename, metadata=None): angle.standard_name = "angle_of_rotation_from_east_to_x" angle.coordinates = "ulat ulon" angle.grid_mapping = "crs" - angleT = self._create_2d_nc_var(f, "angleT") + + if variant == "cice5-auscom": + angleT = self._create_2d_nc_var(f, "angleT") + elif variant is None: + angleT = self._create_2d_nc_var(f, "anglet") + angleT.units = "radians" angleT.long_name = "Rotation angle of T cells." angleT.standard_name = "angle_of_rotation_from_east_to_x" @@ -185,8 +201,8 @@ def write(self, grid_filename, mask_filename, metadata=None): # Mask file f = nc.Dataset(mask_filename, "w") - f.createDimension("nx", self.num_lon_points) - f.createDimension("ny", self.num_lat_points) + f.createDimension("ni", self.num_lon_points) + f.createDimension("nj", self.num_lat_points) mask = self._create_2d_nc_var(f, "kmt") mask.grid_mapping = "crs" mask.standard_name = "sea_binary_mask" diff --git a/esmgrids/cli.py b/esmgrids/cli.py index 08c853d..15adfc5 100644 --- a/esmgrids/cli.py +++ b/esmgrids/cli.py @@ -15,18 +15,22 @@ def cice_from_mom(): parser.add_argument("--ocean_mask", type=str, help="Input MOM ocean_mask.nc mask file") parser.add_argument("--cice_grid", type=str, default="grid.nc", help="Output CICE grid file") parser.add_argument("--cice_kmt", type=str, default="kmt.nc", help="Output CICE kmt file") + parser.add_argument( + "--cice_variant", type=str, default=None, help="Cice variant, valid options = [None, 'cice5-auscom'] " + ) args = parser.parse_args() ocean_hgrid = os.path.abspath(args.ocean_hgrid) ocean_mask = os.path.abspath(args.ocean_mask) cice_grid = os.path.abspath(args.cice_grid) cice_kmt = os.path.abspath(args.cice_kmt) + cice_variant = args.cice_variant version = safe_version() runcmd = ( f"Created using https://github.com/COSIMA/esmgrids {version}: " f"cice_from_mom --ocean_hgrid={ocean_hgrid} --ocean_mask={ocean_mask} " - f"--cice_grid={cice_grid} --cice_kmt={cice_kmt}" + f"--cice_grid={cice_grid} --cice_kmt={cice_kmt} --cice_variant={cice_variant}" ) provenance_metadata = { "inputfile": ( @@ -37,4 +41,4 @@ def cice_from_mom(): mom = MomGrid.fromfile(ocean_hgrid, mask_file=ocean_mask) cice = CiceGrid.fromgrid(mom) - cice.write(cice_grid, cice_kmt, metadata=provenance_metadata) + cice.write(cice_grid, cice_kmt, metadata=provenance_metadata, variant=cice_variant) diff --git a/test/test_cice_grid.py b/test/test_cice_grid.py index 77e2f45..654f815 100644 --- a/test/test_cice_grid.py +++ b/test/test_cice_grid.py @@ -5,15 +5,19 @@ from subprocess import run from pathlib import Path -# from esmgrids.cli import cice_from_mom +from esmgrids.mom_grid import MomGrid +from esmgrids.cice_grid import CiceGrid # create test grids at 4 degrees and 0.1 degrees # 4 degress is the lowest tested in ocean_model_grid_generator # going higher resolution than 0.1 has too much computational cost _test_resolutions = [4, 0.1] +# run test using the valid cice variants +_variants = ["cice5-auscom", None] -# so that our fixtures are only create once in this pytest module, we need this special version of 'tmp_path' + +# so that our fixtures are only created once in this pytest module, we need this special version of 'tmp_path' @pytest.fixture(scope="module") def tmp_path(tmp_path_factory: pytest.TempdirFactory) -> Path: return tmp_path_factory.mktemp("temp") @@ -53,39 +57,31 @@ def __init__(self, res, tmp_path): class CiceGridFixture: """Make the CICE grid, using script under test""" - def __init__(self, mom_grid, tmp_path): + def __init__(self, mom_grid, tmp_path, variant): self.path = str(tmp_path) + "/grid.nc" self.kmt_path = str(tmp_path) + "/kmt.nc" - run( - [ - "cice_from_mom", - "--ocean_hgrid", - mom_grid.path, - "--ocean_mask", - mom_grid.mask_path, - "--cice_grid", - self.path, - "--cice_kmt", - self.kmt_path, - ] - ) - self.ds = xr.open_dataset(self.path, decode_cf=False) - self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False) + run_cmd = [ + "cice_from_mom", + "--ocean_hgrid", + mom_grid.path, + "--ocean_mask", + mom_grid.mask_path, + "--cice_grid", + self.path, + "--cice_kmt", + self.kmt_path, + ] + if variant is not None: + run_cmd.append("--cice_variant") + run_cmd.append(variant) + run(run_cmd) -# pytest doesn't support class fixtures, so we need these two constructor funcs -@pytest.fixture(scope="module", params=_test_resolutions) -def mom_grid(request, tmp_path): - return MomGridFixture(request.param, tmp_path) - - -@pytest.fixture(scope="module") -def cice_grid(mom_grid, tmp_path): - return CiceGridFixture(mom_grid, tmp_path) + self.ds = xr.open_dataset(self.path, decode_cf=False) + self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False) -@pytest.fixture(scope="module") -def test_grid_ds(mom_grid): +def gen_grid_ds(mom_grid, variant): # this generates the expected answers # In simple terms the MOM supergrid has four cells for each model grid cell. The MOM supergrid includes all edges (left and right) but CICE only uses right/east edges. (e.g. For points/edges of first cell: 0,0 is SW corner, 1,1 is middle of cell, 2,2, is NE corner/edges) @@ -105,7 +101,10 @@ def test_grid_ds(mom_grid): test_grid["tlon"] = deg2rad(t_points.x) test_grid["angle"] = deg2rad(u_points.angle_dx) # angle at u point - test_grid["angleT"] = deg2rad(t_points.angle_dx) + if variant == "cice5-auscom": + test_grid["angleT"] = deg2rad(t_points.angle_dx) + else: # cice6 + test_grid["anglet"] = deg2rad(t_points.angle_dx) # length of top (northern) edge of cells test_grid["htn"] = ds.dx.isel(nyp=slice(2, None, 2)).coarsen(nx=2).sum() * 100 @@ -128,33 +127,57 @@ def test_grid_ds(mom_grid): return test_grid +# pytest doesn't support class fixtures, so we need these two constructor funcs +@pytest.fixture(scope="module", params=_test_resolutions) +def mom_grid(request, tmp_path): + return MomGridFixture(request.param, tmp_path) + + +# the variant neews to be the same for both the cice_grid and the test_grid, so bundle them +@pytest.fixture(scope="module", params=_variants) +def grids(request, mom_grid, tmp_path): + return {"cice": CiceGridFixture(mom_grid, tmp_path, request.param), "test_ds": gen_grid_ds(mom_grid, request.param)} + + # ---------------- # the tests in earnest: @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_cice_var_list(cice_grid, test_grid_ds): +def test_cice_var_list(grids): # Test : Are there missing vars in cice_grid? - assert set(test_grid_ds.variables).difference(cice_grid.ds.variables) == set() + assert set(grids["test_ds"].variables).difference(grids["cice"].ds.variables) == set() + + +def test_cice_dims(grids): + # Test : Are the dim names consistent with cice history output? + assert set(grids["cice"].ds.dims) == set( + ["ni", "nj"] + ), "cice dimension names should be 'ni','nj' to be consistent with history output" + assert grids["cice"].ds.sizes["ni"] == len(grids["test_ds"].nx) + assert grids["cice"].ds.sizes["nj"] == len(grids["test_ds"].ny) @pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_cice_grid(cice_grid, test_grid_ds): +def test_cice_grid(grids): # Test : Is the data the same as the test_grid - for jVar in test_grid_ds.variables: - assert_allclose(cice_grid.ds[jVar], test_grid_ds[jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch") + for jVar in grids["test_ds"].variables: + assert_allclose( + grids["cice"].ds[jVar], grids["test_ds"][jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch" + ) -def test_cice_kmt(mom_grid, cice_grid): +def test_cice_kmt(mom_grid, grids): # Test : does the mask match mask = mom_grid.mask_ds.mask - kmt = cice_grid.kmt_ds.kmt + kmt = grids["cice"].kmt_ds.kmt assert_allclose(mask, kmt, rtol=1e-13, verbose=True, err_msg="mask mismatch") -def test_cice_grid_attributes(cice_grid): +def test_cice_grid_attributes(grids): # Test: do the expected attributes to exist in the cice ds + # To-do: rewrite test using the CF-checker (or similar) cf_attributes = { "ulat": {"standard_name": "latitude", "units": "radians"}, "ulon": {"standard_name": "longitude", "units": "radians"}, @@ -184,48 +207,65 @@ def test_cice_grid_attributes(cice_grid): "grid_mapping": "crs", "coordinates": "tlat tlon", }, + "anglet": { + "standard_name": "angle_of_rotation_from_east_to_x", + "units": "radians", + "grid_mapping": "crs", + "coordinates": "tlat tlon", + }, "htn": {"units": "cm", "coordinates": "ulat tlon", "grid_mapping": "crs"}, "hte": {"units": "cm", "coordinates": "tlat ulon", "grid_mapping": "crs"}, } - for iVar in cf_attributes.keys(): - print(cice_grid.ds[iVar]) - - for jAttr in cf_attributes[iVar].keys(): - assert cice_grid.ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr] + for iVar in grids["cice"].ds.keys(): + if iVar != "crs": # test seperately + for jAttr in cf_attributes[iVar].keys(): + assert grids["cice"].ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr] -def test_crs_exist(cice_grid): +def test_crs_exist(grids): # Test: has the crs been added ? # todo: open with GDAL and rioxarray and confirm they find the crs? - assert hasattr(cice_grid.ds, "crs") - assert hasattr(cice_grid.kmt_ds, "crs") + assert hasattr(grids["cice"].ds, "crs") + assert hasattr(grids["cice"].kmt_ds, "crs") -def test_inputs_logged(cice_grid, mom_grid): +def test_inputs_logged(grids, mom_grid): # Test: have the source data been logged ? - input_md5 = run(["md5sum", cice_grid.ds.inputfile], capture_output=True, text=True) + input_md5 = run(["md5sum", mom_grid.path], capture_output=True, text=True) input_md5 = input_md5.stdout.split(" ")[0] - mask_md5 = run(["md5sum", cice_grid.kmt_ds.inputfile], capture_output=True, text=True) + mask_md5 = run(["md5sum", mom_grid.mask_path], capture_output=True, text=True) mask_md5 = mask_md5.stdout.split(" ")[0] - for ds in [cice_grid.ds, cice_grid.kmt_ds]: - assert ( - ds.inputfile - == ( - mom_grid.path - + " (md5 hash: " - + input_md5 - + "), " - + mom_grid.mask_path - + " (md5 hash: " - + mask_md5 - + ")" - ), - "inputfile attribute incorrect ({ds.inputfile} != {mom_grid.path})", - ) + for ds in [grids["cice"].ds, grids["cice"].kmt_ds]: + assert ds.inputfile == ( + mom_grid.path + " (md5 hash: " + input_md5 + "), " + mom_grid.mask_path + " (md5 hash: " + mask_md5 + ")" + ), "inputfile attribute incorrect ({ds.inputfile} != {mom_grid.path})" assert hasattr(ds, "inputfile"), "inputfile attribute missing" assert hasattr(ds, "history"), "history attribute missing" + + +def test_variant(mom_grid, tmp_path): + # Is a error given for variant not equal to None or 'cice5-auscom' + + mom = MomGrid.fromfile(mom_grid.path, mask_file=mom_grid.mask_path) + cice = CiceGrid.fromgrid(mom) + + # invalid variant (="andrew") + with pytest.raises(NotImplementedError, match="andrew not recognised"): + cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc", variant="andrew") + + # valid variant (="cice5-auscom") + try: + cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc", variant="cice5-auscom") + except: + assert False, "Failed to write cice grid with valid input arguments provided" + + # valid variant (default = None) + try: + cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc") + except: + assert False, "Failed to write cice grid with 'None' variant" diff --git a/test/test_grids.py b/test/test_grids.py index 410b034..26114eb 100644 --- a/test/test_grids.py +++ b/test/test_grids.py @@ -3,14 +3,11 @@ import os import sys import numpy as np -import netCDF4 as nc my_dir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(my_dir, "../")) -from esmgrids.mom_grid import MomGrid # noqa from esmgrids.core2_grid import Core2Grid # noqa -from esmgrids.cice_grid import CiceGrid # noqa from esmgrids.util import calc_area_of_polygons # noqa data_tarball = "test_data.tar.gz"