Skip to content

Commit

Permalink
Add anglet variant and change x/y var names
Browse files Browse the repository at this point in the history
  • Loading branch information
anton-seaice committed Apr 22, 2024
1 parent 5191fb5 commit 5072279
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 61 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@ repos:
- id: black
name: black
entry: black
args: ['esmgrids', 'test']
language: system
types: [python]
28 changes: 18 additions & 10 deletions esmgrids/cice_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def fromfile(cls, h_grid_def, mask_file=None, description="CICE tripolar"):
area_t = f.variables["tarea"][:]
area_u = f.variables["uarea"][:]

angle_t = np.rad2deg(f.variables["angleT"][:])
try:
angle_t = np.rad2deg(f.variables["anglet"][:])
except KeyError:
angle_t = np.rad2deg(f.variables["angleT"][:])

angle_u = np.rad2deg(f.variables["angle"][:])

if "clon_t" in f.variables:
Expand Down Expand Up @@ -69,12 +73,12 @@ def _create_2d_nc_var(self, f, name):
return f.createVariable(
name,
"f8",
dimensions=("ny", "nx"),
dimensions=("nj", "ni"),
compression="zlib",
complevel=1,
)

def write(self, grid_filename, mask_filename, metadata=None):
def write(self, grid_filename, mask_filename, metadata=None, variant="cice5-auscom"):
"""
Write out CICE grid to netcdf
Expand All @@ -92,10 +96,10 @@ def write(self, grid_filename, mask_filename, metadata=None):
f = nc.Dataset(grid_filename, "w")

# Create dimensions.
f.createDimension("nx", self.num_lon_points)
# nx is the grid_longitude but doesn't have a value other than its index
f.createDimension("ny", self.num_lat_points)
# ny is the grid_latitude but doesn't have a value other than its index
f.createDimension("ni", self.num_lon_points)
# ni is the grid_longitude but doesn't have a value other than its index
f.createDimension("nj", self.num_lat_points)
# nj is the grid_latitude but doesn't have a value other than its index

# Make all CICE grid variables.
# names are based on https://cfconventions.org/Data/cf-standard-names/current/build/cf-standard-name-table.html
Expand Down Expand Up @@ -135,7 +139,11 @@ def write(self, grid_filename, mask_filename, metadata=None):
angle.standard_name = "angle_of_rotation_from_east_to_x"
angle.coordinates = "ulat ulon"
angle.grid_mapping = "crs"
angleT = self._create_2d_nc_var(f, "angleT")

if variant == "cice5-auscom":
angleT = self._create_2d_nc_var(f, "angleT")
else: # variant==cice6
angleT = self._create_2d_nc_var(f, "anglet")
angleT.units = "radians"
angleT.long_name = "Rotation angle of T cells."
angleT.standard_name = "angle_of_rotation_from_east_to_x"
Expand Down Expand Up @@ -185,8 +193,8 @@ def write(self, grid_filename, mask_filename, metadata=None):
# Mask file
f = nc.Dataset(mask_filename, "w")

f.createDimension("nx", self.num_lon_points)
f.createDimension("ny", self.num_lat_points)
f.createDimension("ni", self.num_lon_points)
f.createDimension("nj", self.num_lat_points)
mask = self._create_2d_nc_var(f, "kmt")
mask.grid_mapping = "crs"
mask.standard_name = "sea_binary_mask"
Expand Down
6 changes: 4 additions & 2 deletions esmgrids/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ def cice_from_mom():
parser.add_argument("--ocean_mask", type=str, help="Input MOM ocean_mask.nc mask file")
parser.add_argument("--cice_grid", type=str, default="grid.nc", help="Output CICE grid file")
parser.add_argument("--cice_kmt", type=str, default="kmt.nc", help="Output CICE kmt file")
parser.add_argument("--cice_variant", type=str, default="cice5-auscom", help="cice variant")

args = parser.parse_args()
ocean_hgrid = os.path.abspath(args.ocean_hgrid)
ocean_mask = os.path.abspath(args.ocean_mask)
cice_grid = os.path.abspath(args.cice_grid)
cice_kmt = os.path.abspath(args.cice_kmt)
cice_variant = args.cice_variant

version = safe_version()
runcmd = (
f"Created using https://github.com/COSIMA/esmgrids {version}: "
f"cice_from_mom --ocean_hgrid={ocean_hgrid} --ocean_mask={ocean_mask} "
f"--cice_grid={cice_grid} --cice_kmt={cice_kmt}"
f"--cice_grid={cice_grid} --cice_kmt={cice_kmt} --cice_variant={cice_variant}"
)
provenance_metadata = {
"inputfile": (
Expand All @@ -37,4 +39,4 @@ def cice_from_mom():

mom = MomGrid.fromfile(ocean_hgrid, mask_file=ocean_mask)
cice = CiceGrid.fromgrid(mom)
cice.write(cice_grid, cice_kmt, metadata=provenance_metadata)
cice.write(cice_grid, cice_kmt, metadata=provenance_metadata, variant=cice_variant)
113 changes: 65 additions & 48 deletions test/test_cice_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# going higher resolution than 0.1 has too much computational cost
_test_resolutions = [4, 0.1]

_variants = ["cice5-auscom", "cice6", None]


# so that our fixtures are only create once in this pytest module, we need this special version of 'tmp_path'
@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,39 +55,31 @@ def __init__(self, res, tmp_path):
class CiceGridFixture:
"""Make the CICE grid, using script under test"""

def __init__(self, mom_grid, tmp_path):
def __init__(self, mom_grid, tmp_path, variant):
self.path = str(tmp_path) + "/grid.nc"
self.kmt_path = str(tmp_path) + "/kmt.nc"
run(
[
"cice_from_mom",
"--ocean_hgrid",
mom_grid.path,
"--ocean_mask",
mom_grid.mask_path,
"--cice_grid",
self.path,
"--cice_kmt",
self.kmt_path,
]
)
self.ds = xr.open_dataset(self.path, decode_cf=False)
self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False)

run_cmd = [
"cice_from_mom",
"--ocean_hgrid",
mom_grid.path,
"--ocean_mask",
mom_grid.mask_path,
"--cice_grid",
self.path,
"--cice_kmt",
self.kmt_path,
]
if variant is not None:
run_cmd.append("--cice_variant")
run_cmd.append(variant)
run(run_cmd)

# pytest doesn't support class fixtures, so we need these two constructor funcs
@pytest.fixture(scope="module", params=_test_resolutions)
def mom_grid(request, tmp_path):
return MomGridFixture(request.param, tmp_path)


@pytest.fixture(scope="module")
def cice_grid(mom_grid, tmp_path):
return CiceGridFixture(mom_grid, tmp_path)
self.ds = xr.open_dataset(self.path, decode_cf=False)
self.kmt_ds = xr.open_dataset(self.kmt_path, decode_cf=False)


@pytest.fixture(scope="module")
def test_grid_ds(mom_grid):
def gen_grid_ds(mom_grid, variant):
# this generates the expected answers
# In simple terms the MOM supergrid has four cells for each model grid cell. The MOM supergrid includes all edges (left and right) but CICE only uses right/east edges. (e.g. For points/edges of first cell: 0,0 is SW corner, 1,1 is middle of cell, 2,2, is NE corner/edges)

Expand All @@ -105,7 +99,10 @@ def test_grid_ds(mom_grid):
test_grid["tlon"] = deg2rad(t_points.x)

test_grid["angle"] = deg2rad(u_points.angle_dx) # angle at u point
test_grid["angleT"] = deg2rad(t_points.angle_dx)
if variant == "cice5-auscom" or variant is None:
test_grid["angleT"] = deg2rad(t_points.angle_dx)
else: # cice6
test_grid["anglet"] = deg2rad(t_points.angle_dx)

# length of top (northern) edge of cells
test_grid["htn"] = ds.dx.isel(nyp=slice(2, None, 2)).coarsen(nx=2).sum() * 100
Expand All @@ -128,33 +125,48 @@ def test_grid_ds(mom_grid):
return test_grid


# pytest doesn't support class fixtures, so we need these two constructor funcs
@pytest.fixture(scope="module", params=_test_resolutions)
def mom_grid(request, tmp_path):
return MomGridFixture(request.param, tmp_path)


# the variant neews to be the same for both the cice_grid and the test_grid, so bundle them
@pytest.fixture(scope="module", params=_variants)
def grids(request, mom_grid, tmp_path):
return {"cice": CiceGridFixture(mom_grid, tmp_path, request.param), "test_ds": gen_grid_ds(mom_grid, request.param)}


# ----------------
# the tests in earnest:


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_cice_var_list(cice_grid, test_grid_ds):
def test_cice_var_list(grids):
# Test : Are there missing vars in cice_grid?
assert set(test_grid_ds.variables).difference(cice_grid.ds.variables) == set()
assert set(grids["test_ds"].variables).difference(grids["cice"].ds.variables) == set()


@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_cice_grid(cice_grid, test_grid_ds):
def test_cice_grid(grids):
# Test : Is the data the same as the test_grid
for jVar in test_grid_ds.variables:
assert_allclose(cice_grid.ds[jVar], test_grid_ds[jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch")
for jVar in grids["test_ds"].variables:
assert_allclose(
grids["cice"].ds[jVar], grids["test_ds"][jVar], rtol=1e-13, verbose=True, err_msg=f"{jVar} mismatch"
)


def test_cice_kmt(mom_grid, cice_grid):
def test_cice_kmt(mom_grid, grids):
# Test : does the mask match
mask = mom_grid.mask_ds.mask
kmt = cice_grid.kmt_ds.kmt
kmt = grids["cice"].kmt_ds.kmt

assert_allclose(mask, kmt, rtol=1e-13, verbose=True, err_msg="mask mismatch")


def test_cice_grid_attributes(cice_grid):
def test_cice_grid_attributes(grids):
# Test: do the expected attributes to exist in the cice ds

cf_attributes = {
"ulat": {"standard_name": "latitude", "units": "radians"},
"ulon": {"standard_name": "longitude", "units": "radians"},
Expand Down Expand Up @@ -184,33 +196,38 @@ def test_cice_grid_attributes(cice_grid):
"grid_mapping": "crs",
"coordinates": "tlat tlon",
},
"anglet": {
"standard_name": "angle_of_rotation_from_east_to_x",
"units": "radians",
"grid_mapping": "crs",
"coordinates": "tlat tlon",
},
"htn": {"units": "cm", "coordinates": "ulat tlon", "grid_mapping": "crs"},
"hte": {"units": "cm", "coordinates": "tlat ulon", "grid_mapping": "crs"},
}

for iVar in cf_attributes.keys():
print(cice_grid.ds[iVar])

for jAttr in cf_attributes[iVar].keys():
assert cice_grid.ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr]
for iVar in grids["cice"].ds.keys():
if iVar != "crs": # test seperately
for jAttr in cf_attributes[iVar].keys():
assert grids["cice"].ds[iVar].attrs[jAttr] == cf_attributes[iVar][jAttr]


def test_crs_exist(cice_grid):
def test_crs_exist(grids):
# Test: has the crs been added ?
# todo: open with GDAL and rioxarray and confirm they find the crs?
assert hasattr(cice_grid.ds, "crs")
assert hasattr(cice_grid.kmt_ds, "crs")
assert hasattr(grids["cice"].ds, "crs")
assert hasattr(grids["cice"].kmt_ds, "crs")


def test_inputs_logged(cice_grid, mom_grid):
def test_inputs_logged(grids, mom_grid):
# Test: have the source data been logged ?

input_md5 = run(["md5sum", cice_grid.ds.inputfile], capture_output=True, text=True)
input_md5 = run(["md5sum", grids["cice"].ds.inputfile], capture_output=True, text=True)
input_md5 = input_md5.stdout.split(" ")[0]
mask_md5 = run(["md5sum", cice_grid.kmt_ds.inputfile], capture_output=True, text=True)
mask_md5 = run(["md5sum", grids["cice"].kmt_ds.inputfile], capture_output=True, text=True)
mask_md5 = mask_md5.stdout.split(" ")[0]

for ds in [cice_grid.ds, cice_grid.kmt_ds]:
for ds in [grids["cice"].ds, grids["cice"].kmt_ds]:
assert (
ds.inputfile
== (
Expand Down

0 comments on commit 5072279

Please sign in to comment.