Skip to content

Commit

Permalink
Fix formatting of cice grid file (#18)
Browse files Browse the repository at this point in the history
* Add anglet variant (for cice6) and change x/y var names to match those uses in cice output
  • Loading branch information
anton-seaice authored Apr 25, 2024
1 parent f5d137f commit 79f2952
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 79 deletions.
36 changes: 26 additions & 10 deletions esmgrids/cice_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def fromfile(cls, h_grid_def, mask_file=None, description="CICE tripolar"):
area_t = f.variables["tarea"][:]
area_u = f.variables["uarea"][:]

angle_t = np.rad2deg(f.variables["angleT"][:])
try:
angle_t = np.rad2deg(f.variables["anglet"][:])
except KeyError:
angle_t = np.rad2deg(f.variables["angleT"][:])

angle_u = np.rad2deg(f.variables["angle"][:])

if "clon_t" in f.variables:
Expand Down Expand Up @@ -69,12 +73,12 @@ def _create_2d_nc_var(self, f, name):
return f.createVariable(
name,
"f8",
dimensions=("ny", "nx"),
dimensions=("nj", "ni"),
compression="zlib",
complevel=1,
)

def write(self, grid_filename, mask_filename, metadata=None):
def write(self, grid_filename, mask_filename, metadata=None, variant=None):
"""
Write out CICE grid to netcdf
Expand All @@ -86,16 +90,23 @@ def write(self, grid_filename, mask_filename, metadata=None):
The name of the mask file to write
metadata: dict
Any global or variable metadata attributes to add to the files being written
variant: str
Use variant='cice5-auscom' for access-om2/cice5-auscom builds, otherwise use None
"""

if variant is not None and variant != "cice5-auscom":
raise NotImplementedError(f"{variant} not recognised")

# Grid file
f = nc.Dataset(grid_filename, "w")

# Create dimensions.
f.createDimension("nx", self.num_lon_points)
# nx is the grid_longitude but doesn't have a value other than its index
f.createDimension("ny", self.num_lat_points)
# ny is the grid_latitude but doesn't have a value other than its index
f.createDimension(
"ni", self.num_lon_points
) # ni is the grid_longitude but doesn't have a value other than its index
f.createDimension(
"nj", self.num_lat_points
) # nj is the grid_latitude but doesn't have a value other than its index

# Make all CICE grid variables.
# names are based on https://cfconventions.org/Data/cf-standard-names/current/build/cf-standard-name-table.html
Expand Down Expand Up @@ -135,7 +146,12 @@ def write(self, grid_filename, mask_filename, metadata=None):
angle.standard_name = "angle_of_rotation_from_east_to_x"
angle.coordinates = "ulat ulon"
angle.grid_mapping = "crs"
angleT = self._create_2d_nc_var(f, "angleT")

if variant == "cice5-auscom":
angleT = self._create_2d_nc_var(f, "angleT")
elif variant is None:
angleT = self._create_2d_nc_var(f, "anglet")

angleT.units = "radians"
angleT.long_name = "Rotation angle of T cells."
angleT.standard_name = "angle_of_rotation_from_east_to_x"
Expand Down Expand Up @@ -185,8 +201,8 @@ def write(self, grid_filename, mask_filename, metadata=None):
# Mask file
f = nc.Dataset(mask_filename, "w")

f.createDimension("nx", self.num_lon_points)
f.createDimension("ny", self.num_lat_points)
f.createDimension("ni", self.num_lon_points)
f.createDimension("nj", self.num_lat_points)
mask = self._create_2d_nc_var(f, "kmt")
mask.grid_mapping = "crs"
mask.standard_name = "sea_binary_mask"
Expand Down
8 changes: 6 additions & 2 deletions esmgrids/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ def cice_from_mom():
parser.add_argument("--ocean_mask", type=str, help="Input MOM ocean_mask.nc mask file")
parser.add_argument("--cice_grid", type=str, default="grid.nc", help="Output CICE grid file")
parser.add_argument("--cice_kmt", type=str, default="kmt.nc", help="Output CICE kmt file")
parser.add_argument(
"--cice_variant", type=str, default=None, help="Cice variant, valid options = [None, 'cice5-auscom'] "
)

args = parser.parse_args()
ocean_hgrid = os.path.abspath(args.ocean_hgrid)
ocean_mask = os.path.abspath(args.ocean_mask)
cice_grid = os.path.abspath(args.cice_grid)
cice_kmt = os.path.abspath(args.cice_kmt)
cice_variant = args.cice_variant

version = safe_version()
runcmd = (
f"Created using https://github.com/COSIMA/esmgrids {version}: "
f"cice_from_mom --ocean_hgrid={ocean_hgrid} --ocean_mask={ocean_mask} "
f"--cice_grid={cice_grid} --cice_kmt={cice_kmt}"
f"--cice_grid={cice_grid} --cice_kmt={cice_kmt} --cice_variant={cice_variant}"
)
provenance_metadata = {
"inputfile": (
Expand All @@ -37,4 +41,4 @@ def cice_from_mom():

mom = MomGrid.fromfile(ocean_hgrid, mask_file=ocean_mask)
cice = CiceGrid.fromgrid(mom)
cice.write(cice_grid, cice_kmt, metadata=provenance_metadata)
cice.write(cice_grid, cice_kmt, metadata=provenance_metadata, variant=cice_variant)
168 changes: 104 additions & 64 deletions test/test_cice_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
from subprocess import run
from pathlib import Path

# from esmgrids.cli import cice_from_mom
from esmgrids.mom_grid import MomGrid
from esmgrids.cice_grid import CiceGrid

# create test grids at 4 degrees and 0.1 degrees
# 4 degress is the lowest tested in ocean_model_grid_generator
# going higher resolution than 0.1 has too much computational cost
_test_resolutions = [4, 0.1]

# run test using the valid cice variants
_variants = ["cice5-auscom", None]

# so that our fixtures are only create once in this pytest module, we need this special version of 'tmp_path'

# so that our fixtures are only created once in this pytest module, we need this special version of 'tmp_path'
@pytest.fixture(scope="module")
def tmp_path(tmp_path_factory: pytest.TempdirFactory) -> Path:
return tmp_path_factory.mktemp("temp")
Expand Down Expand Up @@ -53,39 +57,31 @@ def __init__(self, res, tmp_path):
class CiceGridFixture:
"""Make the CICE grid, using script under test"""

def __init__(self, mom_grid, tmp_path):
def __init__(self, mom_grid, tmp_path, variant):
self.path = str(tmp_path) + "/grid.nc"
self.kmt_path = str(tmp_path) + "/kmt.nc"
run(
[
"cice_from_mom",
"--ocean_hgrid",
mom_grid.path,
"--ocean_mask",
mom_grid.mask_path,
"--cice_grid",
self.path,
"--cice_kmt",
self.kmt_path,
]
)
self.ds = xr.open_dataset(self.path, decode_cf=False)
self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False)

run_cmd = [
"cice_from_mom",
"--ocean_hgrid",
mom_grid.path,
"--ocean_mask",
mom_grid.mask_path,
"--cice_grid",
self.path,
"--cice_kmt",
self.kmt_path,
]
if variant is not None:
run_cmd.append("--cice_variant")
run_cmd.append(variant)
run(run_cmd)

# pytest doesn't support class fixtures, so we need these two constructor funcs
@pytest.fixture(scope="module", params=_test_resolutions)
def mom_grid(request, tmp_path):
return MomGridFixture(request.param, tmp_path)


@pytest.fixture(scope="module")
def cice_grid(mom_grid, tmp_path):
return CiceGridFixture(mom_grid, tmp_path)
self.ds = xr.open_dataset(self.path, decode_cf=False)
self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False)


@pytest.fixture(scope="module")
def test_grid_ds(mom_grid):
def gen_grid_ds(mom_grid, variant):
# this generates the expected answers
# In simple terms the MOM supergrid has four cells for each model grid cell. The MOM supergrid includes all edges (left and right) but CICE only uses right/east edges. (e.g. For points/edges of first cell: 0,0 is SW corner, 1,1 is middle of cell, 2,2, is NE corner/edges)

Expand All @@ -105,7 +101,10 @@ def test_grid_ds(mom_grid):
test_grid["tlon"] = deg2rad(t_points.x)

test_grid["angle"] = deg2rad(u_points.angle_dx) # angle at u point
test_grid["angleT"] = deg2rad(t_points.angle_dx)
if variant == "cice5-auscom":
test_grid["angleT"] = deg2rad(t_points.angle_dx)
else: # cice6
test_grid["anglet"] = deg2rad(t_points.angle_dx)

# length of top (northern) edge of cells
test_grid["htn"] = ds.dx.isel(nyp=slice(2, None, 2)).coarsen(nx=2).sum() * 100
Expand All @@ -128,33 +127,57 @@ def test_grid_ds(mom_grid):
return test_grid


# pytest doesn't support class fixtures, so we need these two constructor funcs
@pytest.fixture(scope="module", params=_test_resolutions)
def mom_grid(request, tmp_path):
return MomGridFixture(request.param, tmp_path)


# the variant neews to be the same for both the cice_grid and the test_grid, so bundle them
@pytest.fixture(scope="module", params=_variants)
def grids(request, mom_grid, tmp_path):
return {"cice": CiceGridFixture(mom_grid, tmp_path, request.param), "test_ds": gen_grid_ds(mom_grid, request.param)}


# ----------------
# the tests in earnest:


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_cice_var_list(cice_grid, test_grid_ds):
def test_cice_var_list(grids):
# Test : Are there missing vars in cice_grid?
assert set(test_grid_ds.variables).difference(cice_grid.ds.variables) == set()
assert set(grids["test_ds"].variables).difference(grids["cice"].ds.variables) == set()


def test_cice_dims(grids):
# Test : Are the dim names consistent with cice history output?
assert set(grids["cice"].ds.dims) == set(
["ni", "nj"]
), "cice dimension names should be 'ni','nj' to be consistent with history output"
assert grids["cice"].ds.sizes["ni"] == len(grids["test_ds"].nx)
assert grids["cice"].ds.sizes["nj"] == len(grids["test_ds"].ny)


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_cice_grid(cice_grid, test_grid_ds):
def test_cice_grid(grids):
# Test : Is the data the same as the test_grid
for jVar in test_grid_ds.variables:
assert_allclose(cice_grid.ds[jVar], test_grid_ds[jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch")
for jVar in grids["test_ds"].variables:
assert_allclose(
grids["cice"].ds[jVar], grids["test_ds"][jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch"
)


def test_cice_kmt(mom_grid, cice_grid):
def test_cice_kmt(mom_grid, grids):
# Test : does the mask match
mask = mom_grid.mask_ds.mask
kmt = cice_grid.kmt_ds.kmt
kmt = grids["cice"].kmt_ds.kmt

assert_allclose(mask, kmt, rtol=1e-13, verbose=True, err_msg="mask mismatch")


def test_cice_grid_attributes(cice_grid):
def test_cice_grid_attributes(grids):
# Test: do the expected attributes to exist in the cice ds
# To-do: rewrite test using the CF-checker (or similar)
cf_attributes = {
"ulat": {"standard_name": "latitude", "units": "radians"},
"ulon": {"standard_name": "longitude", "units": "radians"},
Expand Down Expand Up @@ -184,48 +207,65 @@ def test_cice_grid_attributes(cice_grid):
"grid_mapping": "crs",
"coordinates": "tlat tlon",
},
"anglet": {
"standard_name": "angle_of_rotation_from_east_to_x",
"units": "radians",
"grid_mapping": "crs",
"coordinates": "tlat tlon",
},
"htn": {"units": "cm", "coordinates": "ulat tlon", "grid_mapping": "crs"},
"hte": {"units": "cm", "coordinates": "tlat ulon", "grid_mapping": "crs"},
}

for iVar in cf_attributes.keys():
print(cice_grid.ds[iVar])

for jAttr in cf_attributes[iVar].keys():
assert cice_grid.ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr]
for iVar in grids["cice"].ds.keys():
if iVar != "crs": # test seperately
for jAttr in cf_attributes[iVar].keys():
assert grids["cice"].ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr]


def test_crs_exist(cice_grid):
def test_crs_exist(grids):
# Test: has the crs been added ?
# todo: open with GDAL and rioxarray and confirm they find the crs?
assert hasattr(cice_grid.ds, "crs")
assert hasattr(cice_grid.kmt_ds, "crs")
assert hasattr(grids["cice"].ds, "crs")
assert hasattr(grids["cice"].kmt_ds, "crs")


def test_inputs_logged(cice_grid, mom_grid):
def test_inputs_logged(grids, mom_grid):
# Test: have the source data been logged ?

input_md5 = run(["md5sum", cice_grid.ds.inputfile], capture_output=True, text=True)
input_md5 = run(["md5sum", mom_grid.path], capture_output=True, text=True)
input_md5 = input_md5.stdout.split(" ")[0]
mask_md5 = run(["md5sum", cice_grid.kmt_ds.inputfile], capture_output=True, text=True)
mask_md5 = run(["md5sum", mom_grid.mask_path], capture_output=True, text=True)
mask_md5 = mask_md5.stdout.split(" ")[0]

for ds in [cice_grid.ds, cice_grid.kmt_ds]:
assert (
ds.inputfile
== (
mom_grid.path
+ " (md5 hash: "
+ input_md5
+ "), "
+ mom_grid.mask_path
+ " (md5 hash: "
+ mask_md5
+ ")"
),
"inputfile attribute incorrect ({ds.inputfile} != {mom_grid.path})",
)
for ds in [grids["cice"].ds, grids["cice"].kmt_ds]:
assert ds.inputfile == (
mom_grid.path + " (md5 hash: " + input_md5 + "), " + mom_grid.mask_path + " (md5 hash: " + mask_md5 + ")"
), "inputfile attribute incorrect ({ds.inputfile} != {mom_grid.path})"

assert hasattr(ds, "inputfile"), "inputfile attribute missing"

assert hasattr(ds, "history"), "history attribute missing"


def test_variant(mom_grid, tmp_path):
# Is a error given for variant not equal to None or 'cice5-auscom'

mom = MomGrid.fromfile(mom_grid.path, mask_file=mom_grid.mask_path)
cice = CiceGrid.fromgrid(mom)

# invalid variant (="andrew")
with pytest.raises(NotImplementedError, match="andrew not recognised"):
cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc", variant="andrew")

# valid variant (="cice5-auscom")
try:
cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc", variant="cice5-auscom")
except:
assert False, "Failed to write cice grid with valid input arguments provided"

# valid variant (default = None)
try:
cice.write(str(tmp_path) + "/grid2.nc", str(tmp_path) + "/kmt2.nc")
except:
assert False, "Failed to write cice grid with 'None' variant"
3 changes: 0 additions & 3 deletions test/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
import os
import sys
import numpy as np
import netCDF4 as nc

my_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(my_dir, "../"))

from esmgrids.mom_grid import MomGrid # noqa
from esmgrids.core2_grid import Core2Grid # noqa
from esmgrids.cice_grid import CiceGrid # noqa
from esmgrids.util import calc_area_of_polygons # noqa

data_tarball = "test_data.tar.gz"
Expand Down

0 comments on commit 79f2952

Please sign in to comment.