Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ evaluation:
summary_plots : true
summary_dir: "./plots/"
plot_ensemble: "members" #supported: false, "std", "minmax", "members"
plot_score_maps: false #plot scores on a 2D maps. it slows down score computation
print_summary: false #print out score values on screen. it can be verbose
log_scale: false
add_grid: false
Expand Down
15 changes: 10 additions & 5 deletions packages/evaluate/src/weathergen/evaluate/clim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,13 @@ def align_clim_data(
for fstep, target_data in target_output.items():
samples = np.unique(target_data.sample.values)
for sample in tqdm(samples, f"Aligning climatology for forecast step {fstep}"):
sample_mask = target_data.sample.values == sample
timestamp = target_data.valid_time.values[sample_mask][0]
sel_key = "sample" if "sample" in target_data.dims else "ipoint"
sel_val = (
sample if "sample" in target_data.dims else (target_data.sample.values == sample)
)
sel_mask = {sel_key: sel_val}

timestamp = target_data.sel(sel_mask).valid_time.values[0]
# Prepare climatology data for each sample
matching_time_idx = match_climatology_time(timestamp, clim_data)

Expand All @@ -141,8 +146,8 @@ def align_clim_data(
)
.transpose("grid_points", "channels") # dimensions specific to anemoi
)
target_lats = target_data.loc[{"ipoint": sample_mask}].lat.values
target_lons = target_data.loc[{"ipoint": sample_mask}].lon.values
target_lats = target_data.loc[sel_mask].lat.values
target_lons = target_data.loc[sel_mask].lon.values
# check if target coords match cached target coords
# if they do, use cached clim_indices
if (
Expand Down Expand Up @@ -174,7 +179,7 @@ def align_clim_data(
clim_values = prepared_clim_data.isel(grid_points=clim_indices).values
try:
if len(samples) > 1:
aligned_clim_data[fstep].loc[{"ipoint": sample_mask}] = clim_values
aligned_clim_data[fstep].loc[sel_mask] = clim_values
else:
aligned_clim_data[fstep] = clim_values
except (ValueError, IndexError) as e:
Expand Down
180 changes: 136 additions & 44 deletions packages/evaluate/src/weathergen/evaluate/io_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict[str, str] |
self.run_id = run_id
self.private_paths = private_paths
self.streams = eval_cfg.streams.keys()
self.data = None
# TODO: propagate it to the other functions using global plotting opts
self.global_plotting_options = eval_cfg.get("global_plotting_options", {})

# If results_base_dir and model_base_dir are not provided, default paths are used
self.model_base_dir = self.eval_cfg.get("model_base_dir", None)
Expand Down Expand Up @@ -130,6 +131,13 @@ def get_ensemble(self, stream: str | None = None) -> list[str]:
"""Placeholder implementation ensemble member names getter. Override in subclass."""
return list()

def is_regular(self, stream: str) -> bool:
"""
Placeholder implementation to check if lat/lon are regularly spaced.
Override in subclass.
"""
return True

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray:
"""Placeholder to load pre-computed scores for a given run, stream, metric"""
return None
Expand Down Expand Up @@ -496,9 +504,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non

if not self.fname_zarr.exists() or not self.fname_zarr.is_dir():
_logger.error(f"Zarr file {self.fname_zarr} does not exist.")
# raise FileNotFoundError(
# f"Zarr file {self.fname_zarr} does not exist or is not a directory."
# )
raise FileNotFoundError(
f"Zarr file {self.fname_zarr} does not exist or is not a directory."
)

def get_inference_config(self):
"""
Expand Down Expand Up @@ -610,8 +618,7 @@ def get_data(

for fstep in fsteps:
_logger.info(f"RUN {self.run_id} - {stream}: Processing fstep {fstep}...")
da_tars_fs, da_preds_fs = [], []
pps = []
da_tars_fs, da_preds_fs, pps = [], [], []

for sample in tqdm(samples, desc=f"Processing {self.run_id} - {stream} - {fstep}"):
out = zio.get_data(sample, stream, fstep)
Expand Down Expand Up @@ -642,60 +649,57 @@ def get_data(
_logger.debug(f"Selecting ensemble members {ensemble}.")
pred = pred.sel(ens=ensemble)

if ensemble == ["mean"]:
_logger.debug("Averaging over ensemble members.")
pred = pred.mean("ens", keepdims=True)
else:
_logger.debug(f"Selecting ensemble members {ensemble}.")
pred = pred.sel(ens=ensemble)

da_tars_fs.append(target.squeeze())
da_preds_fs.append(pred.squeeze())

if len(da_tars_fs) > 0:
fsteps_final.append(fstep)
if not da_tars_fs:
_logger.info(
f"[{self.run_id} - {stream}] No valid data found for fstep {fstep}."
)
continue

fsteps_final.append(fstep)

_logger.debug(
f"Concatenating targets and predictions for stream {stream}, "
f"forecast_step {fstep}..."
)

if da_tars_fs:
# faster processing
if self.is_regular(stream):
# Efficient concatenation for regular grid
da_preds_fs = _force_consistent_grids(da_preds_fs)
da_tars_fs = _force_consistent_grids(da_tars_fs)

else:
# Irregular (scatter) case. concatenate over ipoint
da_tars_fs = xr.concat(da_tars_fs, dim="ipoint")
da_preds_fs = xr.concat(da_preds_fs, dim="ipoint")
if len(samples) == 1:
# Ensure sample coordinate is repeated along ipoint even if only one sample
da_tars_fs = da_tars_fs.assign_coords(
sample=(
"ipoint",
np.repeat(da_tars_fs.sample.values, len(da_tars_fs.ipoint)),
)
)
da_preds_fs = da_preds_fs.assign_coords(
sample=(
"ipoint",
np.repeat(da_preds_fs.sample.values, len(da_preds_fs.ipoint)),
)
)

if set(channels) != set(all_channels):
_logger.debug(
f"Restricting targets and predictions to channels {channels} "
f"for stream {stream}..."
if len(samples) == 1:
_logger.debug("Repeating sample coordinate for single-sample case.")
for da in (da_tars_fs, da_preds_fs):
da.assign_coords(
sample=("ipoint", np.repeat(da.sample.values, da.sizes["ipoint"]))
)

da_tars_fs, da_preds_fs, channels = dc.get_derived_channels(
da_tars_fs, da_preds_fs
)
if set(channels) != set(all_channels):
_logger.debug(
f"Restricting targets and predictions to channels {channels} "
f"for stream {stream}..."
)

da_tars_fs = da_tars_fs.sel(channel=channels)
da_preds_fs = da_preds_fs.sel(channel=channels)
da_tars_fs, da_preds_fs, channels = dc.get_derived_channels(
da_tars_fs, da_preds_fs
)

da_tars.append(da_tars_fs)
da_preds.append(da_preds_fs)
da_tars_fs = da_tars_fs.sel(channel=channels)
da_preds_fs = da_preds_fs.sel(channel=channels)

if return_counts:
points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps)
da_tars.append(da_tars_fs)
da_preds.append(da_preds_fs)
if return_counts:
points_per_sample.loc[{"forecast_step": fstep}] = np.array(pps)

# Safer than a list
da_tars = {fstep: da for fstep, da in zip(fsteps_final, da_tars, strict=True)}
Expand Down Expand Up @@ -796,14 +800,65 @@ def get_channels(self, stream: str) -> list[str]:
return all_channels

def get_ensemble(self, stream: str | None = None) -> list[str]:
"""Get the list of ensemble member names for a given stream from the config."""
"""Get the list of ensemble member names for a given stream from the config.
Parameters
----------
stream : str
The name of the stream to get channels for.

Returns
-------
list[str]
A list of ensemble members.
"""
_logger.debug(f"Getting ensembles for stream {stream}...")

# TODO: improve this to get ensemble from io class
with ZarrIO(self.fname_zarr) as zio:
dummy = zio.get_data(0, stream, zio.forecast_steps[0])
return list(dummy.prediction.as_xarray().coords["ens"].values)

# TODO: improve this
def is_regular(self, stream: str) -> bool:
"""Check if the latitude and longitude coordinates are regularly spaced for a given stream.
Parameters
----------
stream : str
The name of the stream to get channels for.

Returns
-------
bool
True if the stream is regularly spaced. False otherwise.
"""
_logger.debug(f"Checking regular spacing for stream {stream}...")

with ZarrIO(self.fname_zarr) as zio:
dummy = zio.get_data(0, stream, zio.forecast_steps[0])

sample_idx = zio.samples[1] if len(zio.samples) > 1 else zio.samples[0]
fstep_idx = (
zio.forecast_steps[1] if len(zio.forecast_steps) > 1 else zio.forecast_steps[0]
)
dummy1 = zio.get_data(sample_idx, stream, fstep_idx)

da = dummy.prediction.as_xarray()
da1 = dummy1.prediction.as_xarray()

if (
da["lat"].shape != da1["lat"].shape
or da["lon"].shape != da1["lon"].shape
or not (
np.allclose(sorted(da["lat"].values), sorted(da1["lat"].values))
and np.allclose(sorted(da["lon"].values), sorted(da1["lon"].values))
)
):
_logger.debug("Latitude and/or longitude coordinates are not regularly spaced.")
return False

_logger.debug("Latitude and longitude coordinates are regularly spaced.")
return True

def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | None:
"""
Load the pre-computed scores for a given run, stream and metric and epoch.
Expand Down Expand Up @@ -859,3 +914,40 @@ def get_inference_stream_attr(self, stream_name: str, key: str, default=None):
if stream.get("name") == stream_name:
return stream.get(key, default)
return default


################### Helper functions ########################


def _force_consistent_grids(ref: list[xr.DataArray]) -> xr.DataArray:
"""
Force all samples to share the same ipoint order.

Parameters
----------
ref:
Input dataset
Returns
-------
xr.DataArray
Returns a Dataset where all samples have the same lat lon and ipoint ordering
"""

# Pick first sample as reference
ref_lat = ref[0].lat
ref_lon = ref[0].lon

sort_idx = np.lexsort((ref_lon.values, ref_lat.values))
npoints = sort_idx.size
aligned = []
for a in ref:
a_sorted = a.isel(ipoint=sort_idx)

a_sorted = a_sorted.assign_coords(
ipoint=np.arange(npoints),
lat=("ipoint", ref_lat.values[sort_idx]),
lon=("ipoint", ref_lon.values[sort_idx]),
)
aligned.append(a_sorted)

return xr.concat(aligned, dim="sample")
Loading