diff --git a/.gitignore b/.gitignore index f33a818cf..92c6c9253 100644 --- a/.gitignore +++ b/.gitignore @@ -90,4 +90,9 @@ ENV/ # Rope project settings .ropeproject +# debug_data +pcmdi_metrics/monsoon_wang/*.nc +pcmdi_metrics/monsoon_wang/debug_regions_specs.py +pcmdi_metrics/monsoon_wang/test_param.py + test/ diff --git a/pcmdi_metrics/io/__init__.py b/pcmdi_metrics/io/__init__.py index a81b7b8b6..397b65316 100644 --- a/pcmdi_metrics/io/__init__.py +++ b/pcmdi_metrics/io/__init__.py @@ -20,4 +20,5 @@ get_time_bounds_key, get_time_key, select_subset, + da_to_ds, ) diff --git a/pcmdi_metrics/io/default_regions_define.py b/pcmdi_metrics/io/default_regions_define.py index 11a238994..b360bf090 100755 --- a/pcmdi_metrics/io/default_regions_define.py +++ b/pcmdi_metrics/io/default_regions_define.py @@ -1,5 +1,11 @@ +from typing import Union + +import xarray as xr import xcdat as xc +# from pcmdi_metrics.io import da_to_ds, get_longitude, select_subset +from .xcdat_dataset_io import da_to_ds, get_longitude, select_subset + def load_regions_specs(): regions_specs = { @@ -45,7 +51,8 @@ def load_regions_specs(): # South American Monsoon "SAMM": {"domain": {"latitude": (-45.0, 0.0), "longitude": (240.0, 330.0)}}, # North African Monsoon - "NAFM": {"domain": {"latitude": (0.0, 45.0), "longitude": (310.0, 60.0)}}, + # "NAFM": {"domain": {"latitude": (0.0, 45.0), "longitude": (310.0, 60.0)}}, + "NAFM": {"domain": {"latitude": (0.0, 45.0), "longitude": (-50, 60.0)}}, # South African Monsoon "SAFM": {"domain": {"latitude": (-45.0, 0.0), "longitude": (0.0, 90.0)}}, # Asian Summer Monsoon @@ -70,55 +77,137 @@ def load_regions_specs(): return regions_specs -def region_subset(ds, regions_specs, region=None): - """ - d: xarray.Dataset - regions_specs: dict - region: string - """ +# def region_subset(ds, regions_specs, region=None): +# """ +# d: xarray.Dataset +# regions_specs: dict +# region: string +# """ +# +# #print("list(regions_specs.keys())", list(regions_specs.keys())) +# +# if (region is None) or ( +# (region is not None) and (region not in list(regions_specs.keys())) +# ): +# print("Error: region not defined") +# else: +# if "domain" in list(regions_specs[region].keys()): +# if "latitude" in list(regions_specs[region]["domain"].keys()): +# lat0 = regions_specs[region]["domain"]["latitude"][0] +# lat1 = regions_specs[region]["domain"]["latitude"][1] +# # proceed subset +# if "latitude" in (ds.coords.dims): +# ds = ds.sel(latitude=slice(lat0, lat1)) +# elif "lat" in (ds.coords.dims): +# ds = ds.sel(lat=slice(lat0, lat1)) +# +# if "longitude" in list(regions_specs[region]["domain"].keys()): +# lon0 = regions_specs[region]["domain"]["longitude"][0] +# lon1 = regions_specs[region]["domain"]["longitude"][1] +# +# # check original dataset longitude range +# if "longitude" in (ds.coords.dims): +# lon_min = ds.longitude.min() +# lon_max = ds.longitude.max() +# elif "lon" in (ds.coords.dims): +# lon_min = ds.lon.min() +# lon_max = ds.lon.max() +# +# # longitude range swap if needed +# if ( +# min(lon0, lon1) < 0 +# ): # when subset region lon is defined in (-180, 180) range +# if ( +# min(lon_min, lon_max) < 0 +# ): # if original data lon range is (-180, 180) no treatment needed +# pass +# else: # if original data lon range is (0, 360), convert swap lon +# ds = xc.swap_lon_axis(ds, to=(-180, 180)) +# +# # proceed subset +# if "longitude" in (ds.coords.dims): +# ds = ds.sel(longitude=slice(lon0, lon1)) +# elif "lon" in (ds.coords.dims): +# ds = ds.sel(lon=slice(lon0, lon1)) +# +# return ds + + +def region_subset( + ds: Union[xr.Dataset, xr.DataArray], + region: str, + data_var: str = "variable", + regions_specs: dict = None, + debug: bool = False, +) -> Union[xr.Dataset, xr.DataArray]: + """_summary_ + + Parameters + ---------- + ds : Union[xr.Dataset, xr.DataArray] + _description_ + region : str + _description_ + data_var : str, optional + _description_, by default None + regions_specs : dict, optional + _description_, by default None + debug: bool, optional + Turn on debug print, by default False - if (region is None) or ( - (region is not None) and (region not in list(regions_specs.keys())) - ): - print("Error: region not defined") + Returns + ------- + Union[xr.Dataset, xr.DataArray] + _description_ + """ + if isinstance(ds, xr.DataArray): + is_dataArray = True + ds = da_to_ds(ds, data_var) else: - if "domain" in list(regions_specs[region].keys()): - if "latitude" in list(regions_specs[region]["domain"].keys()): - lat0 = regions_specs[region]["domain"]["latitude"][0] - lat1 = regions_specs[region]["domain"]["latitude"][1] - # proceed subset - if "latitude" in (ds.coords.dims): - ds = ds.sel(latitude=slice(lat0, lat1)) - elif "lat" in (ds.coords.dims): - ds = ds.sel(lat=slice(lat0, lat1)) - - if "longitude" in list(regions_specs[region]["domain"].keys()): - lon0 = regions_specs[region]["domain"]["longitude"][0] - lon1 = regions_specs[region]["domain"]["longitude"][1] - - # check original dataset longitude range - if "longitude" in (ds.coords.dims): - lon_min = ds.longitude.min() - lon_max = ds.longitude.max() - elif "lon" in (ds.coords.dims): - lon_min = ds.lon.min() - lon_max = ds.lon.max() - - # longitude range swap if needed - if ( - min(lon0, lon1) < 0 - ): # when subset region lon is defined in (-180, 180) range - if ( - min(lon_min, lon_max) < 0 - ): # if original data lon range is (-180, 180) no treatment needed + is_dataArray = False + + if regions_specs is None: + regions_specs = load_regions_specs() + + if "domain" in regions_specs[region]: + if "latitude" in regions_specs[region]["domain"]: + lat0 = regions_specs[region]["domain"]["latitude"][0] + lat1 = regions_specs[region]["domain"]["latitude"][1] + # proceed subset + ds = select_subset(ds, lat=(min(lat0, lat1), max(lat0, lat1))) + if debug: + print("region_subset, latitude subsetted, ds:", ds) + + if "longitude" in regions_specs[region]["domain"]: + lon0 = regions_specs[region]["domain"]["longitude"][0] + lon1 = regions_specs[region]["domain"]["longitude"][1] + + # check original dataset longitude range + lon_min = get_longitude(ds).min().values.item() + lon_max = get_longitude(ds).max().values.item() + + # Check if longitude range swap is needed + if min(lon0, lon1) < 0: + # when subset region lon is defined in (-180, 180) range + if min(lon_min, lon_max) < 0: + # if original data lon range is (-180, 180), no treatment needed + pass + else: + # if original data lon range is (0, 360), convert and swap lon + try: + ds = ds.drop_vars(["time"]) + except: pass - else: # if original data lon range is (0, 360), convert swap lon - ds = xc.swap_lon_axis(ds, to=(-180, 180)) + ds = xc.swap_lon_axis(ds, to=(-180, 180)) - # proceed subset - if "longitude" in (ds.coords.dims): - ds = ds.sel(longitude=slice(lon0, lon1)) - elif "lon" in (ds.coords.dims): - ds = ds.sel(lon=slice(lon0, lon1)) + # proceed subset + # ds = select_subset(ds, lon=(min(lon0, lon1), max(lon0, lon1))) + ds = select_subset(ds, lon=(lon0, lon1)) + if debug: + print("region_subset, longitude subsetted, ds:", ds) - return ds + # return the same type + if is_dataArray: + return ds[data_var] + else: + return ds diff --git a/pcmdi_metrics/io/xcdat_dataset_io.py b/pcmdi_metrics/io/xcdat_dataset_io.py index 72af41e4e..2cdc4a194 100644 --- a/pcmdi_metrics/io/xcdat_dataset_io.py +++ b/pcmdi_metrics/io/xcdat_dataset_io.py @@ -6,6 +6,36 @@ # Internal function +def da_to_ds(d: Union[xr.Dataset, xr.DataArray], var: str = "variable") -> xr.Dataset: + """Convert xarray DataArray to Dataset + + Parameters + ---------- + d : Union[xr.Dataset, xr.DataArray] + Input dataArray. If dataset is given, no process will be done + var : str, optional + Name of dataArray, by default "variable" + + Returns + ------- + xr.Dataset + xarray Dataset + + Raises + ------ + TypeError + Raised when given input is not xarray based variables + """ + if isinstance(d, xr.Dataset): + return d.copy() + elif isinstance(d, xr.DataArray): + return d.to_dataset(name=var).bounds.add_missing_bounds().copy() + else: + raise TypeError( + "Input must be an instance of either xarrary.DataArray or xarrary.Dataset" + ) + + def _find_key( ds: Union[xr.Dataset, xr.DataArray], axis: str, potential_names: list ) -> str: diff --git a/pcmdi_metrics/monsoon_sperber/driver_monsoon_sperber.py b/pcmdi_metrics/monsoon_sperber/driver_monsoon_sperber.py deleted file mode 100644 index 17b3d6655..000000000 --- a/pcmdi_metrics/monsoon_sperber/driver_monsoon_sperber.py +++ /dev/null @@ -1,904 +0,0 @@ -#!/usr/bin/env python -""" -Calculate monsoon metrics - -Bo Dong (dong12@llnl.gov) and Jiwoo Lee (lee1043@llnl.gov) - -Reference: -Sperber, K. and H. Annamalai, 2014: -The use of fractional accumulated precipitation for the evaluation of the -annual cycle of monsoons. Climate Dynamics, 43:3219-3244, -doi: 10.1007/s00382-014-2099-3 - -Auspices: -This work was performed under the auspices of the U.S. Department of -Energy by Lawrence Livermore National Laboratory under Contract -DE-AC52-07NA27344. Lawrence Livermore National Laboratory is operated by -Lawrence Livermore National Security, LLC, for the U.S. Department of Energy, -National Nuclear Security Administration under Contract DE-AC52-07NA27344. - -Disclaimer: -This document was prepared as an account of work sponsored by an -agency of the United States government. Neither the United States government -nor Lawrence Livermore National Security, LLC, nor any of their employees -makes any warranty, expressed or implied, or assumes any legal liability or -responsibility for the accuracy, completeness, or usefulness of any -information, apparatus, product, or process disclosed, or represents that its -use would not infringe privately owned rights. Reference herein to any specific -commercial product, process, or service by trade name, trademark, manufacturer, -or otherwise does not necessarily constitute or imply its endorsement, -recommendation, or favoring by the United States government or Lawrence -Livermore National Security, LLC. The views and opinions of authors expressed -herein do not necessarily state or reflect those of the United States -government or Lawrence Livermore National Security, LLC, and shall not be used -for advertising or product endorsement purposes. -""" - -import copy -import json -import math -import os -import re -import sys -from argparse import RawTextHelpFormatter -from collections import defaultdict -from glob import glob -from shutil import copyfile - -import matplotlib -matplotlib.use('Agg') -#import matplotlib.pyplot as plt -from matplotlib import pyplot as plt -import numpy as np -import pandas as pd -import xarray as xr -import xcdat as xc - -import pcmdi_metrics -from pcmdi_metrics import resources -from pcmdi_metrics.io import load_regions_specs, region_subset -from pcmdi_metrics.mean_climate.lib import pmp_parser -from pcmdi_metrics.monsoon_sperber.lib import ( - AddParserArgument, - YearCheck, - divide_chunks_advanced, - interp1d, - model_land_only, - sperber_metrics, -) -from pcmdi_metrics.utils import create_land_sea_mask, fill_template -from pcmdi_metrics.io import xcdat_open - - -def tree(): - return defaultdict(tree) - - -def pick_year_last_day(ds): - eday = 31 - try: - time_key = xc.axis.get_dim_keys(ds, axis="T") - if "calendar" in ds[time_key].attrs.keys(): - if "360" in ds[time_key]["calendar"]: - eday = 30 - else: - if "360" in ds[time_key][0].values.item().calendar: - eday = 30 - except Exception: - pass - return eday - - -# ================================================= -# Hard coded options... will be moved out later -# ------------------------------------------------- -list_monsoon_regions = ["AIR", "AUS", "Sahel", "GoG", "NAmo", "SAmo"] -# list_monsoon_regions = ["all"] - - -# How many elements each list should have -n = 5 # pentad - -# ================================================= -# Collect user defined options -# ------------------------------------------------- -P = pmp_parser.PMPParser( - description="Runs PCMDI Monsoon Sperber Computations", - formatter_class=RawTextHelpFormatter, -) -P = AddParserArgument(P) -P.add_argument( - "--cmec", - dest="cmec", - default=False, - action="store_true", - help="Use to save CMEC format metrics JSON", -) -P.add_argument( - "--no_cmec", - dest="cmec", - default=False, - action="store_false", - help="Do not save CMEC format metrics JSON", -) -P.set_defaults(cmec=False) -param = P.get_parameter() - -# Pre-defined options -mip = param.mip -exp = param.exp -fq = param.frequency -realm = param.realm - -# On/off switches -nc_out = param.nc_out # Record NetCDF output -plot = param.plot # Generate plots -includeOBS = param.includeOBS # Loop run for OBS or not -cmec = param.cmec # CMEC formatted JSON - -# Path to reference data -reference_data_name = param.reference_data_name -reference_data_path = param.reference_data_path -reference_data_lf_path = param.reference_data_lf_path - -# Path to model data as string template -modpath = param.process_templated_argument("modpath") -modpath_lf = param.process_templated_argument("modpath_lf") -print("modpath = ", modpath) -print("modpath_lf = ", modpath_lf) - -# Check given model option -models = param.modnames -print("models:", models) - -# list of regions -# list_monsoon_regions = param.list_monsoon_regions - -# Include all models if conditioned -if ("all" in [m.lower() for m in models]) or (models == "all"): - model_index_path = re.split(". |_", modpath.split("/")[-1]).index("%(model)") - models = [ - re.split(". |_", p.split("/")[-1])[model_index_path] - for p in glob.glob( - fill_template( - modpath, mip=mip, exp=exp, model="*", realization="*", variable="pr" - ) - ) - ] - # remove duplicates - models = sorted(list(dict.fromkeys(models)), key=lambda s: s.lower()) - -print("number of models:", len(models)) - -# Realizations -realization = param.realization -print("realization: ", realization) - -# Output -outdir = param.process_templated_argument("results_dir") - -# Create output directory -for output_type in ["graphics", "diagnostic_results", "metrics_results"]: - if not os.path.exists(outdir(output_type=output_type)): - os.makedirs(outdir(output_type=output_type)) - print(outdir(output_type=output_type)) - -# Debug -debug = param.debug -print("debug: ", debug) - -# Variables -varModel = param.varModel -varOBS = param.varOBS - -# Year -# model -msyear = param.msyear -meyear = param.meyear -YearCheck(msyear, meyear, P) -# obs -osyear = param.osyear -oeyear = param.oeyear -YearCheck(osyear, oeyear, P) - -# Units -units = param.units -# model -ModUnitsAdjust = param.ModUnitsAdjust -# obs -ObsUnitsAdjust = param.ObsUnitsAdjust - -# JSON update -update_json = param.update_json - -# ================================================= -# Declare dictionary for .json record -# ------------------------------------------------- -monsoon_stat_dic = tree() - -# Define output json file -json_filename = "_".join( - ["monsoon_sperber_stat", mip, exp, fq, realm, str(msyear) + "-" + str(meyear)] -) -json_file = os.path.join(outdir(output_type="metrics_results"), json_filename + ".json") -json_file_org = os.path.join( - outdir(output_type="metrics_results"), - "_".join([json_filename, "org", str(os.getpid())]) + ".json", -) - -# Save pre-existing json file against overwriting -if os.path.isfile(json_file) and os.stat(json_file).st_size > 0: - copyfile(json_file, json_file_org) - if update_json: - fj = open(json_file) - monsoon_stat_dic = json.loads(fj.read()) - fj.close() - -if "REF" not in list(monsoon_stat_dic.keys()): - monsoon_stat_dic["REF"] = {} -if "RESULTS" not in list(monsoon_stat_dic.keys()): - monsoon_stat_dic["RESULTS"] = {} - -# ================================================= -# Loop start for given models -# ------------------------------------------------- -regions_specs = {} -egg_pth = resources.resource_path() -exec( - compile( - open(os.path.join(egg_pth, "default_regions.py")).read(), - os.path.join(egg_pth, "default_regions.py"), - "exec", - ) -) - -# ================================================= -# Loop start for given models -# ------------------------------------------------- -if includeOBS: - models.insert(0, "obs") - -for model in models: - print( - "========== model = " - + model - + " ===============================================================================" - ) - print("\n") - try: - # Conditions depending obs or model - if model == "obs": - var = varOBS - UnitsAdjust = ObsUnitsAdjust - syear = osyear - eyear = oeyear - # variable data - model_path_list = [reference_data_path] - # land fraction - model_lf_path = reference_data_lf_path - # dict for output JSON - if reference_data_name not in list(monsoon_stat_dic["REF"].keys()): - monsoon_stat_dic["REF"][reference_data_name] = {} - # dict for plottng - dict_obs_composite = {} - dict_obs_composite[reference_data_name] = {} - else: # for rest of models - var = varModel - UnitsAdjust = ModUnitsAdjust - syear = msyear - eyear = meyear - # variable data - model_path_list = glob( - modpath(model=model, exp=exp, realization=realization, variable=var) - ) - if debug: - print("model: ", model, " exp: ", exp, " realization: ", realization, " variable: ", var) - print("debug: model_path_list: ", model_path_list) - # land fraction - model_lf_path = modpath_lf(model=model) - print("model_lf_path = ", model_lf_path) - if os.path.isfile(model_lf_path): - pass - else: - model_lf_path = modpath_lf(model=model.upper()) - # dict for output JSON - if model not in list(monsoon_stat_dic["RESULTS"].keys()): - monsoon_stat_dic["RESULTS"][model] = {} - - dict_obs_composite = {} - dict_obs_composite[reference_data_name] = {} - # Read land fraction - - if model_lf_path is not None: - if os.path.isfile(model_lf_path): - try: - ds_lf = xcdat_open(model_lf_path) - except Exception: - ds_lf = None - - if not ds_lf: - lf_array = create_land_sea_mask(ds_lf, method="pcmdi") - ds_lf = lf_array.to_dataset().compute() - ds_lf = ds_lf.rename_vars({"lsmask": "sftlf"}) - - # use pcmdi mask - # lf_array = create_land_sea_mask(ds_lf, method="pcmdi") - # ds_lf = lf_array.to_dataset().compute() - # ds_lf = ds_lf.rename_vars({"lsmask": "sftlf"}) - - if model in [ "EC-EARTH" ]: #, "BNU-ESM" ]: - ds_lf = ds_lf.isel(lat=slice(None, None, -1)) - lf = ds_lf.sftlf.sel(lat=slice(-90, 90)) # land frac file must be global - - # ------------------------------------------------- - # Loop start - Realization - # ------------------------------------------------- - for model_path in model_path_list: - try: - if model == "obs": - run = "obs" - else: - if realization in ["all", "All", "ALL", "*"]: - run_index = modpath.split(".").index("%(realization)") - run = model_path.split("/")[-1].split(".")[run_index] - else: - run = realization - if run not in monsoon_stat_dic["RESULTS"][model]: - monsoon_stat_dic["RESULTS"][model][run] = {} - print("\n") - print(" --- ", run, " ---") - - # Get time coordinate information - print("model_path = ", model_path) - - - dc = xcdat_open(model_path, decode_times=True) - dc['time'].attrs['axis'] = 'T' - dc['time'].attrs['standard_name'] = 'time' - dc = xr.decode_cf(dc, decode_times=True) - dc = dc.bounds.add_missing_bounds("X") - dc = dc.bounds.add_missing_bounds("Y") - dc = dc.bounds.add_missing_bounds("T") - - dc = dc.assign_coords({"lon": lf.lon, "lat": lf.lat}) - c = xc.center_times(dc) - eday = pick_year_last_day(dc) - - - # Get starting and ending year and month - startYear = c.time.values[0].year - startMonth = c.time.values[0].month - endYear = c.time.values[-1].year - endMonth = c.time.values[-1].month - - # Adjust years to consider only when they - # have entire calendar months - if startMonth > 1: - startYear += 1 - if endMonth < 12: - endYear -= 1 - - # Final selection of starting and ending years - startYear = max(syear, startYear) - endYear = min(eyear, endYear) - - # Check calendar (just checking..) - - if debug: - print("debug: startYear: ", type(startYear), startYear) - print("debug: startMonth: ", type(startMonth), startMonth) - print("debug: endYear: ", type(endYear), endYear) - print("debug: endMonth: ", type(endMonth), endMonth) - endYear = startYear + 1 - - # Prepare archiving individual year pentad time series for composite - list_pentad_time_series = {} - list_pentad_time_series_cumsum = {} # Cumulative time series - for region in list_monsoon_regions: - list_pentad_time_series[region] = [] - list_pentad_time_series_cumsum[region] = [] - - # Write individual year time series for each monsoon domain - # in a netCDF file - output_filename = "{}_{}_{}_{}_{}_{}-{}".format( - mip, model, exp, run, "monsoon_sperber", startYear, endYear - ) - if nc_out: - output_filename = "{}_{}_{}_{}_{}_{}-{}".format( - mip, model, exp, run, "monsoon_sperber", startYear, endYear - ) - - file_path = os.path.join( - outdir(output_type="diagnostic_results"), - output_filename + ".nc", - ) - try: - fout = xr.open_dataset( - file_path, mode="a" - ) # 'a' stands for append mode - except FileNotFoundError: - fout = xr.Dataset() - - # Plotting setup - if plot: - ax = {} - if len(list_monsoon_regions) > 1: - nrows = math.ceil(len(list_monsoon_regions) / 2.0) - ncols = 2 - else: - nrows = 1 - ncols = 1 - - fig = plt.figure(figsize=[6.4, 6.4]) - plt.subplots_adjust(hspace=0.25) - - for i, region in enumerate(list_monsoon_regions): - ax[region] = plt.subplot(nrows, ncols, i + 1) - ax[region].set_ylim(0, 1) - # ax[region].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1]) - # ax[region].set_xticks([0, 10, 20, 30, 40, 50, 60, 70]) - ax[region].margins(x=0) - print( - "plot: region", - region, - "nrows", - nrows, - "ncols", - ncols, - "index", - i + 1, - ) - if nrows > 1 and math.ceil((i + 1) / float(ncols)) < nrows: - ax[region].set_xticklabels([]) - if ncols > 1 and (i + 1) % 2 == 0: - ax[region].set_yticklabels([]) - - fig.text(0.5, 0.04, "pentad count", ha="center") - fig.text( - 0.03, - 0.5, - "accumulative pentad precip fraction", - va="center", - rotation="vertical", - ) - - # ------------------------------------------------- - # Loop start - Year - # ------------------------------------------------- - temporary = {} - "========== model = " - + model - + " ===============================================================================" - ) - print("\n") - # year loop, endYear+1 to include last year - for year in range(startYear, endYear + 1): - print("\n") - print(" year = ", year) - print("\n") - d = dc.pr.sel( - time=slice( - # str(year) + "-01-01 00:00:00", str(year) + "-12-31 23:59:59" - str(year) + "-01-01 00:00:00", - str(year) + f"-12-{eday} 23:59:59", - ), - lat=slice(-90, 90), - ) - # unit adjust - if UnitsAdjust[0]: - """Below two lines are identical to following: - d = MV2.multiply(d, 86400.) - d.units = 'mm/d' - """ - d.values = d.values * 86400.0 - d["units"] = units - - - # variable for over land only - d_land = model_land_only(model, d, lf, debug=debug) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - # Loop start - Monsoon region - # - - - - - - - - - - - - - - - - - - - - - - - - - - - regions_specs = load_regions_specs() - - for region in list_monsoon_regions: - print("\n") - print(" region = ", region) - print("\n") - # extract for monsoon region - if region in ["GoG", "NAmo"]: - # all grid point rainfall - d_sub_ds = region_subset(dc, regions_specs, region=region) - # must be entire calendar years - d_sub_pr = d_sub_ds.pr.sel( - time=slice( - str(year) + "-01-01 00:00:00", - # str(year) + "-12-31 23:59:59", - str(year) + f"-12-{eday} 23:59:59", - ) - ) - - d_sub_pr.values = d_sub_pr.values * 86400.0 - d_sub_pr["units"] = units - - - else: - # land-only rainfall - - d_sub_ds = region_subset(dc, regions_specs, region=region) - d_sub_pr = d_sub_ds.pr.sel( - time=slice( - str(year) + "-01-01 00:00:00", - # str(year) + "-12-31 23:59:59", - str(year) + f"-12-{eday} 23:59:59", - ) - ) - - lf_sub_ds = region_subset( - ds_lf, regions_specs, region=region - ) - lf_sub = lf_sub_ds.sftlf - d_sub_pr = model_land_only( - model, d_sub_pr, lf_sub, debug=debug - ) - - - d_sub_pr.values = d_sub_pr.values * 86400.0 - d_sub_pr["units"] = units - - - # Area average - - ds_sub_pr = d_sub_pr.to_dataset().compute() - dc = dc.bounds.add_missing_bounds("X") - ds_sub_pr = ds_sub_pr.bounds.add_missing_bounds("X") - ds_sub_pr = ds_sub_pr.bounds.add_missing_bounds("Y") - ds_sub_pr = ds_sub_pr.bounds.add_missing_bounds("T") - - - if "lat_bnds" not in ds_sub_pr.variables: - lat_bnds = dc["lat_bnds"].sel(lat=ds_sub_pr["lat"]) - ds_sub_pr["lat_bnds"] = lat_bnds - - ds_sub_aave = ds_sub_pr.spatial.average( - "pr", axis=["X", "Y"], weights="generate" - ).compute() - d_sub_aave = ds_sub_aave.pr - - - - if debug: - print("debug: region:", region) - - # Southern Hemisphere monsoon domain - # set time series as 7/1~6/30 - if region in ["AUS", "SAmo"]: - if year == startYear: - start_t = str(year) + "-07-01 00:00:00" - # end_t = str(year) + "-12-31 23:59:59" - end_t = str(year) + f"-12-{eday} 23:59:59" - temporary[region] = d_sub_aave.sel( - time=slice(start_t, end_t) - ) - - continue - else: - # n-1 year 7/1~12/31 - part1 = copy.copy(temporary[region]) - # n year 1/1~6/30 - part2 = d_sub_aave.sel( - time=slice( - str(year) + "-01-01 00:00:00", - str(year) + "-06-30 23:59:59", - ) - ) - start_t = str(year) + "-07-01 00:00:00" - # end_t = str(year) + "-12-31 23:59:59" - end_t = str(year) + f"-12-{eday} 23:59:59" - temporary[region] = d_sub_aave.sel( - time=slice(start_t, end_t) - ) - - d_sub_aave = xr.concat([part1, part2], dim="time") - - - if debug: - print( - "debug: ", - region, - year, - ) - # get pentad time series - list_d_sub_aave_chunks = list( - divide_chunks_advanced(d_sub_aave, n, debug=debug) - ) - - pentad_time_series = [] - time_coords = np.array([], dtype="datetime64") - - for d_sub_aave_chunk in list_d_sub_aave_chunks: - # ignore when chunk length is shorter than defined - if d_sub_aave_chunk.shape[0] >= n: - aa = d_sub_aave_chunk.to_numpy() - aa_mean = np.mean(aa) - ave_chunk = d_sub_aave_chunk.mean( - axis=0, skipna=True - ).compute() - pentad_time_series.append(float(ave_chunk)) - datetime_str = str(d_sub_aave_chunk["time"][0].values) - datetime = pd.to_datetime([datetime_str[:10]]) - time_coords = np.concatenate([time_coords, datetime]) - time_coords = pd.to_datetime(time_coords) - - pentad_time_series = xr.DataArray( - pentad_time_series, - dims="time", - coords={"time": time_coords}, - ) - - - if debug: - print( - "debug: pentad_time_series length: ", - len(pentad_time_series), - ) - - # Keep pentad time series length in consistent - ref_length = int(365 / n) - if len(pentad_time_series) < ref_length: - - pentad_time_series = pentad_time_series.interp( - time=pd.date_range(time_coords[0], time_coords[-1], periods=ref_length) - ) - - time_coords = pentad_time_series.coords["time"] - - - pentad_time_series_cumsum = np.cumsum(pentad_time_series) - pentad_time_series = xr.DataArray( - pentad_time_series, - dims="time", - name=region + "_" + str(year), - ) - pentad_time_series.attrs["units"] = str(d.units.values) - - pentad_time_series_cumsum = xr.DataArray( - pentad_time_series_cumsum, - dims="time", - name=region + "_" + str(year) + "_cumsum", - ) - pentad_time_series_cumsum.attrs["units"] = str(d.units.values) - pentad_time_series_cumsum.coords["time"] = time_coords - - - if nc_out: - # Archive individual year time series in netCDF file - pentad_time_series.to_netcdf(file_path, mode="a") - pentad_time_series_cumsum.to_netcdf(file_path, mode="a") - - """ - if plot: - # Add grey line for individual year in plot - if year == startYear: - label = 'Individual yr' - else: - label = '' - ax[region].plot( - np.array(pentad_time_series_cumsum), - c='grey', label=label) - """ - - # Append individual year: save for following composite - list_pentad_time_series[region].append(pentad_time_series) - list_pentad_time_series_cumsum[region].append( - pentad_time_series_cumsum - ) - - # --- Monsoon region loop end - # --- Year loop end - dc.close() - - # ------------------------------------------------- - # Loop start: Monsoon region without year: Composite - # ------------------------------------------------- - if debug: - print("debug: composite start") - - for region in list_monsoon_regions: - # Get composite for each region - - composite_pentad_time_series = np.array( - list_pentad_time_series[region] - ).mean(axis=0) - - # Get accumulation ts from the composite - composite_pentad_time_series_cumsum = np.cumsum( - composite_pentad_time_series - ) - - # Maintain axis information - - # - - - - - - - - - - - - # Metrics for composite - # - - - - - - - - - - - - - metrics_result = sperber_metrics( - composite_pentad_time_series_cumsum, region, debug=debug - ) - - # Normalized cummulative pentad time series - composite_pentad_time_series_cumsum_normalized = metrics_result[ - "frac_accum" - ] - - composite_pentad_time_series = xr.DataArray( - composite_pentad_time_series, dims="time", name=region + "_comp" - ) - composite_pentad_time_series.attrs["units"] = str(d.units) - composite_pentad_time_series.coords["time"] = time_coords - - composite_pentad_time_series_cumsum = xr.DataArray( - composite_pentad_time_series_cumsum, - dims="time", - name=region + "_comp_cumsum", - ) - composite_pentad_time_series_cumsum.attrs["units"] = str(d.units) - composite_pentad_time_series_cumsum.coords["time"] = time_coords - - composite_pentad_time_series_cumsum_normalized = xr.DataArray( - composite_pentad_time_series_cumsum_normalized, - dims="time", - name=region + "_comp_cumsum_fraction", - ) - composite_pentad_time_series_cumsum_normalized.attrs["units"] = str( - d.units - ) - composite_pentad_time_series_cumsum_normalized.coords[ - "time" - ] = time_coords - - if model == "obs": - dict_obs_composite[reference_data_name][region] = {} - dict_obs_composite[reference_data_name][ - region - ] = composite_pentad_time_series_cumsum_normalized - - # Archive as dict for JSON - if model == "obs": - dict_head = monsoon_stat_dic["REF"][reference_data_name] - else: - dict_head = monsoon_stat_dic["RESULTS"][model][run] - # generate key if not there - if region not in list(dict_head.keys()): - dict_head[region] = {} - # generate keys and save for statistics - dict_head[region]["onset_index"] = metrics_result["onset_index"] - dict_head[region]["decay_index"] = metrics_result["decay_index"] - dict_head[region]["slope"] = metrics_result["slope"] - dict_head[region]["duration"] = metrics_result["duration"] - - # Archice in netCDF file - if nc_out: - composite_pentad_time_series.to_netcdf(file_path, mode="a") - composite_pentad_time_series_cumsum.to_netcdf( - file_path, mode="a" - ) - composite_pentad_time_series_cumsum_normalized.to_netcdf( - file_path, mode="a" - ) - - if region == list_monsoon_regions[-1]: - fout.close() - - # Add line in plot - if plot: - if model != "obs": - # model - ax[region].plot( - np.array( - composite_pentad_time_series_cumsum_normalized - ), - c="red", - label=model, - ) - for idx in [ - metrics_result["onset_index"], - metrics_result["decay_index"], - ]: - ax[region].axvline( - x=idx, - ymin=0, - ymax=composite_pentad_time_series_cumsum_normalized[ - idx - ].item(), - c="red", - ls="--", - ) - - # obs - if model == "obs": - ax[region].plot( - np.array( - dict_obs_composite[reference_data_name][region] - ), - c="blue", - label=reference_data_name, - ) - for idx in [ - monsoon_stat_dic["REF"][reference_data_name][region][ - "onset_index" - ], - monsoon_stat_dic["REF"][reference_data_name][region][ - "decay_index" - ], - ]: - ax[region].axvline( - x=idx, - ymin=0, - ymax=dict_obs_composite[reference_data_name][ - region - ][idx].item(), - c="blue", - ls="--", - ) - # title - ax[region].set_title(region) - if region == list_monsoon_regions[0]: - ax[region].legend(loc=2) - if region == list_monsoon_regions[-1]: - if model == "obs": - data_name = "OBS: " + reference_data_name - else: - data_name = ", ".join([mip.upper(), model, exp, run]) - fig.suptitle( - "Precipitation pentad time series\n" - + "Monsoon domain composite accumulations\n" - + ", ".join( - [data_name, str(startYear) + "-" + str(endYear)] - ) - ) - plt.subplots_adjust(top=0.85) - plt.savefig( - os.path.join( - outdir(output_type="graphics"), - output_filename + ".png", - ) - ) - plt.close() - - # ================================================= - # Write dictionary to json file - # (let the json keep overwritten in model loop) - # ------------------------------------------------- - JSON = pcmdi_metrics.io.base.Base( - outdir(output_type="metrics_results"), json_filename - ) - JSON.write( - monsoon_stat_dic, - json_structure=["model", "realization", "monsoon_region", "metric"], - sort_keys=True, - indent=4, - separators=(",", ": "), - ) - if cmec: - JSON.write_cmec(indent=4, separators=(",", ": ")) - - except Exception as err: - if debug: - raise - else: - print("warning: faild for ", model, run, err) - pass - # --- Realization loop end - except Exception as err: - if debug: - raise - else: - print("warning: faild for ", model, err) - pass -# --- Model loop end - -if not debug: - sys.exit(0) diff --git a/pcmdi_metrics/monsoon_sperber/lib/argparse_functions.py b/pcmdi_metrics/monsoon_sperber/lib/argparse_functions.py index 6fa42d61a..a51cc81a3 100644 --- a/pcmdi_metrics/monsoon_sperber/lib/argparse_functions.py +++ b/pcmdi_metrics/monsoon_sperber/lib/argparse_functions.py @@ -54,7 +54,9 @@ def AddParserArgument(P): "--meyear", dest="meyear", type=int, help="End year for model data set" ) P.add_argument("--modnames", type=str, default=None, help="List of models") - P.add_argument("--list_monsoon_regions", type=str, default=None, help="List of regions") + P.add_argument( + "--list_monsoon_regions", type=str, default=None, help="List of regions" + ) P.add_argument( "-r", "--realization", diff --git a/pcmdi_metrics/monsoon_sperber/lib/divide_chunks.py b/pcmdi_metrics/monsoon_sperber/lib/divide_chunks.py index efd031175..8a215f4f8 100644 --- a/pcmdi_metrics/monsoon_sperber/lib/divide_chunks.py +++ b/pcmdi_metrics/monsoon_sperber/lib/divide_chunks.py @@ -30,7 +30,7 @@ def divide_chunks_advanced(data, n, debug=False): day = day.values calendar = "gregorian" if debug: - #print("month = ", month, "day = ", day) + # print("month = ", month, "day = ", day) print("debug: first day of year is " + str(month) + "/" + str(day)) if month not in [1, 7] or day != 1: sys.exit( diff --git a/pcmdi_metrics/monsoon_sperber/lib/model_land_only.py b/pcmdi_metrics/monsoon_sperber/lib/model_land_only.py index 2dd47cac2..b1319fb65 100644 --- a/pcmdi_metrics/monsoon_sperber/lib/model_land_only.py +++ b/pcmdi_metrics/monsoon_sperber/lib/model_land_only.py @@ -9,7 +9,7 @@ def model_land_only(model, model_timeseries, lf, debug=False): # - - - - - - - - - - - - - - - - - - - - - - - - - if debug: - #plot_map(model_timeseries[0], "_".join(["test", model, "beforeMask.png"])) + # plot_map(model_timeseries[0], "_".join(["test", model, "beforeMask.png"])) print("debug: plot for beforeMask done") # Check land fraction variable to see if it meet criteria @@ -33,7 +33,7 @@ def model_land_only(model, model_timeseries, lf, debug=False): model_timeseries_masked = model_timeseries.where(lf > 90) if debug: - #plot_map(model_timeseries_masked[0], "_".join(["test", model, "afterMask.png"])) + # plot_map(model_timeseries_masked[0], "_".join(["test", model, "afterMask.png"])) print("debug: plot for afterMask done") return model_timeseries_masked diff --git a/pcmdi_metrics/monsoon_sperber/param/Bo_param.py b/pcmdi_metrics/monsoon_sperber/param/Bo_param.py index 5ec7d9522..ca0d9bd6e 100644 --- a/pcmdi_metrics/monsoon_sperber/param/Bo_param.py +++ b/pcmdi_metrics/monsoon_sperber/param/Bo_param.py @@ -14,9 +14,9 @@ # ------------------------------------------------- update_json = False debug = False -#debug = True +# debug = True -#list_monsoon_regions = ["AIR", "AUS", "Sahel", "GoG", "NAmo", "SAmo"] +# list_monsoon_regions = ["AIR", "AUS", "Sahel", "GoG", "NAmo", "SAmo"] list_monsoon_regions = ["AUS"] # ================================================= # Observation @@ -41,7 +41,7 @@ modpath = "/work/lee1043/ESGF/xmls/cmip5/historical/day/pr/cmip5.%(model).%(exp).%(realization).day.pr.xml" modpath_lf = "/work/lee1043/ESGF/xmls/cmip5/historical/fx/sftlf/cmip5.%(model).historical.r0i0p0.fx.sftlf.xml" -#/p/css03/scratch/published-older/cmip5/output1/CSIRO-BOM/ACCESS1-0/historical/day/atmos/day/r1i1p1/v4/pr/pr_day_ACCESS1-0_historical_r1i1p1_19750101-19991231.nc +# /p/css03/scratch/published-older/cmip5/output1/CSIRO-BOM/ACCESS1-0/historical/day/atmos/day/r1i1p1/v4/pr/pr_day_ACCESS1-0_historical_r1i1p1_19750101-19991231.nc # modnames = ['ACCESS1-0', 'ACCESS1-3', 'BCC-CSM1-1', 'BCC-CSM1-1-M', 'BNU-ESM', 'CanCM4', 'CanESM2', 'CCSM4', 'CESM1-BGC', 'CESM1-CAM5', 'CESM1-FASTCHEM', 'CMCC-CESM', 'CMCC-CM', 'CMCC-CMS', 'CNRM-CM5', 'CSIRO-Mk3-6-0', 'EC-EARTH', 'FGOALS-g2', 'GFDL-CM3', 'GFDL-ESM2G', 'GFDL-ESM2M', 'GISS-E2-H', 'GISS-E2-R', 'HadGEM2-AO', 'HadGEM2-CC', 'HadGEM2-ES', 'INMCM4', 'IPSL-CM5A-LR', 'IPSL-CM5A-MR', 'IPSL-CM5B-LR', 'MIROC-ESM', 'MIROC-ESM-CHEM', 'MIROC4h', 'MIROC5', 'MPI-ESM-MR', 'MPI-ESM-P', 'MRI-CGCM3', 'MRI-ESM1', 'NorESM1-M'] # noqa @@ -60,7 +60,7 @@ # ================================================= # Output # ------------------------------------------------- -#pmprdir = "/p/user_pub/pmp/pmp_results/pmp_v1.1.2" +# pmprdir = "/p/user_pub/pmp/pmp_results/pmp_v1.1.2" pmprdir = "/p/user_pub/climate_work/dong12/PMP_result/" case_id = "{:v%Y%m%d}".format(datetime.datetime.now()) diff --git a/pcmdi_metrics/monsoon_wang/basic_monsoon_wang_param.py b/pcmdi_metrics/monsoon_wang/basic_monsoon_wang_param.py new file mode 100644 index 000000000..2755464b1 --- /dev/null +++ b/pcmdi_metrics/monsoon_wang/basic_monsoon_wang_param.py @@ -0,0 +1,57 @@ +# +# OPTIONS ARE SET BY USER IN THIS FILE AS INDICATED BELOW BY: +# +# + +# LIST OF MODEL VERSIONS TO BE TESTED + +# trouble model ICON-ESM-LR' + +modnames = [ + "ACCESS-CM2", + "ACCESS-ESM1-5", + "AWI-ESM-1-1-LR", + "BCC-CSM2-MR", + "BCC-ESM1", + "CAMS-CSM1-0", + "CanESM5", + "CAS-ESM2-0", + "CESM2", + "CESM2-WACCM", + "CIESM", + "CMCC-ESM2", + "E3SM-1-0", + "E3SM-1-1", + "E3SM-2-0", + "EC-Earth3", + "FGOALS-f3-L", + "GFDL-CM4", + "GFDL-ESM4", + "GISS-E2-1-G", + "GISS-E2-1-H", + "IPSL-CM5A2-INCA", + "IPSL-CM6A-LR", + "KIOST-ESM", + "MIROC6", + "MPI-ESM1-2-LR", + "MRI-ESM2-0", + "NESM3", + "NorESM2-LM", + "TaiESM1", +] + + +# ROOT PATH FOR MODELS CLIMATOLOGIES +test_data_path = "/p/user_pub/pmp/pmp_results/pmp_v1.1.2/diagnostic_results/CMIP_CLIMS/cmip6/historical/v20230823/pr/cmip6.historical.%(model).r1i1p1f1.mon.pr.198101-200512.AC.v20230823.nc" + +# ROOT PATH FOR OBSERVATIONS +reference_data_path = "/p/user_pub/PCMDIobs/obs4MIPs_legacy/PCMDIobs2_clims/atmos/pr/TRMM-3B43v-7/pr_mon_TRMM-3B43v-7_BE_gn_199801-201712.v20200421.AC.nc" + +# DIRECTORY WHERE TO PUT RESULTS +results_dir = "$OUTPUT_DIR$/monsoon_wang" + +# Threshold +threshold = 2.5 / 86400 + +# monsoon domain mask based on observations +obs_mask = True diff --git a/pcmdi_metrics/monsoon_wang/monsoon_precip_index_fncs.py b/pcmdi_metrics/monsoon_wang/monsoon_precip_index_fncs.py index 5fc2f81f6..b36ebdb1f 100644 --- a/pcmdi_metrics/monsoon_wang/monsoon_precip_index_fncs.py +++ b/pcmdi_metrics/monsoon_wang/monsoon_precip_index_fncs.py @@ -1,30 +1,65 @@ -import MV2 -import numpy +from typing import Union + +import numpy as np +import xarray as xr # SEASONAL RANGE - USING ANNUAL CYCLE CLIMATOLGIES 0=Jan, 11=Dec +def da_to_ds(d: Union[xr.Dataset, xr.DataArray], var: str = "variable") -> xr.Dataset: + """Convert xarray DataArray to Dataset + + Parameters + ---------- + d : Union[xr.Dataset, xr.DataArray] + Input dataArray. If dataset is given, no process will be done + var : str, optional + Name of dataArray, by default "variable" + + Returns + ------- + xr.Dataset + xarray Dataset + + Raises + ------ + TypeError + Raised when given input is not xarray based variables + """ + if isinstance(d, xr.Dataset): + return d.copy() + elif isinstance(d, xr.DataArray): + return d.to_dataset(name=var).bounds.add_missing_bounds().copy() + else: + raise TypeError( + "Input must be an instance of either xarrary.DataArray or xarrary.Dataset" + ) + + +def regrid(da_in, da_grid, data_var="pr"): + ds_in = da_to_ds(da_in, data_var) + ds_grid = da_to_ds(da_grid, data_var) + + ds_out = ds_in.regridder.horizontal(data_var, ds_grid, tool="regrid2") + da_out = ds_out[data_var] + + return da_out + + def compute_season(data, season_indices, weights): - out = numpy.ma.zeros(data.shape[1:], dtype=data.dtype) + out = np.ma.zeros(data.shape[1:], dtype=data.dtype) N = 0 for i in season_indices: out += data[i] * weights[i] N += weights[i] - out = MV2.array(out) - out.id = data.id - out.setAxisList(data.getAxisList()[1:]) return out / N def mpd(data): """Monsoon precipitation intensity and annual range calculation - .. describe:: Input - * data - * Assumes climatology array with 12 times step first one January - """ months_length = [ 31.0, @@ -44,24 +79,18 @@ def mpd(data): ndjfm = compute_season(data, [10, 11, 0, 1, 2], months_length) ann = compute_season(data, list(range(12)), months_length) - annrange = MV2.subtract(mjjas, ndjfm) + data_map = data.isel(time=0) + + annrange = mjjas - ndjfm - lat = annrange.getAxis(0) - i, e = lat.mapInterval((-91, 0, "con")) - if i > e: # reveresedlats - tmp = i + 1 - i = e + 1 - e = tmp + da_annrange = xr.DataArray(annrange, coords=data_map.coords, dims=data_map.dims) + da_annrange = da_annrange.where(da_annrange.lat >= 0, da_annrange * -1) - annrange[slice(i, e)] = -annrange[slice(i, e)] - annrange.id = data.id + "_ar" - annrange.longname = "annual range" + mpi = np.divide(da_annrange.values, ann, where=ann.astype(bool)) - mpi = MV2.divide(annrange, ann) - mpi.id = data.id + "_int" - mpi.longname = "intensity" + da_mpi = xr.DataArray(mpi, coords=data_map.coords, dims=data_map.dims) - return annrange, mpi + return da_annrange, da_mpi def mpi_skill_scores(annrange_mod_dom, annrange_obs_dom, threshold=2.5 / 86400.0): @@ -82,29 +111,43 @@ def mpi_skill_scores(annrange_mod_dom, annrange_obs_dom, threshold=2.5 / 86400.0 * threshold in same units as inputs """ - mt = numpy.ma.greater(annrange_mod_dom, threshold) - ot = numpy.ma.greater(annrange_obs_dom, threshold) + mt = np.ma.greater(annrange_mod_dom, threshold) + ot = np.ma.greater(annrange_obs_dom, threshold) hitmap = mt * ot # only where both mt and ot are True hit = float(hitmap.sum()) - xor = numpy.ma.logical_xor(mt, ot) + xor = np.ma.logical_xor(mt, ot) missmap = xor * ot - missed = float(MV2.sum(missmap)) + missed = float(missmap.sum()) falarmmap = xor * mt - falarm = float(MV2.sum(falarmmap)) + falarm = float(falarmmap.sum()) if (hit + missed + falarm) > 0.0: score = hit / (hit + missed + falarm) else: score = 1.0e20 - hitmap.id = "hit" - missmap.id = "miss" - falarmmap.id = "false_alarm" - - for a in [hitmap, missmap, falarmmap]: - a.setAxisList(annrange_mod_dom.getAxisList()) - - return hit, missed, falarm, score, hitmap, missmap, falarmmap + xr_hitmap = xr.DataArray( + hitmap, + name="hitmap", + coords={"lat": annrange_mod_dom.lat, "lon": annrange_mod_dom.lon}, + dims=["lat", "lon"], + ) + + xr_missmap = xr.DataArray( + missmap, + name="missmap", + coords={"lat": annrange_mod_dom.lat, "lon": annrange_mod_dom.lon}, + dims=["lat", "lon"], + ) + + xr_falarmmap = xr.DataArray( + falarmmap, + name="falarmmap", + coords={"lat": annrange_mod_dom.lat, "lon": annrange_mod_dom.lon}, + dims=["lat", "lon"], + ) + + return hit, missed, falarm, score, xr_hitmap, xr_missmap, xr_falarmmap diff --git a/pcmdi_metrics/monsoon_wang/monsoon_wang_driver.py b/pcmdi_metrics/monsoon_wang/monsoon_wang_driver.py index 1fed57ec8..623764a1a 100644 --- a/pcmdi_metrics/monsoon_wang/monsoon_wang_driver.py +++ b/pcmdi_metrics/monsoon_wang/monsoon_wang_driver.py @@ -2,15 +2,18 @@ import collections import os +import sys import cdms2 import numpy -from genutil import statistics +import numpy as np +import xarray as xr +from monsoon_precip_index_fncs import mpd, mpi_skill_scores, regrid import pcmdi_metrics from pcmdi_metrics import resources +from pcmdi_metrics.io import load_regions_specs, region_subset from pcmdi_metrics.mean_climate.lib.pmp_parser import PMPParser -from pcmdi_metrics.monsoon_wang import mpd, mpi_skill_scores from pcmdi_metrics.utils import StringConstructor @@ -21,6 +24,7 @@ def create_monsoon_wang_parser(): P.use("--results_dir") P.use("--reference_data_path") P.use("--test_data_path") + P.use("--obs_mask") P.add_argument( "--outnj", @@ -103,17 +107,23 @@ def monsoon_wang_runner(args): # PMP monthly default PR obs cdms2.axis.longitude_aliases.append("longitude_prclim_mpd") cdms2.axis.latitude_aliases.append("latitude_prclim_mpd") - fobs = cdms2.open(args.reference_data_path) - dobs_orig = fobs(args.obsvar) + fobs = xr.open_dataset(args.reference_data_path, decode_times=False) + dobs_orig = fobs[args.obsvar] fobs.close() - obsgrid = dobs_orig.getGrid() - ######################################## # FCN TO COMPUTE GLOBAL ANNUAL RANGE AND MONSOON PRECIP INDEX annrange_obs, mpi_obs = mpd(dobs_orig) + + # create monsoon domain mask based on observations: annual range > 2.5 mm/day + if args.obs_mask: + domain_mask_obs = xr.where(annrange_obs > thr, 1, 0) + domain_mask_obs.name = "mask" + mpi_obs = mpi_obs.where(domain_mask_obs) + nout_mpi_obs = os.path.join(outpathdata, "mpi_obs_masked.nc") + ######################################### # SETUP WHERE TO OUTPUT RESULTING DATA (netcdf) nout = os.path.join( @@ -149,6 +159,7 @@ def monsoon_wang_runner(args): if len(gmods) == 0: raise RuntimeError("No model file found!") + ######################################### egg_pth = resources.resource_path() @@ -163,6 +174,7 @@ def monsoon_wang_runner(args): globals, locals, ) + regions_specs = locals["regions_specs"] doms = ["AllMW", "AllM", "NAMM", "SAMM", "NAFM", "SAFM", "ASM", "AUSM"] @@ -182,37 +194,52 @@ def monsoon_wang_runner(args): mpi_stats_dic[mod] = {} - print( - "******************************************************************************************" - ) - print(modelFile) - f = cdms2.open(modelFile) - d_orig = f(var) + print("modelFile = ", modelFile) + f = xr.open_dataset(modelFile) + d_orig = f[var] annrange_mod, mpi_mod = mpd(d_orig) - annrange_mod = annrange_mod.regrid( - obsgrid, regridTool="regrid2", regridMethod="conserve", mkCyclic=True - ) - mpi_mod = mpi_mod.regrid( - obsgrid, regridTool="regrid2", regridMethod="conserve", mkCyclic=True - ) + + domain_mask_mod = xr.where(annrange_mod > thr, 1, 0) + mpi_mod = mpi_mod.where(domain_mask_mod) + + lats = annrange_obs.lat[0] + latn = annrange_obs.lat[-1] + lone = annrange_obs.lon[-1] + lonw = annrange_obs.lon[0] + + annrange_obs = regrid(annrange_obs, annrange_mod) + + mpi_obs = regrid(mpi_obs, mpi_mod) + + regions_specs = load_regions_specs() for dom in doms: mpi_stats_dic[mod][dom] = {} - reg_sel = regions_specs[dom]["domain"] + print("dom = ", dom) + + mpi_obs_reg = region_subset(mpi_obs, dom) + mpi_obs_reg_sd = mpi_obs_reg.std(dim=["lat", "lon"]) + mpi_mod_reg = region_subset(mpi_mod, dom) + + da1_flat = mpi_mod_reg.values.ravel() + da2_flat = mpi_obs_reg.values.ravel() - mpi_obs_reg = mpi_obs(reg_sel) - mpi_obs_reg_sd = float(statistics.std(mpi_obs_reg, axis="xy")) - mpi_mod_reg = mpi_mod(reg_sel) + cor = np.ma.corrcoef( + np.ma.masked_invalid(da1_flat), np.ma.masked_invalid(da2_flat) + )[0, 1] + + squared_diff = (mpi_mod_reg - mpi_obs_reg) ** 2 + mean_squared_error = squared_diff.mean(skipna=True) + rms = np.sqrt(mean_squared_error) - cor = float(statistics.correlation(mpi_mod_reg, mpi_obs_reg, axis="xy")) - rms = float(statistics.rms(mpi_mod_reg, mpi_obs_reg, axis="xy")) rmsn = rms / mpi_obs_reg_sd # DOMAIN SELECTED FROM GLOBAL ANNUAL RANGE FOR MODS AND OBS - annrange_mod_dom = annrange_mod(reg_sel) - annrange_obs_dom = annrange_obs(reg_sel) + + annrange_mod_dom = region_subset(annrange_mod, dom) + annrange_obs_dom = region_subset(annrange_obs, dom) # SKILL SCORES # HIT/(HIT + MISSED + FALSE ALARMS) @@ -227,15 +254,23 @@ def monsoon_wang_runner(args): mpi_stats_dic[mod][dom]["threat_score"] = format(score, sig_digits) # SAVE ANNRANGE AND HIT MISS AND FALSE ALARM FOR EACH MOD DOM - fm = os.path.join(nout, "_".join([mod, dom, "wang-monsoon.nc"])) - g = cdms2.open(fm, "w") - g.write(annrange_mod_dom) - g.write(hitmap, dtype=numpy.int32) - g.write(missmap, dtype=numpy.int32) - g.write(falarmmap, dtype=numpy.int32) - g.close() + fm = os.path.join(nout, "_".join([mod, dom, "wang-monsoon_xcdat.nc"])) + ds_out = xr.Dataset( + { + "obsmap": annrange_obs_dom, + "modmap": annrange_mod_dom, + "hitmap": hitmap, + "missmap": missmap, + "falarmmap": falarmmap, + } + ) + ds_out.to_netcdf(fm) f.close() + if np.isnan(cor): + print("invalid correlation values") + sys.exit() + # OUTPUT METRICS TO JSON FILE OUT = pcmdi_metrics.io.base.Base(os.path.abspath(jout), json_filename) diff --git a/pcmdi_metrics/monsoon_wang/run.py b/pcmdi_metrics/monsoon_wang/run.py new file mode 100644 index 000000000..004963366 --- /dev/null +++ b/pcmdi_metrics/monsoon_wang/run.py @@ -0,0 +1,5 @@ +from monsoon_wang_driver import create_monsoon_wang_parser, monsoon_wang_runner + +P = create_monsoon_wang_parser() +args = P.get_parameter(argparse_vals_only=False) +monsoon_wang_runner(args) diff --git a/share/DefArgsCIA.json b/share/DefArgsCIA.json index cd33a055d..8507f33ba 100644 --- a/share/DefArgsCIA.json +++ b/share/DefArgsCIA.json @@ -163,4 +163,4 @@ ], "help":"A list of variables to be processed" } -} \ No newline at end of file +} diff --git a/share/default_regions.py b/share/default_regions.py index 07b2382fc..72eb7804a 100755 --- a/share/default_regions.py +++ b/share/default_regions.py @@ -50,7 +50,8 @@ }, # North African Monsoon "NAFM": { - "domain": cdutil.region.domain(latitude=(0.0, 45.0), longitude=(310.0, 60.0)) + # "domain": cdutil.region.domain(latitude=(0.0, 45.0), longitude=(310.0, 60.0)) + "domain": cdutil.region.domain(latitude=(0.0, 45.0), longitude=(-50, 60.0)) }, # South African Monsoon "SAFM": {