Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions #1

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 152 additions & 6 deletions climetlab_weatherbench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,110 @@
# nor does it submit to any jurisdiction.
from __future__ import annotations

import logging

import climetlab as cml
from climetlab import Dataset
from climetlab.decorators import normalize
from climetlab.sources.multi import MultiSource

LOG = logging.getLogger(__name__)

__version__ = "0.1.0"
__version__ = "0.2.0"

URL = "https://storage.ecmwf.europeanweather.cloud"

PATTERN = "{url}/WeatherBench/" "{parameter}_{year}_5.625deg.nc"
PATTERN = "{url}/WeatherBench/{parameter}hPa_{year}_5.625deg.nc"

ALL_MONTHS = [
"01",
"02",
"03",
"04",
"05",
"06",
"07",
"08",
"09",
"10",
"11",
"12",
]
ALL_DAYS = [
"01",
"02",
"03",
"04",
"05",
"06",
"07",
"08",
"09",
"10",
"11",
"12",
"13",
"14",
"15",
"16",
"17",
"18",
"19",
"20",
"21",
"22",
"23",
"24",
"25",
"26",
"27",
"28",
"29",
"30",
"31",
]


class Request:
sources = None


class CDSRequest(Request):
def __init__(self, year, parameter, grid) -> list:
request = {
"product_type": "reanalysis",
"format": "netcdf",
"year": f"{year}",
"month": ALL_MONTHS,
"day": ALL_DAYS,
"time": [
"00:00",
"06:00",
"12:00",
"18:00",
],
"grid": [grid, grid],
}
if "_" in parameter:
param_split = parameter.split("_")
variable = param_split[0]
level = param_split[1]
request["pressure_level"] = level
cds_source = "reanalysis-era5-pressure-levels"
else:
variable = parameter
cds_source = "reanalysis-era5-single-levels"
request["variable"] = variable

self.source = cml.load_source("cds", cds_source, request)


class UrlRequest(Request):
@normalize("year", list(range(1979, 2019)), multiple=False)
@normalize("grid", [5.625], multiple=False)
def __init__(self, year, parameter, grid) -> list:
request = dict(parameter=parameter, url=URL, year=year)
self.source = cml.load_source("url-pattern", PATTERN, request)


class Main(Dataset):
Expand Down Expand Up @@ -45,7 +140,58 @@ class Main(Dataset):

dataset = None

@normalize("parameter", ["geopotential_500hPa", "temperature_850hPa"])
def __init__(self, year, parameter):
request = dict(parameter=parameter, url=URL, year=year)
self.source = cml.load_source("url-pattern", PATTERN, request)
@normalize(
"parameter",
["geopotential_500", "temperature_850"],
aliases={
"temperature_850Hpa": "temperature_850",
"geopotential_500Hpa": "geopotential_500",
},
multiple=True,
)
@normalize("year", list(range(1979, 2022)), multiple=True)
@normalize("grid", [5.625, 0.1, 0.25]) # TODO give real values here.
# @normalize("level", [500,850])
# def __init__(self, year, parameter, grid=5.625, level):
def __init__(self, year, parameter, grid=5.625):
sources = []
for p in parameter:

sources_many_years = []
for y in year:
sources_many_years += self._get_sources(y, p, grid)
sources.append(
MultiSource(
sources_many_years,
merger="concat(dim=time)",
)
)

# Merging manually latter because we need special option to merge
self.source = MultiSource(sources)
# self.source = MultiSource(sources, merge='merge()')

def _get_sources(self, year, p, grid) -> list:
for cls in [UrlRequest, CDSRequest]:
try:
s = cls(year, p, grid).source
except ValueError as e:
LOG.debug(str(e))
continue

# this hack is to bypass the multisource merge
if isinstance(s, MultiSource):
return s.sources
else:
return [s]

LOG.error("Cannot find data for ", year, p, grid)
return [] # or raise exception ?

def to_xarray(self, **kwargs):
options = dict(xarray_open_mfdataset_kwargs=dict(compat="override"))
options.update(kwargs)
ds = self.source.to_xarray(**options)
if "level" in ds.variables:
ds = ds.drop("level")
return ds
118 changes: 118 additions & 0 deletions climetlab_weatherbench/weatherbench_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Functions for evaluating forecasts.
"""
import numpy as np
import xarray as xr


def load_test_data(path, var, years=slice("2017", "2018")):
"""
Load the test dataset. If z return z500, if t return t850.
Args:
path: Path to nc files
var: variable. Geopotential = 'z', Temperature = 't'
years: slice for time window

Returns:
dataset: Concatenated dataset for 2017 and 2018
"""
ds = xr.open_mfdataset(f"{path}/*.nc", combine="by_coords")[var]
if var in ["z", "t"]:
if len(ds["level"].dims) > 0:
try:
ds = ds.sel(level=500 if var == "z" else 850).drop("level")
except ValueError:
ds = ds.drop("level")
else:
assert (
ds["level"].values == 500 if var == "z" else ds["level"].values == 850
)
return ds.sel(time=years)


def compute_weighted_rmse(da_fc, da_true, mean_dims=xr.ALL_DIMS):
"""
Compute the RMSE with latitude weighting from two xr.DataArrays.

Args:
da_fc (xr.DataArray): Forecast. Time coordinate must be validation time.
da_true (xr.DataArray): Truth.
mean_dims: dimensions over which to average score
Returns:
rmse: Latitude weighted root mean squared error
"""
error = da_fc - da_true
weights_lat = np.cos(np.deg2rad(error.lat))
weights_lat /= weights_lat.mean()
rmse = np.sqrt(((error) ** 2 * weights_lat).mean(mean_dims))
return rmse


def compute_weighted_acc(da_fc, da_true, mean_dims=xr.ALL_DIMS):
"""
Compute the ACC with latitude weighting from two xr.DataArrays.
WARNING: Does not work if datasets contain NaNs

Args:
da_fc (xr.DataArray): Forecast. Time coordinate must be validation time.
da_true (xr.DataArray): Truth.
mean_dims: dimensions over which to average score
Returns:
acc: Latitude weighted acc
"""

clim = da_true.mean("time")
try:
t = np.intersect1d(da_fc.time, da_true.time)
fa = da_fc.sel(time=t) - clim
except AttributeError:
t = da_true.time.values
fa = da_fc - clim
a = da_true.sel(time=t) - clim

weights_lat = np.cos(np.deg2rad(da_fc.lat))
weights_lat /= weights_lat.mean()
w = weights_lat

fa_prime = fa - fa.mean()
a_prime = a - a.mean()

acc = np.sum(w * fa_prime * a_prime) / np.sqrt(
np.sum(w * fa_prime**2) * np.sum(w * a_prime**2)
)
return acc


def compute_weighted_mae(da_fc, da_true, mean_dims=xr.ALL_DIMS):
"""
Compute the MAE with latitude weighting from two xr.DataArrays.
Args:
da_fc (xr.DataArray): Forecast. Time coordinate must be validation time.
da_true (xr.DataArray): Truth.
mean_dims: dimensions over which to average score
Returns:
mae: Latitude weighted root mean absolute error
"""
error = da_fc - da_true
weights_lat = np.cos(np.deg2rad(error.lat))
weights_lat /= weights_lat.mean()
mae = (np.abs(error) * weights_lat).mean(mean_dims)
return mae


def evaluate_iterative_forecast(da_fc, da_true, func, mean_dims=xr.ALL_DIMS):
"""
Compute iterative score (given by func) with latitude weighting from two xr.DataArrays.
Args:
da_fc (xr.DataArray): Iterative Forecast. Time coordinate must be initialization time.
da_true (xr.DataArray): Truth.
mean_dims: dimensions over which to average score
Returns:
score: Latitude weighted score
"""
rmses = []
for f in da_fc.lead_time:
fc = da_fc.sel(lead_time=f)
fc["time"] = fc.time + np.timedelta64(int(f), "h")
rmses.append(func(fc, da_true, mean_dims))
return xr.concat(rmses, "lead_time")
Loading