Skip to content

Commit

Permalink
debug and updates
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Jan 18, 2024
1 parent 348b859 commit ccbc08d
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 238 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: ["--honor-noqa"]
Expand All @@ -34,7 +34,7 @@ repos:
# Python linting
# =======================
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
args: ["--config=setup.cfg"]
Expand Down
3 changes: 2 additions & 1 deletion pcmdi_metrics/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
da_to_ds,
get_axis_list,
get_data_list,
get_grid,
get_latitude_bounds_key,
get_latitude_key,
get_latitude,
Expand All @@ -21,4 +22,4 @@
get_time_key,
select_subset,
)
from .default_regions_define import load_regions_specs, region_subset # noqa
from .regions import load_regions_specs, region_subset # noqa
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr
import xcdat as xc

from pcmdi_metrics.io import da_to_ds
from pcmdi_metrics.io import da_to_ds, get_longitude, select_subset


def load_regions_specs() -> dict:
Expand Down Expand Up @@ -76,63 +76,67 @@ def load_regions_specs() -> dict:


def region_subset(
ds: Union[xr.Dataset, xr.DataArray], region: str, regions_specs: dict = None
ds: Union[xr.Dataset, xr.DataArray],
region: str,
data_var: str = "variable",
regions_specs: dict = None,
) -> Union[xr.Dataset, xr.DataArray]:
"""
ds: xarray.Dataset
regions_specs: dict
region: string
"""_summary_
Parameters
----------
ds : Union[xr.Dataset, xr.DataArray]
_description_
region : str
_description_
data_var : str, optional
_description_, by default None
regions_specs : dict, optional
_description_, by default None
Returns
-------
Union[xr.Dataset, xr.DataArray]
_description_
"""
if isinstance(ds, xr.DataArray):
is_dataArray = True
varname = "variable"
ds = da_to_ds(ds, varname)
ds = da_to_ds(ds, data_var)
else:
is_dataArray = False

if regions_specs is None:
regions_specs = load_regions_specs()

if "domain" in list(regions_specs[region].keys()):
if "latitude" in list(regions_specs[region]["domain"].keys()):
if "domain" in regions_specs[region]:
if "latitude" in regions_specs[region]["domain"]:
lat0 = regions_specs[region]["domain"]["latitude"][0]
lat1 = regions_specs[region]["domain"]["latitude"][1]
# proceed subset
if "latitude" in (ds.coords.dims):
ds = ds.sel(latitude=slice(lat0, lat1))
elif "lat" in (ds.coords.dims):
ds = ds.sel(lat=slice(lat0, lat1))
ds = select_subset(ds, lat=(lat0, lat1))

if "longitude" in list(regions_specs[region]["domain"].keys()):
if "longitude" in regions_specs[region]["domain"]:
lon0 = regions_specs[region]["domain"]["longitude"][0]
lon1 = regions_specs[region]["domain"]["longitude"][1]

# check original dataset longitude range
if "longitude" in (ds.coords.dims):
lon_min = ds.longitude.min()
lon_max = ds.longitude.max()
elif "lon" in (ds.coords.dims):
lon_min = ds.lon.min()
lon_max = ds.lon.max()

# longitude range swap if needed
if (
min(lon0, lon1) < 0
): # when subset region lon is defined in (-180, 180) range
if (
min(lon_min, lon_max) < 0
): # if original data lon range is (-180, 180) no treatment needed
lon_min = get_longitude(ds).min().values.item()
lon_max = get_longitude(ds).max().values.item()

# Check if longitude range swap is needed
if min(lon0, lon1) < 0:
# when subset region lon is defined in (-180, 180) range
if min(lon_min, lon_max) < 0:
# if original data lon range is (-180, 180), no treatment needed
pass
else: # if original data lon range is (0, 360), convert swap lon
else:
# if original data lon range is (0, 360), convert and swap lon
ds = xc.swap_lon_axis(ds, to=(-180, 180))

# proceed subset
if "longitude" in (ds.coords.dims):
ds = ds.sel(longitude=slice(lon0, lon1))
elif "lon" in (ds.coords.dims):
ds = ds.sel(lon=slice(lon0, lon1))
ds = select_subset(ds, lon=(lon0, lon1))

if is_dataArray:
return ds["variable"]
return ds[data_var]
else:
return ds
22 changes: 22 additions & 0 deletions pcmdi_metrics/io/xcdat_dataset_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,25 @@ def da_to_ds(d: Union[xr.Dataset, xr.DataArray], var: str = "variable") -> xr.Da
raise TypeError(
"Input must be an instance of either xarrary.DataArray or xarrary.Dataset"
)


def get_grid(
ds: xr.Dataset,
) -> xr.Dataset:
"""Get grid information
Parameters
----------
ds : xr.Dataset
xarray dataset to extract grid information that has latitude, longitude, and their bounds included
Returns
-------
xr.Dataset
xarray dataset with grid information
"""
lat_key = get_latitude_key(ds)
lon_key = get_longitude_key(ds)
lat_bnds_key = get_latitude_bounds_key(ds)
lon_bnds_key = get_longitude_bounds_key(ds)
return ds[[lat_key, lon_key, lat_bnds_key, lon_bnds_key]]
52 changes: 26 additions & 26 deletions pcmdi_metrics/stats/compute_statistics_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import xcdat as xc


def _check_data_convert_to_ds_if_needed(
def da_to_ds(
d: Union[xr.Dataset, xr.DataArray], var: str = "variable"
):
if isinstance(d, xr.Dataset):
Expand All @@ -29,8 +29,8 @@ def annual_mean(dm, do, var="variable"):
"Comments": "Assumes input are 12 months climatology",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

dm_am = dm.temporal.average(var)
do_am = do.temporal.average(var)
Expand Down Expand Up @@ -84,8 +84,8 @@ def bias_xy(dm, do, var="variable", weights=None):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

dif = dm[var] - do[var]
if weights is None:
Expand All @@ -104,8 +104,8 @@ def bias_xyt(dm, do, var="variable"):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

ds = dm.copy(deep=True)
ds["dif"] = dm[var] - do[var]
Expand All @@ -124,8 +124,8 @@ def cor_xy(dm, do, var="variable", weights=None):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

if weights is None:
weights = dm.spatial.get_weights(axis=["X", "Y"])
Expand Down Expand Up @@ -155,7 +155,7 @@ def mean_xy(d, var="variable", weights=None):
"Contact": "pcmdi-metrics@llnl.gov",
}

d = _check_data_convert_to_ds_if_needed(d, var)
d = da_to_ds(d, var)

lat_key = xc.axis.get_dim_keys(d, axis="Y")
lon_key = xc.axis.get_dim_keys(d, axis="X")
Expand All @@ -176,8 +176,8 @@ def meanabs_xy(dm, do, var="variable", weights=None):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

if weights is None:
weights = dm.spatial.get_weights(axis=["X", "Y"])
Expand All @@ -197,8 +197,8 @@ def meanabs_xyt(dm, do, var="variable"):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

ds = dm.copy(deep=True)
ds["absdif"] = abs(dm[var] - do[var])
Expand All @@ -219,8 +219,8 @@ def rms_0(dm, do, var="variable", weighted=True):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

dif_square = (dm[var] - do[var]) ** 2
if weighted:
Expand All @@ -240,8 +240,8 @@ def rms_xy(dm, do, var="variable", weights=None):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

dif_square = (dm[var] - do[var]) ** 2
if weights is None:
Expand All @@ -259,8 +259,8 @@ def rms_xyt(dm, do, var="variable"):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

ds = dm.copy(deep=True)
ds["diff_square"] = (dm[var] - do[var]) ** 2
Expand All @@ -280,8 +280,8 @@ def rmsc_xy(dm, do, var="variable", weights=None, NormalizeByOwnSTDV=False):
"Contact": "pcmdi-metrics@llnl.gov",
}

dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

if weights is None:
weights = dm.spatial.get_weights(axis=["X", "Y"])
Expand Down Expand Up @@ -310,7 +310,7 @@ def std_xy(ds, var="variable", weights=None):
"Contact": "pcmdi-metrics@llnl.gov",
}

ds = _check_data_convert_to_ds_if_needed(ds, var)
ds = da_to_ds(ds, var)

if weights is None:
weights = ds.spatial.get_weights(axis=["X", "Y"])
Expand All @@ -334,7 +334,7 @@ def std_xyt(d, var="variable"):
"Contact": "pcmdi-metrics@llnl.gov",
}
ds = d.copy(deep=True)
ds = _check_data_convert_to_ds_if_needed(ds, var)
ds = da_to_ds(ds, var)
average = d.spatial.average(var, axis=["X", "Y"]).temporal.average(var)[var]
ds["anomaly"] = (d[var] - average) ** 2
variance = (
Expand All @@ -353,8 +353,8 @@ def zonal_mean(dm, do, var="variable"):
"Contact": "pcmdi-metrics@llnl.gov",
"Comments": "",
}
dm = _check_data_convert_to_ds_if_needed(dm, var)
do = _check_data_convert_to_ds_if_needed(do, var)
dm = da_to_ds(dm, var)
do = da_to_ds(do, var)

dm_zm = dm.spatial.average(var, axis=["X"])
do_zm = do.spatial.average(var, axis=["X"])
Expand Down
1 change: 1 addition & 0 deletions pcmdi_metrics/variability_mode/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from .landmask import data_land_mask_out, estimate_landmask # noqa
from .lib_variability_mode import ( # noqa
check_start_end_year,
debug_print,
get_domain_range,
read_data_in,
Expand Down
12 changes: 7 additions & 5 deletions pcmdi_metrics/variability_mode/lib/calc_stat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from time import gmtime, strftime

from pcmdi_metrics.io import region_subset
from pcmdi_metrics.io import get_grid, region_subset
from pcmdi_metrics.stats import bias_xy as calcBias
from pcmdi_metrics.stats import cor_xy as calcSCOR
from pcmdi_metrics.stats import mean_xy
from pcmdi_metrics.stats import rms_xy as calcRMS
from pcmdi_metrics.stats import rmsc_xy as calcRMSc
from pcmdi_metrics.utils import regrid


def calc_stats_save_dict(
Expand Down Expand Up @@ -60,12 +61,13 @@ def calc_stats_save_dict(
# . . . . . . . . . . . . . . . . . . . . . . . . .
if obs_compare:
if method in ["eof", "cbf"]:
ref_grid_global = eof_lr_obs.getGrid()
ref_grid_global = get_grid(eof_lr_obs)
# Regrid (interpolation, model grid to ref grid)
debug_print("regrid (global) start", debug)
eof_model_global = eof_lr.regrid(
ref_grid_global, regridTool="regrid2", mkCyclic=True
)
# eof_model_global = eof_lr.regrid(eof_lr,
# ref_grid_global, regridTool="regrid2", mkCyclic=True
# )
eof_model_global = regrid(eof_lr, ref_grid_global)
debug_print("regrid end", debug)
# Extract subdomain
# eof_model = eof_model_global(region_subdomain)
Expand Down
Loading

0 comments on commit ccbc08d

Please sign in to comment.