diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 0c190c0..1e73585 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -90,6 +90,19 @@ def compare_datasets(a, b): assert max_delta == 0.0, (date, param, a_, b_, a_ - b_, max_delta) +def compare_statistics(ds1, ds2): + vars1 = ds1.variables + vars2 = ds2.variables + assert len(vars1) == len(vars2) + for v1, v2 in zip(vars1, vars2): + idx1 = ds1.name_to_index[v1] + idx2 = ds2.name_to_index[v2] + assert (ds1.statistics["mean"][idx1] == ds2.statistics["mean"][idx2]).all() + assert (ds1.statistics["stdev"][idx1] == ds2.statistics["stdev"][idx2]).all() + assert (ds1.statistics["maximum"][idx1] == ds2.statistics["maximum"][idx2]).all() + assert (ds1.statistics["minimum"][idx1] == ds2.statistics["minimum"][idx2]).all() + + class Comparer: def __init__(self, name, output_path=None, reference_path=None): self.name = name @@ -106,8 +119,7 @@ def __init__(self, name, output_path=None, reference_path=None): def compare(self): compare_dot_zattrs(self.z_output.attrs, self.z_reference.attrs) compare_datasets(self.ds_output, self.ds_reference) - # not implemented : - # compare_statistics(self.z_output, self.z_reference) + compare_statistics(self.ds_output, self.ds_reference) @pytest.mark.parametrize("name", NAMES)