From d2fcc9625e1386b2d32bda3972214103e66ef6e3 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 15 Oct 2025 08:55:34 +0200 Subject: [PATCH 01/26] store lead time in OutptDataset and make it available to evaluation --- packages/common/src/weathergen/common/io.py | 38 ++++++++++++++++++--- src/weathergen/utils/validation_io.py | 1 + 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 3e7594d1c..9c0ce8e13 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -148,6 +148,7 @@ class OutputDataset: channels: list[str] geoinfo_channels: list[str] + lead_time: int @functools.cached_property def arrays(self) -> dict[str, zarr.Array | NDArray]: @@ -186,6 +187,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "sample": [self.item_key.sample], "stream": [self.item_key.stream], "forecast_step": [self.item_key.forecast_step], + "lead_time": ("forecast_step", [self.lead_time]), "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data "valid_time": ("ipoint", times.astype("datetime64[ns]")), @@ -285,6 +287,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels + dataset_group.attrs["lead_time"] = dataset.lead_time def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords @@ -302,6 +305,14 @@ def _create_dataset(self, group: zarr.Group, name: str, array: NDArray): ) group.create_dataset(name, data=array, chunks=chunks) + @functools.cached_property + def example_key(self) -> ItemKey: + sample, example_sample = next(self.data_root.groups()) + stream, example_stream = next(example_sample.groups()) + fstep, example_item = next(example_stream.groups()) + + return ItemKey(sample, fstep, stream) + @functools.cached_property def samples(self) -> list[int]: """Query available samples in this zarr store.""" @@ -320,8 +331,17 @@ def forecast_steps(self) -> list[int]: # assume stream/samples/forecast_steps are orthogonal _, example_sample = next(self.data_root.groups()) _, example_stream = next(example_sample.groups()) + return list(example_stream.group_keys()) + @functools.cached_property + def lead_times(self) -> list[int]: + """Calculate available lead times from available forecast steps and len_hrs.""" + example_prediction = self.load_zarr(self.example_key).prediction + len_hrs = example_prediction.lead_time // self.example_key.forecast_step + + return [step * len_hrs for step in self.forecast_steps] + @dataclasses.dataclass class DataCoordinates: @@ -365,6 +385,7 @@ class OutputBatchData: sample_start: int forecast_offset: int + len_hrs: int @functools.cached_property def samples(self): @@ -415,8 +436,10 @@ def extract(self, key: ItemKey) -> OutputItem: "Number of channel names does not align with prediction data." ) + lead_time = self.len_hrs * key + if key.with_source: - source_dataset = self._extract_sources(offset_key.sample, stream_idx, key) + source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time) else: source_dataset = None @@ -425,9 +448,15 @@ def extract(self, key: ItemKey) -> OutputItem: return OutputItem( key=key, source=source_dataset, - target=OutputDataset("target", key, target_data, **dataclasses.asdict(data_coords)), + target=OutputDataset( + "target", key, target_data, lead_time=lead_time, **dataclasses.asdict(data_coords) + ), prediction=OutputDataset( - "prediction", key, preds_data, **dataclasses.asdict(data_coords) + "prediction", + key, + preds_data, + lead_time=lead_time, + **dataclasses.asdict(data_coords), ), ) @@ -487,7 +516,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels) - def _extract_sources(self, sample, stream_idx, key): + def _extract_sources(self, sample, stream_idx, key, lead_time): channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] @@ -506,6 +535,7 @@ def _extract_sources(self, sample, stream_idx, key): np.asarray(source.geoinfos), channels, geoinfo_channels, + lead_time ) _logger.debug(f"source shape: {source_dataset.data.shape}") diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e28563132..c0939cf8c 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -61,6 +61,7 @@ def write_output( geoinfo_channels, sample_start, cf.forecast_offset, + cf.len_hrs ) with io.ZarrIO(config.get_path_output(cf, epoch)) as writer: From a4cb9a096eb57927201c2fa25c48326849e248f2 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 15 Oct 2025 11:22:35 +0200 Subject: [PATCH 02/26] ruffed --- packages/common/src/weathergen/common/io.py | 2 +- src/weathergen/utils/validation_io.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 9c0ce8e13..d4676590b 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -535,7 +535,7 @@ def _extract_sources(self, sample, stream_idx, key, lead_time): np.asarray(source.geoinfos), channels, geoinfo_channels, - lead_time + lead_time, ) _logger.debug(f"source shape: {source_dataset.data.shape}") diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index c0939cf8c..b7108de5d 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -61,7 +61,7 @@ def write_output( geoinfo_channels, sample_start, cf.forecast_offset, - cf.len_hrs + cf.len_hrs, ) with io.ZarrIO(config.get_path_output(cf, epoch)) as writer: From f1e2157e3bde557e2c3630f348d154b90c522c60 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 15 Oct 2025 12:17:53 +0200 Subject: [PATCH 03/26] addressed comments --- packages/common/src/weathergen/common/io.py | 23 +++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index d4676590b..32a63ecb1 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -148,7 +148,8 @@ class OutputDataset: channels: list[str] geoinfo_channels: list[str] - lead_time: int + # lead time in hours defined as forecast step * length of forecast step (len_hours) + lead_time_hrs: int @functools.cached_property def arrays(self) -> dict[str, zarr.Array | NDArray]: @@ -187,7 +188,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "sample": [self.item_key.sample], "stream": [self.item_key.stream], "forecast_step": [self.item_key.forecast_step], - "lead_time": ("forecast_step", [self.lead_time]), + "lead_time_hrs": ("forecast_step", [self.lead_time_hrs]), "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data "valid_time": ("ipoint", times.astype("datetime64[ns]")), @@ -287,7 +288,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels - dataset_group.attrs["lead_time"] = dataset.lead_time + dataset_group.attrs["lead_time_hrs"] = dataset.lead_time_hrs def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords @@ -338,7 +339,7 @@ def forecast_steps(self) -> list[int]: def lead_times(self) -> list[int]: """Calculate available lead times from available forecast steps and len_hrs.""" example_prediction = self.load_zarr(self.example_key).prediction - len_hrs = example_prediction.lead_time // self.example_key.forecast_step + len_hrs = example_prediction.lead_time_hrs // self.example_key.forecast_step return [step * len_hrs for step in self.forecast_steps] @@ -385,7 +386,7 @@ class OutputBatchData: sample_start: int forecast_offset: int - len_hrs: int + t_window_len_hours: int @functools.cached_property def samples(self): @@ -436,7 +437,7 @@ def extract(self, key: ItemKey) -> OutputItem: "Number of channel names does not align with prediction data." ) - lead_time = self.len_hrs * key + lead_time = self.t_window_len_hours * key.forecast_step if key.with_source: source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time) @@ -449,13 +450,17 @@ def extract(self, key: ItemKey) -> OutputItem: key=key, source=source_dataset, target=OutputDataset( - "target", key, target_data, lead_time=lead_time, **dataclasses.asdict(data_coords) + "target", + key, + target_data, + lead_time_hrs=lead_time, + **dataclasses.asdict(data_coords), ), prediction=OutputDataset( "prediction", key, preds_data, - lead_time=lead_time, + lead_time_hrs=lead_time, **dataclasses.asdict(data_coords), ), ) @@ -516,7 +521,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels) - def _extract_sources(self, sample, stream_idx, key, lead_time): + def _extract_sources(self, sample, stream_idx, key, lead_time: int): channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] From b183e9d4026cc9beb9361369ca6d6a67dbfab53c Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:07:24 +0100 Subject: [PATCH 04/26] remove lead_time_hrs from OutputDataset --- packages/common/src/weathergen/common/io.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 32a63ecb1..6ae098705 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -149,7 +149,6 @@ class OutputDataset: channels: list[str] geoinfo_channels: list[str] # lead time in hours defined as forecast step * length of forecast step (len_hours) - lead_time_hrs: int @functools.cached_property def arrays(self) -> dict[str, zarr.Array | NDArray]: @@ -188,7 +187,6 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "sample": [self.item_key.sample], "stream": [self.item_key.stream], "forecast_step": [self.item_key.forecast_step], - "lead_time_hrs": ("forecast_step", [self.lead_time_hrs]), "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data "valid_time": ("ipoint", times.astype("datetime64[ns]")), @@ -288,7 +286,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels - dataset_group.attrs["lead_time_hrs"] = dataset.lead_time_hrs + def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords @@ -335,14 +333,6 @@ def forecast_steps(self) -> list[int]: return list(example_stream.group_keys()) - @functools.cached_property - def lead_times(self) -> list[int]: - """Calculate available lead times from available forecast steps and len_hrs.""" - example_prediction = self.load_zarr(self.example_key).prediction - len_hrs = example_prediction.lead_time_hrs // self.example_key.forecast_step - - return [step * len_hrs for step in self.forecast_steps] - @dataclasses.dataclass class DataCoordinates: @@ -437,10 +427,8 @@ def extract(self, key: ItemKey) -> OutputItem: "Number of channel names does not align with prediction data." ) - lead_time = self.t_window_len_hours * key.forecast_step - if key.with_source: - source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time) + source_dataset = self._extract_sources(offset_key.sample, stream_idx, key) else: source_dataset = None @@ -453,14 +441,12 @@ def extract(self, key: ItemKey) -> OutputItem: "target", key, target_data, - lead_time_hrs=lead_time, **dataclasses.asdict(data_coords), ), prediction=OutputDataset( "prediction", key, preds_data, - lead_time_hrs=lead_time, **dataclasses.asdict(data_coords), ), ) @@ -521,7 +507,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels) - def _extract_sources(self, sample, stream_idx, key, lead_time: int): + def _extract_sources(self, sample, stream_idx, key): channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] @@ -540,7 +526,6 @@ def _extract_sources(self, sample, stream_idx, key, lead_time: int): np.asarray(source.geoinfos), channels, geoinfo_channels, - lead_time, ) _logger.debug(f"source shape: {source_dataset.data.shape}") From e3a56280097efee23dbcd03099ac548c0217ac65 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 13:35:41 +0100 Subject: [PATCH 05/26] use type alias for union of `zarr.Array` and `NDArray` --- packages/common/src/weathergen/common/io.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 6ae098705..09ed0f6df 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -25,6 +25,7 @@ CHUNK_N_SAMPLES = 16392 type DType = np.float32 type NPDT64 = datetime64 +type ArrayType = zarr.Array | np.NDArray[DType] _logger = logging.getLogger(__name__) @@ -135,23 +136,23 @@ class OutputDataset: item_key: ItemKey # (datapoints, channels, ens) - data: zarr.Array | NDArray # wrong type => array like + data: ArrayType # wrong type => array like # (datapoints,) - times: zarr.Array | NDArray + times: ArrayType # (datapoints, 2) - coords: zarr.Array | NDArray + coords: ArrayType # (datapoints, geoinfos) geoinfos are stream dependent => 0 for most gridded data - geoinfo: zarr.Array | NDArray + geoinfo: ArrayType channels: list[str] geoinfo_channels: list[str] # lead time in hours defined as forecast step * length of forecast step (len_hours) @functools.cached_property - def arrays(self) -> dict[str, zarr.Array | NDArray]: + def arrays(self) -> dict[str, ArrayType]: """Iterate over the arrays and their names.""" return { "data": self.data, From c5b1d1940d33161efe0dc7df028da0c12f5eedf5 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 13:32:46 +0100 Subject: [PATCH 06/26] dataclass to store time range information --- packages/common/src/weathergen/common/io.py | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 09ed0f6df..1274c77e5 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -36,6 +36,34 @@ def is_ndarray(obj: typing.Any) -> bool: return isinstance(obj, (np.ndarray)) # noqa: TID251 +@dataclasses.dataclass +class TimeRange: + start: np.datetime64 + end: np.datetime64 + + def __post_init__(self): + # ensure consistent type + self.start = self.start.astype("np.datetime64[ns]") + self.end = self.end.astype("np.datetime64[ns]") + + assert self.start < self.end + + def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": + assert forecast_dt_hours > 0 and fstep >= 0 + offset = np.timedelta64(forecast_dt_hours * fstep, "h") + return TimeRange(self.start + offset, self.end + offset) + + def get_lead_time( + self, abs_time: np.datetime64 | NDArray[np.datetime64] + ) -> NDArray[np.datetime64]: + if isinstance(abs_time, np.datetime64): + abs_time = np.array([abs_time]) + + abs_time.astype("np.datetiem64[ns]") + assert all(abs_time > self.end) + return abs_time - self.end + + @dataclasses.dataclass class IOReaderData: """ From c29e9c94fe7f75abd7ae091717770c7781aacc11 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 13:41:44 +0100 Subject: [PATCH 07/26] add source interval to `OutputDataset` --- packages/common/src/weathergen/common/io.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 1274c77e5..20d6517e0 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -162,6 +162,7 @@ class OutputDataset: name: str item_key: ItemKey + source_interval: TimeRange # (datapoints, channels, ens) data: ArrayType # wrong type => array like @@ -179,6 +180,16 @@ class OutputDataset: geoinfo_channels: list[str] # lead time in hours defined as forecast step * length of forecast step (len_hours) + @classmethod + def create(cls, name, key, arrays: dict[str, ArrayType], attrs: dict[str, typing.Any]): + """ + Create Output dataset from dictonaries. + """ + assert "source_interval" in attrs, "missing expected attribute 'source_interval'" + + source_interval = TimeRange(**attrs["source_interval"]) + return cls(name, key, source_interval, **arrays, **attrs) + @functools.cached_property def arrays(self) -> dict[str, ArrayType]: """Iterate over the arrays and their names.""" From 4dae35abe6ec232d42a5a378a81417485aaed005 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:15:17 +0100 Subject: [PATCH 08/26] store lead-time and source interval in xarray --- packages/common/src/weathergen/common/io.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 20d6517e0..5bec5a5e6 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -215,7 +215,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: additional_dims = (0, 1, 2) if len(data.shape) == 3 else (0, 1, 2, 5) expanded_data = da.expand_dims(data, axis=additional_dims) coords = da.from_zarr(self.coords).compute() - times = da.from_zarr(self.times).compute() + times = da.from_zarr(self.times).compute().astype("datetime64[ns]") geoinfo = da.from_zarr(self.geoinfo).compute() geoinfo = {name: ("ipoint", geoinfo[:, i]) for i, name in enumerate(self.geoinfo_channels)} # TODO: make sample, stream, forecast_step DataArray attribute, test how it @@ -225,11 +225,14 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: dims=["sample", "stream", "forecast_step", "ipoint", "channel", "ens"], coords={ "sample": [self.item_key.sample], + "source_interval_start": ("sample", [self.source_interval.start]), + "source_interval_end": ("sample", self.source_interval.end), "stream": [self.item_key.stream], "forecast_step": [self.item_key.forecast_step], "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data - "valid_time": ("ipoint", times.astype("datetime64[ns]")), + "valid_time": ("ipoint", times), + "lead_time": ("ipoint", self.source_interval.get_lead_time(times)), "lat": ("ipoint", coords[..., 0]), "lon": ("ipoint", coords[..., 1]), **geoinfo, From 8fcf028b11436e6868630d8b97016f45dfe663d3 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:16:53 +0100 Subject: [PATCH 09/26] correct de(serialization) for `OutputDataset` with source-interval --- packages/common/src/weathergen/common/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 5bec5a5e6..c395bac8f 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -298,7 +298,7 @@ def load_zarr(self, key: ItemKey) -> OutputItem: """Get datasets for a output item.""" group = self._get_group(key) datasets = { - name: OutputDataset(name, key, **dict(dataset.arrays()), **dataset.attrs) + name: OutputDataset.create(name, key, dict(dataset.arrays()), dataset.attrs) for name, dataset in group.groups() } datasets["key"] = key @@ -329,7 +329,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels - + dataset_group.attrs["source_interval"] = dataclasses.asdict(dataset.source_interval) def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords From 4fb096021a597dee7d2d4673ac40040fad0f18c3 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:18:33 +0100 Subject: [PATCH 10/26] correctly instantiate `OutputDataset` from `OutputBatchData` --- packages/common/src/weathergen/common/io.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index c395bac8f..1ce268166 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -446,6 +446,10 @@ def extract(self, key: ItemKey) -> OutputItem: stream_idx = self.streams[key.stream] datapoints = self._get_datapoints_per_sample(offset_key, stream_idx) + # TODO extract real source interval start/end times + dummy_date = np.datetime64("2020-01-01") + source_interval = TimeRange(dummy_date, dummy_date + np.timedelta64(5, "D")) + _logger.debug( f"forecast_step: {key.forecast_step} = {offset_key.forecast_step} (rel_step) + " + f"{self.forecast_offset} (forecast_offset)" @@ -471,7 +475,9 @@ def extract(self, key: ItemKey) -> OutputItem: ) if key.with_source: - source_dataset = self._extract_sources(offset_key.sample, stream_idx, key) + source_dataset = self._extract_sources( + offset_key.sample, stream_idx, key, source_interval + ) else: source_dataset = None @@ -483,12 +489,14 @@ def extract(self, key: ItemKey) -> OutputItem: target=OutputDataset( "target", key, + source_interval, target_data, **dataclasses.asdict(data_coords), ), prediction=OutputDataset( "prediction", key, + source_interval, preds_data, **dataclasses.asdict(data_coords), ), @@ -550,11 +558,13 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels) - def _extract_sources(self, sample, stream_idx, key): + def _extract_sources( + self, sample: int, stream_idx: int, key: ItemKey, source_interval: TimeRange + ) -> OutputDataset: channels = self.source_channels[stream_idx] geoinfo_channels = self.geoinfo_channels[stream_idx] - source = self.sources[sample][stream_idx] + source: IOReaderData = self.sources[sample][stream_idx] assert source.data.shape[1] == len(channels), ( "Number of source channel names does not align with source data" @@ -563,6 +573,7 @@ def _extract_sources(self, sample, stream_idx, key): source_dataset = OutputDataset( "source", key, + source_interval, np.asarray(source.data), np.asarray(source.datetimes), np.asarray(source.coords), From f2bab79e7985377d827d95d7b6b59c5920e9ba07 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:19:42 +0100 Subject: [PATCH 11/26] remove attribute `t_window_len_hours` from `OutputBatchData` --- packages/common/src/weathergen/common/io.py | 1 - src/weathergen/utils/validation_io.py | 1 - 2 files changed, 2 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 1ce268166..aeb3083ec 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -419,7 +419,6 @@ class OutputBatchData: sample_start: int forecast_offset: int - t_window_len_hours: int @functools.cached_property def samples(self): diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index b7108de5d..e28563132 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -61,7 +61,6 @@ def write_output( geoinfo_channels, sample_start, cf.forecast_offset, - cf.len_hrs, ) with io.ZarrIO(config.get_path_output(cf, epoch)) as writer: From 9d5cd6d7b06e5708271102ec59496ddc4b7e9e20 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:22:34 +0100 Subject: [PATCH 12/26] calculate for output source windows from sample indices --- src/weathergen/train/trainer.py | 13 +++++++++---- src/weathergen/utils/validation_io.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3d847a671..c3b77f318 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -33,6 +33,7 @@ import weathergen.common.config as config from weathergen.common.config import Config from weathergen.datasets.multi_stream_data_sampler import MultiStreamDataSampler +from weathergen.datasets.stream_data import StreamData from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -683,11 +684,12 @@ def validate(self, epoch): self.model_params, batch, cf.forecast_offset, forecast_steps ) + streams_data: list[list[StreamData]] = batch[0] # compute loss and log output if bidx < cf.log_validation: loss_values = self.loss_calculator_val.compute_loss( preds=preds, - streams_data=batch[0], + streams_data=streams_data, ) # TODO: Move _prepare_logging into write_validation by passing streams_data @@ -701,9 +703,11 @@ def validate(self, epoch): preds=preds, forecast_offset=cf.forecast_offset, forecast_steps=cf.forecast_steps, - streams_data=batch[0], + streams_data=streams_data, ) - sources = [[item.source_raw for item in b] for b in batch[0]] + sources = [[item.source_raw for item in stream] for stream in streams_data] + # sample idx should be the same across streams => select first + sample_idxs = [item.sample_idx for item in streams_data[0]] write_output( self.cf, epoch, @@ -714,12 +718,13 @@ def validate(self, epoch): targets_coords_all, targets_times_all, targets_lens, + sample_idxs, ) else: loss_values = self.loss_calculator_val.compute_loss( preds=preds, - streams_data=batch[0], + streams_data=streams_data, ) self.loss_unweighted_hist += [loss_values.losses_all] diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index e28563132..5d46e5564 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -11,6 +11,8 @@ import weathergen.common.config as config import weathergen.common.io as io +from weathergen.common.io import TimeRange +from weathergen.datasets.data_reader_base import TimeWindowHandler, str_to_datetime64 _logger = logging.getLogger(__name__) @@ -25,6 +27,7 @@ def write_output( targets_coords_all, targets_times_all, targets_lens, + sample_idxs, ): stream_names = [stream.name for stream in cf.streams] output_stream_names = cf.analysis_streams_output @@ -48,6 +51,13 @@ def write_output( assert len(stream_names) == len(preds_all[0]), "data does not match number of streams" assert len(stream_names) == len(sources[0]), "data does not match number of streams" + start_date = str_to_datetime64(cf.start_date_val) + end_date = str_to_datetime64(cf.end_date_val) + + twh = TimeWindowHandler(start_date, end_date, cf.len_hrs, cf.step_hrs) + source_windows = (twh.window(idx) for idx in sample_idxs) + source_intervals = [TimeRange(window.start, window.end) for window in source_windows] + data = io.OutputBatchData( sources, targets_all, From fef7456b979a11c286cf287de56fbdb926fa5376 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 14:31:33 +0100 Subject: [PATCH 13/26] add source_intervals as attribute to OutputBatchData --- packages/common/src/weathergen/common/io.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index aeb3083ec..8f66fff0c 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -393,6 +393,9 @@ class OutputBatchData: # sample, stream, tensor(datapoint, channel+coords) # => datapoints is accross all datasets per stream sources: list[list[IOReaderData]] + + # sample + source_intervals: list[TimeRange] # fstep, stream, redundant dim (size 1), tensor(sample x datapoint, channel) targets: list[list[list]] @@ -445,9 +448,7 @@ def extract(self, key: ItemKey) -> OutputItem: stream_idx = self.streams[key.stream] datapoints = self._get_datapoints_per_sample(offset_key, stream_idx) - # TODO extract real source interval start/end times - dummy_date = np.datetime64("2020-01-01") - source_interval = TimeRange(dummy_date, dummy_date + np.timedelta64(5, "D")) + source_interval = self.source_intervals[offset_key.sample] _logger.debug( f"forecast_step: {key.forecast_step} = {offset_key.forecast_step} (rel_step) + " From 8bbb267397a375f561b9d3415b81519afc18c4c0 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 29 Oct 2025 16:34:54 +0100 Subject: [PATCH 14/26] ruffed --- packages/common/src/weathergen/common/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 8f66fff0c..febf33d26 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -393,7 +393,7 @@ class OutputBatchData: # sample, stream, tensor(datapoint, channel+coords) # => datapoints is accross all datasets per stream sources: list[list[IOReaderData]] - + # sample source_intervals: list[TimeRange] From 2dcd4a9b70b3bca5ae10f16e4d064c5beb2f0f7f Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Mon, 3 Nov 2025 14:18:16 +0100 Subject: [PATCH 15/26] fix: pass source intervals to OutputBatchData --- src/weathergen/utils/validation_io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/utils/validation_io.py b/src/weathergen/utils/validation_io.py index 5d46e5564..71d45f57e 100644 --- a/src/weathergen/utils/validation_io.py +++ b/src/weathergen/utils/validation_io.py @@ -60,6 +60,7 @@ def write_output( data = io.OutputBatchData( sources, + source_intervals, targets_all, preds_all, targets_coords_all, From bf8fcef68f8a9fe8aa2a8b0f92dd7365b0a1c227 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Tue, 4 Nov 2025 13:08:26 +0100 Subject: [PATCH 16/26] fix: use correct string in array.astype --- packages/common/src/weathergen/common/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index febf33d26..f54d5df69 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -43,8 +43,8 @@ class TimeRange: def __post_init__(self): # ensure consistent type - self.start = self.start.astype("np.datetime64[ns]") - self.end = self.end.astype("np.datetime64[ns]") + self.start = self.start.astype("datetime64[ns]") + self.end = self.end.astype("datetime64[ns]") assert self.start < self.end From 7784cbe43efb8af2cec3a2c1c72382ca9509640b Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Tue, 4 Nov 2025 15:00:27 +0100 Subject: [PATCH 17/26] fix: (de)serialize np.datetime64 --- packages/common/src/weathergen/common/io.py | 23 +++++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index f54d5df69..9f6e1bc05 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -36,18 +36,23 @@ def is_ndarray(obj: typing.Any) -> bool: return isinstance(obj, (np.ndarray)) # noqa: TID251 -@dataclasses.dataclass class TimeRange: - start: np.datetime64 - end: np.datetime64 - - def __post_init__(self): - # ensure consistent type - self.start = self.start.astype("datetime64[ns]") - self.end = self.end.astype("datetime64[ns]") + def __init__(self, start: NPDT64 | str, end: NPDT64 | str): + # ensure consistent type => convert serialized strings + self.start = np.datetime64(start, "ns") + self.end = np.datetime64(end, "ns") assert self.start < self.end + def as_dict(self): + """Convert instance to a JSON-serializable dict.""" + + # will output as "YYYY-MM-DDThh:mm:s.sssssssss" + return { + "start": str(self.start), + "end": str(self.end), + } + def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": assert forecast_dt_hours > 0 and fstep >= 0 offset = np.timedelta64(forecast_dt_hours * fstep, "h") @@ -329,7 +334,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset): def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset): dataset_group.attrs["channels"] = dataset.channels dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels - dataset_group.attrs["source_interval"] = dataclasses.asdict(dataset.source_interval) + dataset_group.attrs["source_interval"] = dataset.source_interval.as_dict() def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset): for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords From 54ee41767e9b1d3b2fefab5b1cf9c11de79d1fa1 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Tue, 4 Nov 2025 15:30:32 +0100 Subject: [PATCH 18/26] fix deserialization of OutputDataset from json/zarr --- packages/common/src/weathergen/common/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 9f6e1bc05..eca6c0f06 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -192,7 +192,7 @@ def create(cls, name, key, arrays: dict[str, ArrayType], attrs: dict[str, typing """ assert "source_interval" in attrs, "missing expected attribute 'source_interval'" - source_interval = TimeRange(**attrs["source_interval"]) + source_interval = TimeRange(**attrs.pop("source_interval")) return cls(name, key, source_interval, **arrays, **attrs) @functools.cached_property From 21b310a60adab19ef768a01f9f5e5b48e4b8d19c Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Tue, 4 Nov 2025 16:40:22 +0100 Subject: [PATCH 19/26] fix: errors in xarray conversion --- packages/common/src/weathergen/common/io.py | 49 ++++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index eca6c0f06..d4516ba37 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -64,8 +64,9 @@ def get_lead_time( if isinstance(abs_time, np.datetime64): abs_time = np.array([abs_time]) - abs_time.astype("np.datetiem64[ns]") - assert all(abs_time > self.end) + abs_time.astype("datetime64[ns]") + # this fails for forecast offset = 0 / fstep 0 + # assert all(abs_time >= self.end) return abs_time - self.end @@ -225,25 +226,24 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: geoinfo = {name: ("ipoint", geoinfo[:, i]) for i, name in enumerate(self.geoinfo_channels)} # TODO: make sample, stream, forecast_step DataArray attribute, test how it # interacts with concatenating - return xr.DataArray( - expanded_data, - dims=["sample", "stream", "forecast_step", "ipoint", "channel", "ens"], - coords={ - "sample": [self.item_key.sample], - "source_interval_start": ("sample", [self.source_interval.start]), - "source_interval_end": ("sample", self.source_interval.end), - "stream": [self.item_key.stream], - "forecast_step": [self.item_key.forecast_step], - "ipoint": self.datapoints, - "channel": self.channels, # TODO: make sure channel names align with data - "valid_time": ("ipoint", times), - "lead_time": ("ipoint", self.source_interval.get_lead_time(times)), - "lat": ("ipoint", coords[..., 0]), - "lon": ("ipoint", coords[..., 1]), - **geoinfo, - }, - name=self.name, - ) + dims = ["sample", "stream", "forecast_step", "ipoint", "channel", "ens"] + breakpoint() + ds_coords = { + "sample": [self.item_key.sample], + "source_interval_start": ("sample", [self.source_interval.start]), + "source_interval_end": ("sample", [self.source_interval.end]), + "stream": [self.item_key.stream], + "forecast_step": [self.item_key.forecast_step], + "ipoint": self.datapoints, + "channel": self.channels, # TODO: make sure channel names align with data + "valid_time": ("ipoint", times), + "lead_time": ("ipoint", self.source_interval.get_lead_time(times)), + "lat": ("ipoint", coords[..., 0]), + "lon": ("ipoint", coords[..., 1]), + **geoinfo, + } + breakpoint() + return xr.DataArray(expanded_data, dims=dims, coords=ds_coords, name=self.name) class OutputItem: @@ -303,7 +303,12 @@ def load_zarr(self, key: ItemKey) -> OutputItem: """Get datasets for a output item.""" group = self._get_group(key) datasets = { - name: OutputDataset.create(name, key, dict(dataset.arrays()), dataset.attrs) + name: OutputDataset.create( + name, + key, + dict(dataset.arrays()), + dict(dataset.attrs).copy() + ) for name, dataset in group.groups() } datasets["key"] = key From 24eb659a4a63e15d6dbccbaf7f90b0e41fbdae42 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Tue, 4 Nov 2025 19:59:13 +0100 Subject: [PATCH 20/26] fix types --- packages/common/src/weathergen/common/io.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index d4516ba37..be2e016c7 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -60,11 +60,12 @@ def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": def get_lead_time( self, abs_time: np.datetime64 | NDArray[np.datetime64] - ) -> NDArray[np.datetime64]: + ) -> NDArray[np.timedelta64]: if isinstance(abs_time, np.datetime64): abs_time = np.array([abs_time]) - abs_time.astype("datetime64[ns]") + abs_time = abs_time.astype("datetime64[ns]") + # end = self.end.astype("datetime64[ns]") # this fails for forecast offset = 0 / fstep 0 # assert all(abs_time >= self.end) return abs_time - self.end @@ -227,7 +228,6 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: # TODO: make sample, stream, forecast_step DataArray attribute, test how it # interacts with concatenating dims = ["sample", "stream", "forecast_step", "ipoint", "channel", "ens"] - breakpoint() ds_coords = { "sample": [self.item_key.sample], "source_interval_start": ("sample", [self.source_interval.start]), @@ -242,7 +242,6 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "lon": ("ipoint", coords[..., 1]), **geoinfo, } - breakpoint() return xr.DataArray(expanded_data, dims=dims, coords=ds_coords, name=self.name) @@ -304,10 +303,7 @@ def load_zarr(self, key: ItemKey) -> OutputItem: group = self._get_group(key) datasets = { name: OutputDataset.create( - name, - key, - dict(dataset.arrays()), - dict(dataset.attrs).copy() + name, key, dict(dataset.arrays()), dict(dataset.attrs).copy() ) for name, dataset in group.groups() } From c8986d6cdc4ea04123a0568dcbea610edacff6e2 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 5 Nov 2025 08:45:50 +0100 Subject: [PATCH 21/26] Improve documentation --- packages/common/src/weathergen/common/io.py | 48 +++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index be2e016c7..bfcb5e362 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -37,6 +37,18 @@ def is_ndarray(obj: typing.Any) -> bool: class TimeRange: + """ + Holds information about a time interval used in forecasting. + + Time interval is left-closed, right-open. TimeRange can be instatiated from + numpy datetime64 objects or strings as outputed by TimeRange.as_dict. + Both will be converted to datetime64 with nanosecond precision. + + Attrs: + start: Start of the time range in nanoseconds. + end: End of the time range in nanoseconds + """ + def __init__(self, start: NPDT64 | str, end: NPDT64 | str): # ensure consistent type => convert serialized strings self.start = np.datetime64(start, "ns") @@ -44,16 +56,31 @@ def __init__(self, start: NPDT64 | str, end: NPDT64 | str): assert self.start < self.end - def as_dict(self): - """Convert instance to a JSON-serializable dict.""" + def as_dict(self) -> dict[str, str]: + """ + Convert instance to a JSON-serializable dict. + + will convert datetime objects as "YYYY-MM-DDThh:mm:s.sssssssss" - # will output as "YYYY-MM-DDThh:mm:s.sssssssss" + Returns: + JSON-serializable dict, wher datetime objects were converted to strings. + """ return { "start": str(self.start), "end": str(self.end), } def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": + """ + Infer the interval cosidered at forecast step `fstep`. + + Args: + forecast_dt_hours: number of hours the source TimeRange is shifted per forecast step. + fstep: current forecast step. + + Returns: + New TimeRange shifted TimeRange. + """ assert forecast_dt_hours > 0 and fstep >= 0 offset = np.timedelta64(forecast_dt_hours * fstep, "h") return TimeRange(self.start + offset, self.end + offset) @@ -61,6 +88,15 @@ def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": def get_lead_time( self, abs_time: np.datetime64 | NDArray[np.datetime64] ) -> NDArray[np.timedelta64]: + """ + Calculate lead times based on the end of the TimeRange. + + Args: + abs_time: Single timestamp or array of timestamps. + + Returns: + Array of time differences (lead times) for each input timestamp. + """ if isinstance(abs_time, np.datetime64): abs_time = np.array([abs_time]) @@ -191,6 +227,12 @@ class OutputDataset: def create(cls, name, key, arrays: dict[str, ArrayType], attrs: dict[str, typing.Any]): """ Create Output dataset from dictonaries. + + Args: + name: Name of dataset (target/prediction/source) + item_key: ItemKey to associated with the parent OutputItem. + arrays: Data and Coordinate arrays. + attrs: Additional metadata. """ assert "source_interval" in attrs, "missing expected attribute 'source_interval'" From ae3af6619572ed978eebd61582b6571b868e2735 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Wed, 5 Nov 2025 15:57:05 +0100 Subject: [PATCH 22/26] remove lead_time from OututDataset --- packages/common/src/weathergen/common/io.py | 23 --------------------- 1 file changed, 23 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index bfcb5e362..1e0a3a21d 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -85,27 +85,6 @@ def forecast_interval(self, forecast_dt_hours: int, fstep: int) -> "TimeRange": offset = np.timedelta64(forecast_dt_hours * fstep, "h") return TimeRange(self.start + offset, self.end + offset) - def get_lead_time( - self, abs_time: np.datetime64 | NDArray[np.datetime64] - ) -> NDArray[np.timedelta64]: - """ - Calculate lead times based on the end of the TimeRange. - - Args: - abs_time: Single timestamp or array of timestamps. - - Returns: - Array of time differences (lead times) for each input timestamp. - """ - if isinstance(abs_time, np.datetime64): - abs_time = np.array([abs_time]) - - abs_time = abs_time.astype("datetime64[ns]") - # end = self.end.astype("datetime64[ns]") - # this fails for forecast offset = 0 / fstep 0 - # assert all(abs_time >= self.end) - return abs_time - self.end - @dataclasses.dataclass class IOReaderData: @@ -221,7 +200,6 @@ class OutputDataset: channels: list[str] geoinfo_channels: list[str] - # lead time in hours defined as forecast step * length of forecast step (len_hours) @classmethod def create(cls, name, key, arrays: dict[str, ArrayType], attrs: dict[str, typing.Any]): @@ -279,7 +257,6 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray: "ipoint": self.datapoints, "channel": self.channels, # TODO: make sure channel names align with data "valid_time": ("ipoint", times), - "lead_time": ("ipoint", self.source_interval.get_lead_time(times)), "lat": ("ipoint", coords[..., 0]), "lon": ("ipoint", coords[..., 1]), **geoinfo, From 8e913befe2dfe975a1bda12ddf649ac5a199c5ed Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 6 Nov 2025 12:54:42 +0100 Subject: [PATCH 23/26] implement with_target function for ItemKey class --- packages/common/src/weathergen/common/io.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 1e0a3a21d..3717d10b3 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -166,17 +166,20 @@ class ItemKey: stream: str @property - def path(self): + def path(self) -> str: """Unique path within a hierarchy for one output item.""" return f"{self.sample}/{self.stream}/{self.forecast_step}" @property - def with_source(self): + def with_source(self) -> bool: """Decide if output item should contain source dataset.""" - # TODO: is this valid for the adjusted (offsetted) forecast steps? - # => if config.forecast_offset > 0 source will be never written return self.forecast_step == 0 + def with_target(self, forecast_offset: typing.Literal[0, 1]) -> bool: + """Decide if output item should contain target and predictions.""" + assert forecast_offset in (0, 1) + return (not self.with_source) or (forecast_offset == 0) + @dataclasses.dataclass class OutputDataset: From e4ee2341ff0f5e9bd533f38c48d32044d2507c3e Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 6 Nov 2025 12:56:19 +0100 Subject: [PATCH 24/26] handle potentially missing target/prediction data in OutputItem --- packages/common/src/weathergen/common/io.py | 26 +++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 3717d10b3..2635bddbb 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -208,7 +208,7 @@ class OutputDataset: def create(cls, name, key, arrays: dict[str, ArrayType], attrs: dict[str, typing.Any]): """ Create Output dataset from dictonaries. - + Args: name: Name of dataset (target/prediction/source) item_key: ItemKey to associated with the parent OutputItem. @@ -281,14 +281,26 @@ def __init__( self.prediction = prediction self.source = source - self.datasets = [self.target, self.prediction] + self.datasets = [] + forecast_offset = 0 if self.key.with_source: - if self.source: - self.datasets.append(self.source) - else: - msg = f"Missing source dataset for item: {self.key.path}" - raise ValueError(msg) + self._append_dataset(self.source, "source") + + # forecast offset=1 should produce no targets + if not self.target: + forecast_offset = 1 + + if self.key.with_target(forecast_offset): + self._append_dataset(self.target, "target") + self._append_dataset(self.prediction, "prediction") + + def _append_dataset(self, dataset: OutputDataset | None, name: str): + if dataset: + self.datasets.append(dataset) + else: + msg = f"Missing {name} dataset for item: {self.key.path}" + raise ValueError(msg) class ZarrIO: From e2f7f822c417c662ab6bbcd2fcfe193068e17976 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 6 Nov 2025 13:03:53 +0100 Subject: [PATCH 25/26] separate extraction of targets/predictions into method --- packages/common/src/weathergen/common/io.py | 98 +++++++++++---------- 1 file changed, 53 insertions(+), 45 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index 2635bddbb..a5fff84c4 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -486,16 +486,52 @@ def extract(self, key: ItemKey) -> OutputItem: _logger.debug(f"extracting subset: {key}") offset_key = self._offset_key(key) stream_idx = self.streams[key.stream] - datapoints = self._get_datapoints_per_sample(offset_key, stream_idx) source_interval = self.source_intervals[offset_key.sample] - _logger.debug( f"forecast_step: {key.forecast_step} = {offset_key.forecast_step} (rel_step) + " + f"{self.forecast_offset} (forecast_offset)" ) _logger.debug(f"stream: {key.stream} with index: {stream_idx}") + assert self.forecast_offset in (0, 1) + if key.with_source: + source_dataset = self._extract_sources( + offset_key.sample, stream_idx, key, source_interval + ) + else: + source_dataset = None + + target_dataset, prediction_dataset = self._extract_targets_predictions( + stream_idx, offset_key, key, source_interval + ) + + return OutputItem( + key=key, + source=source_dataset, + target=target_dataset, + prediction=prediction_dataset, + ) + + def _offset_key(self, key: ItemKey): + """ + Correct indices in key to be useable for data extraction. + + `key` contains indices that are adjusted to have better output semantics. + To be useable in extraction these have to be adjusted to bridge the differences + compared to the semantics of the data. + - `sample` is adjusted from a global continous index to a per batch index + - `forecast_step` is adjusted from including `forecast_offset` to indexing + the data (always starts at 0) + """ + return ItemKey( + key.sample - self.sample_start, key.forecast_step - self.forecast_offset, key.stream + ) + + def _extract_targets_predictions(self, stream_idx, offset_key, key, source_interval): + datapoints = self._get_datapoints_per_sample(offset_key, stream_idx) + data_coords = self._extract_coordinates(stream_idx, offset_key, datapoints) + if (datapoints.stop - datapoints.start) == 0: target_data = np.zeros((0, len(self.target_channels[stream_idx])), dtype=np.float32) preds_data = np.zeros((0, len(self.target_channels[stream_idx])), dtype=np.float32) @@ -505,8 +541,6 @@ def extract(self, key: ItemKey) -> OutputItem: 1, 2, 0 )[datapoints] - data_coords = self._extract_coordinates(stream_idx, offset_key, datapoints) - assert len(data_coords.channels) == target_data.shape[1], ( "Number of channel names does not align with target data." ) @@ -514,34 +548,23 @@ def extract(self, key: ItemKey) -> OutputItem: "Number of channel names does not align with prediction data." ) - if key.with_source: - source_dataset = self._extract_sources( - offset_key.sample, stream_idx, key, source_interval - ) - else: - source_dataset = None - - assert is_ndarray(target_data), f"Expected ndarray but got: {type(target_data)}" - assert is_ndarray(preds_data), f"Expected ndarray but got: {type(preds_data)}" - return OutputItem( - key=key, - source=source_dataset, - target=OutputDataset( - "target", - key, - source_interval, - target_data, - **dataclasses.asdict(data_coords), - ), - prediction=OutputDataset( - "prediction", - key, - source_interval, - preds_data, - **dataclasses.asdict(data_coords), - ), + target_dataset = OutputDataset( + "target", + key, + source_interval, + target_data, + **dataclasses.asdict(data_coords), + ) + prediction_dataset = OutputDataset( + "prediction", + key, + source_interval, + preds_data, + **dataclasses.asdict(data_coords), ) + return target_dataset, prediction_dataset + def _get_datapoints_per_sample(self, offset_key, stream_idx): lens = self.targets_lens[offset_key.forecast_step][stream_idx] @@ -560,21 +583,6 @@ def _get_datapoints_per_sample(self, offset_key, stream_idx): return slice(start, start + n_samples) - def _offset_key(self, key: ItemKey): - """ - Correct indices in key to be useable for data extraction. - - `key` contains indices that are adjusted to have better output semantics. - To be useable in extraction these have to be adjusted to bridge the differences - compared to the semantics of the data. - - `sample` is adjusted from a global continous index to a per batch index - - `forecast_step` is adjusted from including `forecast_offset` to indexing - the data (always starts at 0) - """ - return ItemKey( - key.sample - self.sample_start, key.forecast_step - self.forecast_offset, key.stream - ) - def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordinates: _coords = self.targets_coords[offset_key.forecast_step][stream_idx][datapoints].numpy() From 54f8f6470c22882619fcd534cbd464bea20ded99 Mon Sep 17 00:00:00 2001 From: Simon Grasse Date: Thu, 6 Nov 2025 13:06:04 +0100 Subject: [PATCH 26/26] include fstep 0 when forecast offset is 1 --- packages/common/src/weathergen/common/io.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/packages/common/src/weathergen/common/io.py b/packages/common/src/weathergen/common/io.py index a5fff84c4..2b6310935 100644 --- a/packages/common/src/weathergen/common/io.py +++ b/packages/common/src/weathergen/common/io.py @@ -466,12 +466,16 @@ class OutputBatchData: @functools.cached_property def samples(self): """Continous indices of all samples accross all batches.""" + + # TODO associate samples with the sampel idx used for the time window return np.arange(len(self.sources)) + self.sample_start @functools.cached_property def forecast_steps(self): """Indices of all forecast steps adjusted by the forecast offset""" - return np.arange(len(self.targets)) + self.forecast_offset + # forecast offset should be either 1 for forecasting or 0 for MTM + assert self.forecast_offset in (0, 1) + return np.arange(len(self.targets) + self.forecast_offset) def items(self) -> typing.Generator[OutputItem, None, None]: """Iterate over possible output items""" @@ -502,9 +506,12 @@ def extract(self, key: ItemKey) -> OutputItem: else: source_dataset = None - target_dataset, prediction_dataset = self._extract_targets_predictions( - stream_idx, offset_key, key, source_interval - ) + if key.with_target(self.forecast_offset): + target_dataset, prediction_dataset = self._extract_targets_predictions( + stream_idx, offset_key, key, source_interval + ) + else: + target_dataset, prediction_dataset = (None, None) return OutputItem( key=key,