diff --git a/climetlab_weatherbench/main.py b/climetlab_weatherbench/main.py index 05d1561..d039167 100644 --- a/climetlab_weatherbench/main.py +++ b/climetlab_weatherbench/main.py @@ -7,15 +7,110 @@ # nor does it submit to any jurisdiction. from __future__ import annotations +import logging + import climetlab as cml from climetlab import Dataset from climetlab.decorators import normalize +from climetlab.sources.multi import MultiSource + +LOG = logging.getLogger(__name__) -__version__ = "0.1.0" +__version__ = "0.2.0" URL = "https://storage.ecmwf.europeanweather.cloud" -PATTERN = "{url}/WeatherBench/" "{parameter}_{year}_5.625deg.nc" +PATTERN = "{url}/WeatherBench/{parameter}hPa_{year}_5.625deg.nc" + +ALL_MONTHS = [ + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "09", + "10", + "11", + "12", +] +ALL_DAYS = [ + "01", + "02", + "03", + "04", + "05", + "06", + "07", + "08", + "09", + "10", + "11", + "12", + "13", + "14", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + "26", + "27", + "28", + "29", + "30", + "31", +] + + +class Request: + sources = None + + +class CDSRequest(Request): + def __init__(self, year, parameter, grid) -> list: + request = { + "product_type": "reanalysis", + "format": "netcdf", + "year": f"{year}", + "month": ALL_MONTHS, + "day": ALL_DAYS, + "time": [ + "00:00", + "06:00", + "12:00", + "18:00", + ], + "grid": [grid, grid], + } + if "_" in parameter: + param_split = parameter.split("_") + variable = param_split[0] + level = param_split[1] + request["pressure_level"] = level + cds_source = "reanalysis-era5-pressure-levels" + else: + variable = parameter + cds_source = "reanalysis-era5-single-levels" + request["variable"] = variable + + self.source = cml.load_source("cds", cds_source, request) + + +class UrlRequest(Request): + @normalize("year", list(range(1979, 2019)), multiple=False) + @normalize("grid", [5.625], multiple=False) + def __init__(self, year, parameter, grid) -> list: + request = dict(parameter=parameter, url=URL, year=year) + self.source = cml.load_source("url-pattern", PATTERN, request) class Main(Dataset): @@ -45,7 +140,58 @@ class Main(Dataset): dataset = None - @normalize("parameter", ["geopotential_500hPa", "temperature_850hPa"]) - def __init__(self, year, parameter): - request = dict(parameter=parameter, url=URL, year=year) - self.source = cml.load_source("url-pattern", PATTERN, request) + @normalize( + "parameter", + ["geopotential_500", "temperature_850"], + aliases={ + "temperature_850Hpa": "temperature_850", + "geopotential_500Hpa": "geopotential_500", + }, + multiple=True, + ) + @normalize("year", list(range(1979, 2022)), multiple=True) + @normalize("grid", [5.625, 0.1, 0.25]) # TODO give real values here. + # @normalize("level", [500,850]) + # def __init__(self, year, parameter, grid=5.625, level): + def __init__(self, year, parameter, grid=5.625): + sources = [] + for p in parameter: + + sources_many_years = [] + for y in year: + sources_many_years += self._get_sources(y, p, grid) + sources.append( + MultiSource( + sources_many_years, + merger="concat(dim=time)", + ) + ) + + # Merging manually latter because we need special option to merge + self.source = MultiSource(sources) + # self.source = MultiSource(sources, merge='merge()') + + def _get_sources(self, year, p, grid) -> list: + for cls in [UrlRequest, CDSRequest]: + try: + s = cls(year, p, grid).source + except ValueError as e: + LOG.debug(str(e)) + continue + + # this hack is to bypass the multisource merge + if isinstance(s, MultiSource): + return s.sources + else: + return [s] + + LOG.error("Cannot find data for ", year, p, grid) + return [] # or raise exception ? + + def to_xarray(self, **kwargs): + options = dict(xarray_open_mfdataset_kwargs=dict(compat="override")) + options.update(kwargs) + ds = self.source.to_xarray(**options) + if "level" in ds.variables: + ds = ds.drop("level") + return ds diff --git a/climetlab_weatherbench/weatherbench_score.py b/climetlab_weatherbench/weatherbench_score.py new file mode 100644 index 0000000..d6b5585 --- /dev/null +++ b/climetlab_weatherbench/weatherbench_score.py @@ -0,0 +1,118 @@ +""" +Functions for evaluating forecasts. +""" +import numpy as np +import xarray as xr + + +def load_test_data(path, var, years=slice("2017", "2018")): + """ + Load the test dataset. If z return z500, if t return t850. + Args: + path: Path to nc files + var: variable. Geopotential = 'z', Temperature = 't' + years: slice for time window + + Returns: + dataset: Concatenated dataset for 2017 and 2018 + """ + ds = xr.open_mfdataset(f"{path}/*.nc", combine="by_coords")[var] + if var in ["z", "t"]: + if len(ds["level"].dims) > 0: + try: + ds = ds.sel(level=500 if var == "z" else 850).drop("level") + except ValueError: + ds = ds.drop("level") + else: + assert ( + ds["level"].values == 500 if var == "z" else ds["level"].values == 850 + ) + return ds.sel(time=years) + + +def compute_weighted_rmse(da_fc, da_true, mean_dims=xr.ALL_DIMS): + """ + Compute the RMSE with latitude weighting from two xr.DataArrays. + + Args: + da_fc (xr.DataArray): Forecast. Time coordinate must be validation time. + da_true (xr.DataArray): Truth. + mean_dims: dimensions over which to average score + Returns: + rmse: Latitude weighted root mean squared error + """ + error = da_fc - da_true + weights_lat = np.cos(np.deg2rad(error.lat)) + weights_lat /= weights_lat.mean() + rmse = np.sqrt(((error) ** 2 * weights_lat).mean(mean_dims)) + return rmse + + +def compute_weighted_acc(da_fc, da_true, mean_dims=xr.ALL_DIMS): + """ + Compute the ACC with latitude weighting from two xr.DataArrays. + WARNING: Does not work if datasets contain NaNs + + Args: + da_fc (xr.DataArray): Forecast. Time coordinate must be validation time. + da_true (xr.DataArray): Truth. + mean_dims: dimensions over which to average score + Returns: + acc: Latitude weighted acc + """ + + clim = da_true.mean("time") + try: + t = np.intersect1d(da_fc.time, da_true.time) + fa = da_fc.sel(time=t) - clim + except AttributeError: + t = da_true.time.values + fa = da_fc - clim + a = da_true.sel(time=t) - clim + + weights_lat = np.cos(np.deg2rad(da_fc.lat)) + weights_lat /= weights_lat.mean() + w = weights_lat + + fa_prime = fa - fa.mean() + a_prime = a - a.mean() + + acc = np.sum(w * fa_prime * a_prime) / np.sqrt( + np.sum(w * fa_prime**2) * np.sum(w * a_prime**2) + ) + return acc + + +def compute_weighted_mae(da_fc, da_true, mean_dims=xr.ALL_DIMS): + """ + Compute the MAE with latitude weighting from two xr.DataArrays. + Args: + da_fc (xr.DataArray): Forecast. Time coordinate must be validation time. + da_true (xr.DataArray): Truth. + mean_dims: dimensions over which to average score + Returns: + mae: Latitude weighted root mean absolute error + """ + error = da_fc - da_true + weights_lat = np.cos(np.deg2rad(error.lat)) + weights_lat /= weights_lat.mean() + mae = (np.abs(error) * weights_lat).mean(mean_dims) + return mae + + +def evaluate_iterative_forecast(da_fc, da_true, func, mean_dims=xr.ALL_DIMS): + """ + Compute iterative score (given by func) with latitude weighting from two xr.DataArrays. + Args: + da_fc (xr.DataArray): Iterative Forecast. Time coordinate must be initialization time. + da_true (xr.DataArray): Truth. + mean_dims: dimensions over which to average score + Returns: + score: Latitude weighted score + """ + rmses = [] + for f in da_fc.lead_time: + fc = da_fc.sel(lead_time=f) + fc["time"] = fc.time + np.timedelta64(int(f), "h") + rmses.append(func(fc, da_true, mean_dims)) + return xr.concat(rmses, "lead_time") diff --git a/notebooks/demo_merge.ipynb b/notebooks/demo_merge.ipynb new file mode 100644 index 0000000..1a66585 --- /dev/null +++ b/notebooks/demo_merge.ipynb @@ -0,0 +1,710 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "pleasant-howard", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install climetlab\n", + "#!pip install climetlab_weatherbench" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fd0ab86", + "metadata": {}, + "outputs": [], + "source": [ + "import climetlab as cml\n", + "cmlds = cml.load_dataset(\n", + " \"weatherbench\", \n", + " year=2017, \n", + " # year=[2018, 2017], # not working\n", + " parameter=[\"temperature_850\", \"geopotential_500\"]\n", + " # parameter=\"temperature_850Hpa\", # ok\n", + ")\n", + "ds = cmlds.to_xarray()\n", + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "affiliated-binding", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "now merging these:\n", + "Url(https://storage.ecmwf.europeanweather.cloud/WeatherBench/temperature_850hPa_2018_5.625deg.nc)\n", + "now merging these:\n", + "Url(https://storage.ecmwf.europeanweather.cloud/WeatherBench/geopotential_500hPa_2018_5.625deg.nc)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:  (lon: 64, lat: 32, time: 8760)\n",
+       "Coordinates:\n",
+       "  * lon      (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4\n",
+       "    level    int32 ...\n",
+       "  * lat      (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19\n",
+       "  * time     (time) datetime64[ns] 2018-01-01 ... 2018-12-31T23:00:00\n",
+       "Data variables:\n",
+       "    t        (time, lat, lon) float32 dask.array<chunksize=(8760, 32, 64), meta=np.ndarray>\n",
+       "    z        (time, lat, lon) float32 dask.array<chunksize=(8760, 32, 64), meta=np.ndarray>\n",
+       "Attributes:\n",
+       "    Conventions:  CF-1.6\n",
+       "    history:      2019-11-19 01:26:35 GMT by grib_to_netcdf-2.14.0: /opt/ecmw...
" + ], + "text/plain": [ + "\n", + "Dimensions: (lon: 64, lat: 32, time: 8760)\n", + "Coordinates:\n", + " * lon (lon) float64 0.0 5.625 11.25 16.88 ... 337.5 343.1 348.8 354.4\n", + " level int32 ...\n", + " * lat (lat) float64 -87.19 -81.56 -75.94 -70.31 ... 75.94 81.56 87.19\n", + " * time (time) datetime64[ns] 2018-01-01 ... 2018-12-31T23:00:00\n", + "Data variables:\n", + " t (time, lat, lon) float32 dask.array\n", + " z (time, lat, lon) float32 dask.array\n", + "Attributes:\n", + " Conventions: CF-1.6\n", + " history: 2019-11-19 01:26:35 GMT by grib_to_netcdf-2.14.0: /opt/ecmw..." + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import climetlab as cml\n", + "\n", + "acmlds = cml.load_dataset(\n", + " \"weatherbench\",\n", + " year=2018,\n", + " parameter=\"temperature_850\",\n", + ")\n", + "a = acmlds.to_xarray()\n", + "\n", + "bcmlds = cml.load_dataset(\n", + " \"weatherbench\",\n", + " year=2018,\n", + " parameter=\"geopotential_500\",\n", + ")\n", + "b = bcmlds.to_xarray()\n", + "\n", + "from climetlab.sources.multi import MultiSource\n", + "m = MultiSource([acmlds.source,bcmlds.source])\n", + "\n", + "m.to_xarray(xarray_open_mfdataset_kwargs=dict(compat='override'))\n", + "#m.to_xarray(xarray_open_mfdataset_kwargs={'merge_kwargs': dict(compat='override')})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "boring-bryan", + "metadata": {}, + "outputs": [], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "755d711c", + "metadata": {}, + "outputs": [], + "source": [ + "b" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5537fb39", + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr \n", + "xr.merge([a,b],**dict(compat='override'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b126147a", + "metadata": {}, + "outputs": [], + "source": [ + "m.to_xarray(xarray_open_mfdataset_kwargs={'merge_kwargs': dict(compat='override')})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f37834e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" + }, + "kernelspec": { + "display_name": "Python 3.9.12 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}