From 8f4b24ce84bce6fb375f1bc77d816ccf402b9383 Mon Sep 17 00:00:00 2001 From: "stephen.worsley" Date: Mon, 22 Jul 2024 15:34:10 +0100 Subject: [PATCH] fix unmasked connectivity bug --- esmf_regrid/experimental/unstructured_regrid.py | 12 +++++++++--- .../unstructured_regrid/test_MeshInfo.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/esmf_regrid/experimental/unstructured_regrid.py b/esmf_regrid/experimental/unstructured_regrid.py index 39960be9..1dc1bb24 100644 --- a/esmf_regrid/experimental/unstructured_regrid.py +++ b/esmf_regrid/experimental/unstructured_regrid.py @@ -99,9 +99,15 @@ def _as_esmf_info(self): nodeCoord = self.node_coords.flatten() nodeOwner = np.zeros([num_node]) # regridding currently serial elemId = np.arange(1, num_elem + 1) - elemType = self.fnc.count(axis=1) - # Experiments seem to indicate that ESMF is using 0 indexing here - elemConn = self.fnc.compressed() - self.nsi + if np.ma.isMaskedArray(self.fnc): + elemType = self.fnc.count(axis=1) + # Experiments seem to indicate that ESMF is using 0 indexing here + elemConn = self.fnc.compressed() - self.nsi + else: + elemType = self.fnc.shape[1] * np.ones(self.fnc.shape[0]) + # Experiments seem to indicate that ESMF is using 0 indexing here + elemConn = self.fnc.flatten() - self.nsi + elemCoord = self.elem_coords result = ( num_node, diff --git a/esmf_regrid/tests/unit/experimental/unstructured_regrid/test_MeshInfo.py b/esmf_regrid/tests/unit/experimental/unstructured_regrid/test_MeshInfo.py index fe2ced0b..bea8a6f1 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_regrid/test_MeshInfo.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_regrid/test_MeshInfo.py @@ -46,6 +46,20 @@ def test_make_mesh(): assert esmf_mesh_0.__repr__() == esmf_mesh_1.__repr__() == expected_repr +def test_connectivity_mask_equivalence(): + """Test for handling connectivity masks :meth:`~esmf_regrid.esmf_regridder.GridInfo.make_esmf_field`.""" + coords, nodes, _ = _make_small_mesh_args() + coords = coords[:-1] + nodes = nodes[:, :-1] + unmasked_nodes = nodes.filled() + mesh = MeshInfo(coords, unmasked_nodes, 0) + esmf_mesh_unmasked = mesh.make_esmf_field() + + mesh = MeshInfo(coords, nodes, 0) + esmf_mesh_masked = mesh.make_esmf_field() + assert esmf_mesh_unmasked.__repr__() == esmf_mesh_masked.__repr__() + + def test_regrid_with_mesh(): """Basic test for regridding with :meth:`~esmf_regrid.esmf_regridder.GridInfo.make_esmf_field`.""" mesh_args = _make_small_mesh_args()