Skip to content

Commit

Permalink
Merge pull request #199 from openghg/Iss198-weighted-EASTASIA
Browse files Browse the repository at this point in the history
Iss198 weighted eastasia
  • Loading branch information
qq23840 authored Sep 2, 2024
2 parents a484a52 + 9341381 commit 84a05ec
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 24 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# OpenGHG Inversions Change Log


- Added option to use 'weighted' algorithm to derive basis functions for EASTASIA domain [#PR 199](https://github.com/openghg/openghg_inversions/pull/199)

# Version 0.2.0

- Added option to pass "mean" and "stdev" as parameters for lognormal BC prior [#PR 190](https://github.com/openghg/openghg_inversions/pull/190)
Expand Down
2 changes: 1 addition & 1 deletion openghg_inversions/basis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ._functions import bucketbasisfunction, quadtreebasisfunction
from ._functions import bucketbasisfunction, quadtreebasisfunction, fixed_outer_regions_basis
from ._wrapper import basis_functions_wrapper
17 changes: 14 additions & 3 deletions openghg_inversions/basis/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _mean_fp_times_mean_flux(
def quadtreebasisfunction(
fp_all: dict,
start_date: str,
domain : str,
emissions_name: list[str] | None = None,
nbasis: int = 100,
abs_flux: bool = False,
Expand All @@ -230,6 +231,8 @@ def quadtreebasisfunction(
fp_all dictionary of datasets as produced from get_data functions
start_date (str):
Start date of period of inversion
domain (str):
Domain across which to calculate basis functions.
emissions_name (list):
List of keyword "source" args used for retrieving emissions files
from the Object store
Expand Down Expand Up @@ -263,17 +266,19 @@ def quadtreebasisfunction(

quad_basis.attrs["creator"] = getpass.getuser()
quad_basis.attrs["date created"] = str(pd.Timestamp.today())
quad_basis.attrs["domain"] = domain

return quad_basis


def bucketbasisfunction(
fp_all: dict,
start_date: str,
domain : str,
emissions_name: list[str] | None = None,
nbasis: int = 100,
abs_flux: bool = False,
mask: xr.DataArray | None = None,
mask: xr.DataArray | None = None
) -> xr.DataArray:
"""Basis functions calculated using a weighted region approach
where each basis function / scaling region contains approximately
Expand All @@ -284,6 +289,8 @@ def bucketbasisfunction(
fp_all dictionary of datasets as produced from get_data functions
start_date (str):
Start date of period of inversion
domain (str):
domain for the basis functions to be calculated over
emissions_name (list):
List of keyword "source" args used for retrieving emissions files
from the Object store
Expand All @@ -306,14 +313,15 @@ def bucketbasisfunction(
fps = _mean_fp_times_mean_flux(flux, footprints, abs_flux=abs_flux, mask=mask).as_numpy()

# use xr.apply_ufunc to keep xarray coords
func = partial(weighted_algorithm, nregion=nbasis, bucket=1)
func = partial(weighted_algorithm, nregion=nbasis, bucket=1, domain=domain)
bucket_basis = xr.apply_ufunc(func, fps)

bucket_basis = bucket_basis.expand_dims({"time": [pd.to_datetime(start_date)]}, axis=-1)
bucket_basis = bucket_basis.rename("basis") # this will be used in merges

bucket_basis.attrs["creator"] = getpass.getuser()
bucket_basis.attrs["date created"] = str(pd.Timestamp.today())
bucket_basis.attrs["domain"] = domain

return bucket_basis

Expand All @@ -330,6 +338,7 @@ def fixed_outer_regions_basis(
fp_all: dict,
start_date: str,
basis_algorithm: str,
domain: str,
emissions_name: list[str] | None = None,
nbasis: int = 100,
abs_flux: bool = False,
Expand All @@ -343,6 +352,8 @@ def fixed_outer_regions_basis(
Start date of period of inference
basis_algorithm (str):
Name of the basis algorithm used. Options are "quadtree", "weighted"
domain (str):
domain for the basis functions to be calculated over
emissions_name (list):
List of keyword "source" args used for retrieving emissions files
from the Object store.
Expand All @@ -367,7 +378,7 @@ def fixed_outer_regions_basis(
mask = intem_regions == 6

basis_function = basis_functions[basis_algorithm].algorithm
inner_region = basis_function(fp_all, start_date, emissions_name, nbasis, abs_flux, mask=mask)
inner_region = basis_function(fp_all, start_date, domain, emissions_name, nbasis, abs_flux, mask=mask)

basis = intem_regions.rename("basis")

Expand Down
4 changes: 2 additions & 2 deletions openghg_inversions/basis/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def basis_functions_wrapper(
elif fix_outer_regions is True:
try:
basis_data_array = fixed_outer_regions_basis(
fp_all, start_date, basis_algorithm, emissions_name, nbasis
fp_all, start_date, basis_algorithm, domain, emissions_name, nbasis
)
except KeyError as e:
raise ValueError(
Expand All @@ -113,7 +113,7 @@ def basis_functions_wrapper(
"Basis algorithm not recognised. Please use either 'quadtree' or 'weighted', or input a basis function file"
) from e
print(f"Using {basis_function.description} to derive basis functions.")
basis_data_array = basis_function.algorithm(fp_all, start_date, emissions_name, nbasis)
basis_data_array = basis_function.algorithm(fp_all, start_date, domain, emissions_name, nbasis)

fp_data = fp_sensitivity(fp_all, basis_func=basis_data_array)

Expand Down
58 changes: 41 additions & 17 deletions openghg_inversions/basis/algorithms/_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,30 @@
from pathlib import Path
import numpy as np
import xarray as xr
import logging

logger = logging.getLogger(__name__)


# BUCKET BASIS FUNCTIONS
def load_landsea_indices() -> np.ndarray:
"""Load UKMO array with indices that separate
land and sea regions in EUROPE domain.
def load_landsea_indices(domain : str) -> np.ndarray:
"""Load array with indices that separate
land and sea regions in specified domain.
Args:
domain (str): domain for which to load landsea indices. Currently only "EASTASIA" or "EUROPE".
Returns :
Array containing 0 (where there is sea)
and 1 (where there is land).
"""
landsea_indices = xr.open_dataset(Path(__file__).parent / "country-EUROPE-UKMO-landsea-2023.nc")
if domain == "EASTASIA":
landsea_indices = xr.open_dataset(Path(__file__).parent / "country-land-sea_EASTASIA.nc")
elif domain == "EUROPE":
landsea_indices = xr.open_dataset(Path(__file__).parent / "country-EUROPE-UKMO-landsea-2023.nc")
else:
logger.warning(f"No land-sea file found for domain {domain}. Defaulting to EUROPE (country-EUROPE-UKMO-landsea-2023.nc)")
landsea_indices = xr.open_dataset(Path(__file__).parent / "country-EUROPE-UKMO-landsea-2023.nc")
return landsea_indices["country"].values


Expand Down Expand Up @@ -65,7 +77,7 @@ def bucket_value_split(
)


def get_nregions(bucket: float, grid: np.ndarray) -> int:
def get_nregions(bucket: float, grid: np.ndarray, domain: str) -> int:
"""Optimize bucket value to number of desired regions.
Args:
Expand All @@ -75,14 +87,17 @@ def get_nregions(bucket: float, grid: np.ndarray) -> int:
2D grid of footprints * flux, or whatever
grid you want to split. Could be: population
data, spatial distribution of bakeries, you choose!
domain:
Domain across which to calculate basis functions.
Currently limited to "EUROPE" or "EASTASIA"
Return :
number of basis functions for bucket value
"""
return np.max(bucket_split_landsea_basis(grid, bucket))
return np.max(bucket_split_landsea_basis(grid, bucket, domain))


def optimize_nregions(bucket: float, grid: np.ndarray, nregion: int, tol: int) -> float:
def optimize_nregions(bucket: float, grid: np.ndarray, nregion: int, tol: int, domain: str) -> float:
"""Optimize bucket value to obtain nregion basis functions
within +/- tol.
Expand All @@ -98,24 +113,27 @@ def optimize_nregions(bucket: float, grid: np.ndarray, nregion: int, tol: int) -
tol:
Tolerance to find number of basis function regions.
i.e. optimizes nregions to +/- tol
domain:
Domain across which to calculate basis functions.
Currently limited to "EUROPE" or "EASTASIA"
Return :
Optimized bucket value
"""
# print(bucket, get_nregions(bucket, grid))
if get_nregions(bucket, grid) <= nregion + tol and get_nregions(bucket, grid) >= nregion - tol:
if get_nregions(bucket, grid, domain) <= nregion + tol and get_nregions(bucket, grid, domain) >= nregion - tol:
return bucket

if get_nregions(bucket, grid) < nregion + tol:
if get_nregions(bucket, grid, domain) < nregion + tol:
bucket *= 0.995
return optimize_nregions(bucket, grid, nregion, tol)
return optimize_nregions(bucket, grid, nregion, tol, domain)

elif get_nregions(bucket, grid) > nregion - tol:
elif get_nregions(bucket, grid, domain) > nregion - tol:
bucket *= 1.005
return optimize_nregions(bucket, grid, nregion, tol)
return optimize_nregions(bucket, grid, nregion, tol, domain)


def bucket_split_landsea_basis(grid: np.ndarray, bucket: float) -> np.ndarray:
def bucket_split_landsea_basis(grid: np.ndarray, bucket: float, domain : str) -> np.ndarray:
"""Same as bucket_split_basis but includes
land-sea split. i.e. basis functions cannot overlap sea and land.
Expand All @@ -126,12 +144,15 @@ def bucket_split_landsea_basis(grid: np.ndarray, bucket: float) -> np.ndarray:
data, spatial distribution of bakeries, you choose!
bucket:
Maximum value for each basis function region
domain:
Domain across which to calculate basis functions.
Currently limited to "EUROPE" or "EASTASIA"
Returns:
2D array with basis function values
"""
landsea_indices = load_landsea_indices()
landsea_indices = load_landsea_indices(domain)
myregions = bucket_value_split(grid, bucket)

mybasis_function = np.zeros(shape=grid.shape)
Expand Down Expand Up @@ -159,7 +180,7 @@ def bucket_split_landsea_basis(grid: np.ndarray, bucket: float) -> np.ndarray:


def nregion_landsea_basis(
grid: np.ndarray, bucket: float = 1, nregion: int = 100, tol: int = 1
grid: np.ndarray, bucket: float = 1, nregion: int = 100, tol: int = 1, domain: str = 'EUROPE'
) -> np.ndarray:
"""Obtain basis function with nregions (for land-sea split).
Expand All @@ -178,10 +199,13 @@ def nregion_landsea_basis(
Tolerance to find number of basis function regions.
i.e. optimizes nregions to +/- tol
Defaults to 1
domain:
Domain across which to calculate basis functions.
Currently limited to "EUROPE" or "EASTASIA"
Returns:
basis_function: 2D basis function array
"""
bucket_opt = optimize_nregions(bucket, grid, nregion, tol)
basis_function = bucket_split_landsea_basis(grid, bucket_opt)
bucket_opt = optimize_nregions(bucket, grid, nregion, tol, domain)
basis_function = bucket_split_landsea_basis(grid, bucket_opt, domain)
return basis_function
Binary file not shown.
Binary file not shown.
30 changes: 29 additions & 1 deletion tests/test_basis_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pandas as pd
import xarray as xr
from openghg_inversions.basis._functions import basis, _flux_fp_from_fp_all, _mean_fp_times_mean_flux
from openghg_inversions.basis import bucketbasisfunction, quadtreebasisfunction
from openghg_inversions.basis import bucketbasisfunction, quadtreebasisfunction, fixed_outer_regions_basis
from openghg_inversions.get_data import data_processing_surface_notracer


Expand Down Expand Up @@ -46,6 +46,7 @@ def test_quadtree_basis_function(tac_ch4_data_args, raw_data_path):
fp_all=fp_all,
start_date="2019-01-01",
seed=42,
domain="EUROPE"
)

basis_func_reloaded = basis(
Expand All @@ -70,6 +71,7 @@ def test_bucket_basis_function(tac_ch4_data_args, raw_data_path):
emissions_name=[emissions_name],
fp_all=fp_all,
start_date="2019-01-01",
domain="EUROPE"
)


Expand All @@ -80,3 +82,29 @@ def test_bucket_basis_function(tac_ch4_data_args, raw_data_path):
# TODO: create new "fixed" basis function file, since we've switched basis functions from
# dataset to data array
xr.testing.assert_allclose(basis_func, basis_func_reloaded.basis)

def test_fixed_outer_region_basis_function(tac_ch4_data_args, raw_data_path):
"""Check if fixed outer region basis created wtih seed 42 and TAC CH4 args matches
a basis created with the same argumenst and saved to file.
This is to check against changes in the code from when this test was made
(2 Sep 2024)
"""
fp_all, *_ = data_processing_surface_notracer(**tac_ch4_data_args)
emissions_name = next(iter(fp_all[".flux"].keys()))
basis_func = fixed_outer_regions_basis(
emissions_name=[emissions_name],
fp_all=fp_all,
start_date="2019-01-01",
domain="EUROPE",
basis_algorithm='weighted'
)

basis_func_reloaded = basis(
domain="EUROPE", basis_case="fixed_outer_region_ch4-test_basis", basis_directory=raw_data_path / "basis"
)

# TODO: create new "fixed" basis function file, since we've switched basis functions from
# dataset to data array
xr.testing.assert_allclose(basis_func, basis_func_reloaded.basis)

0 comments on commit 84a05ec

Please sign in to comment.