Skip to content

Commit

Permalink
Added tests for new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
liammegill committed Nov 15, 2024
1 parent 73b8255 commit f8a2b9f
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 12 deletions.
41 changes: 30 additions & 11 deletions openairclim/calc_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def calc_cont_grid_areas(lat: np.ndarray, lon: np.ndarray) -> np.ndarray:
return areas


def interpolate_base_inv_dict(inv_dict, base_inv_dict, intrp_vars):
def interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars):
"""Create base emission inventories for years in `inv_dict` that do not
exist in `base_inv_dict`.
Expand All @@ -132,26 +132,39 @@ def interpolate_base_inv_dict(inv_dict, base_inv_dict, intrp_vars):
keys are inventory years.
base_inv_dict (dict): Dictionary of base emission inventory
xarrays, keys are inventory years.
intrp_vars (array-like): List of strings of data variables in
intrp_vars (list): List of strings of data variables in
base_inv_dict that are to be included in the missing base
inventories, e.g. ["distance", "fuel"].
Returns:
dict: Dictionary of base emission inventory xarrays including any
missing years compared to inv_dict, keys are inventory years.
"""

# TODO give user the option to select different regridding (currently only nearest)
# and interpolation (currently only linear) methods

# pre-conditions
Note:
A custom nearest neighbour method is used for regridding and a linear
interpolation method for calculating data in missing years. In future
versions, the user will be able to select methods for both.
"""

# if base_inv_dict is empty, then return the empty dictionary
# otherwise, continue with the calculations
# if base_inv_dict is empty, then return the empty dictionary.
if not base_inv_dict:
return {}

# pre-conditions
assert inv_dict, "inv_dict cannot be empty."
assert intrp_vars, "intrp_vars cannot be empty."
if base_inv_dict:
assert min(base_inv_dict.keys()) <= min(inv_dict.keys()), "The " \
f"inv_dict key {min(inv_dict.keys())} is less than the earliest " \
f"base_inv_dict key {min(base_inv_dict.keys())}."
assert max(base_inv_dict.keys()) >= max(inv_dict.keys()), "The " \
f"inv_dict key {max(inv_dict.keys())} is larger than the largest "\
f"base_inv_dict key {max(base_inv_dict.keys())}."
for intrp_var in intrp_vars:
for yr in base_inv_dict.keys():
assert intrp_var in base_inv_dict[yr], "Variable " \
f"'{intrp_var}' not present in base_inv_dict."

# get years that need to be calculated
inv_yrs = list(inv_dict.keys())
base_yrs = list(base_inv_dict.keys())
Expand Down Expand Up @@ -215,7 +228,7 @@ def interpolate_base_inv_dict(inv_dict, base_inv_dict, intrp_vars):
# linear weighting
w = (yr - yrs_lb[i]) / (yrs_ub[i] - yrs_lb[i])
ds_i = regrid_base_inv_dict[yrs_lb[i]] * (1 - w) + \
regrid_base_inv_dict[yrs_ub[i]] * w
regrid_base_inv_dict[yrs_ub[i]] * w

# reset index to match input inventories
ds_i_flat = ds_i.stack(index=["lon", "lat", "plev"])
Expand All @@ -225,6 +238,12 @@ def interpolate_base_inv_dict(inv_dict, base_inv_dict, intrp_vars):
# sort full_base_inv_dict
full_base_inv_dict = dict(sorted(full_base_inv_dict.items()))

# post-conditions
if intrp_yrs:
for yr in intrp_yrs:
assert yr in full_base_inv_dict, "Missing years not included in " \
"output dictionary."

return full_base_inv_dict


Expand Down
2 changes: 1 addition & 1 deletion openairclim/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def run(file_name):

# if necessary, augment base_inv_dict with years in inv_dict not
# present in base_inv_dict
base_inv_dict = oac.interpolate_base_inv_dict(
base_inv_dict = oac.interp_base_inv_dict(
inv_dict, base_inv_dict, ["distance"]
)

Expand Down
124 changes: 124 additions & 0 deletions tests/calc_cont_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,55 @@ def test_plev_vals(self):
"be at altitudes above ground level - defined as 1014 hPa."


class TestCheckContInput:
"""Tests function check_cont_input(ds_cont, inv_dict, base_inv_dict)"""

@pytest.fixture(scope="class")
def inv_dict(self):
"""Fixture to create an example inv_dict."""
return {2020: create_test_inv(year=2020),
2030: create_test_inv(year=2030),
2040: create_test_inv(year=2040),
2050: create_test_inv(year=2050)}

@pytest.fixture(scope="class")
def ds_cont(self):
"""Fixture to load an example ds_cont file."""
return create_test_resp_cont()

def test_year_out_of_range(self, ds_cont, inv_dict):
"""Tests behaviour when inv_dict includes a year that is out of range
of the years in base_inv_dict."""
# test year too low
base_inv_dict = {2030: create_test_inv(year=2030),
2050: create_test_inv(year=2050)}
with pytest.raises(AssertionError):
oac.check_cont_input(ds_cont, inv_dict, base_inv_dict)
# test year too high
base_inv_dict = {2020: create_test_inv(year=2020),
2040: create_test_inv(year=2040)}
with pytest.raises(AssertionError):
oac.check_cont_input(ds_cont, inv_dict, base_inv_dict)

def test_missing_ds_cont_vars(self, ds_cont, inv_dict):
"""Tests ds_cont with missing data variable."""
base_inv_dict = inv_dict
ds_cont_incorrect = ds_cont.drop_vars(["ISS"])
with pytest.raises(AssertionError):
oac.check_cont_input(ds_cont_incorrect, inv_dict, base_inv_dict)

def test_incorrect_ds_cont_coord_unit(self, ds_cont, inv_dict):
"""Tests ds_cont with incorrect coordinates and units."""
base_inv_dict = inv_dict
ds_cont_incorrect1 = ds_cont.copy()
ds_cont_incorrect1.lat.attrs["units"] = "deg"
ds_cont_incorrect2 = ds_cont.copy()
ds_cont_incorrect2 = ds_cont_incorrect2.rename({"lat": "latitude"})
for ds_cont_incorrect in [ds_cont_incorrect1, ds_cont_incorrect2]:
with pytest.raises(AssertionError):
oac.check_cont_input(ds_cont_incorrect, inv_dict, base_inv_dict)


class TestCalcContGridAreas:
"""Tests function calc_cont_grid_areas(lat, lon)"""

Expand All @@ -73,6 +122,81 @@ def test_unsorted_longitudes(self):
"unsuccessful."


class TestInterpBaseInvDict:
"""Tests function interp_base_inv_dict(inv_dict, base_inv_dict,
intrp_vars)"""

@pytest.fixture(scope="class")
def inv_dict(self):
"""Fixture to create an example inv_dict."""
return {2020: create_test_inv(year=2020),
2030: create_test_inv(year=2030),
2040: create_test_inv(year=2040),
2050: create_test_inv(year=2050)}

def test_empty_base_inv_dict(self, inv_dict):
"""Tests an empty base_inv_dict."""
base_inv_dict = {}
intrp_vars = ["distance"]
result = oac.interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars)
assert not result, "Expected empty output when base_inv_dict is empty."

def test_empty_inv_dict(self):
"""Tests an empty inv_dict."""
base_inv_dict = {2020: create_test_inv(year=2020),
2050: create_test_inv(year=2050)}
intrp_vars = ["distance"]
with pytest.raises(AssertionError):
oac.interp_base_inv_dict({}, base_inv_dict, intrp_vars)

def test_no_missing_years(self, inv_dict):
"""Tests behaviour when all keys in inv_dict are in base_inv_dict."""
base_inv_dict = inv_dict.copy()
intrp_vars = ["distance"]
result = oac.interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars)
assert result == base_inv_dict, "Expected no change to base_inv_dict."

def test_missing_years(self, inv_dict):
"""Tests behaviour when there is a key in inv_dict that is not in
base_inv_dict."""
base_inv_dict = {2020: create_test_inv(year=2020),
2050: create_test_inv(year=2050)}
intrp_vars = ["distance"]
result = oac.interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars)
assert 2030 in result, "Missing year 2030 should have been calculated."

# compare the sum of the distances
tot_dist_2020 = base_inv_dict[2020]["distance"].data.sum()
tot_dist_2050 = base_inv_dict[2050]["distance"].data.sum()
exp_tot_dist_2030 = tot_dist_2020 + (tot_dist_2050 - tot_dist_2020) / 3
act_tot_dist_2030 = result[2030]["distance"].data.sum()
np.testing.assert_allclose(act_tot_dist_2030, exp_tot_dist_2030)

def test_incorrect_intrp_vars(self, inv_dict):
"""Tests behaviour when the list of values to be interpolated includes
a value not in inv_dict or base_inv_dict."""
base_inv_dict = {2020: create_test_inv(year=2020),
2050: create_test_inv(year=2050)}
intrp_vars = ["wrong-value"]
with pytest.raises(AssertionError):
oac.interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars)

def test_year_out_of_range(self, inv_dict):
"""Tests behaviour when inv_dict includes a year that is out of range
of the years in base_inv_dict."""
# test year too low
base_inv_dict = {2030: create_test_inv(year=2030),
2050: create_test_inv(year=2050)}
intrp_vars = ["distance"]
with pytest.raises(AssertionError):
oac.interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars)
# test year too high
base_inv_dict = {2020: create_test_inv(year=2020),
2040: create_test_inv(year=2040)}
with pytest.raises(AssertionError):
oac.interp_base_inv_dict(inv_dict, base_inv_dict, intrp_vars)


class TestCalcContWeighting:
"""Tests function calc_cont_weighting(config, val)"""

Expand Down

0 comments on commit f8a2b9f

Please sign in to comment.