@@ -148,7 +148,8 @@ class OutputDataset:
148148
149149 channels : list [str ]
150150 geoinfo_channels : list [str ]
151- lead_time : int
151+ # lead time in hours defined as forecast step * length of forecast step (len_hours)
152+ lead_time_hrs : int
152153
153154 @functools .cached_property
154155 def arrays (self ) -> dict [str , zarr .Array | NDArray ]:
@@ -187,7 +188,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray:
187188 "sample" : [self .item_key .sample ],
188189 "stream" : [self .item_key .stream ],
189190 "forecast_step" : [self .item_key .forecast_step ],
190- "lead_time " : ("forecast_step" , [self .lead_time ]),
191+ "lead_time_hrs " : ("forecast_step" , [self .lead_time_hrs ]),
191192 "ipoint" : self .datapoints ,
192193 "channel" : self .channels , # TODO: make sure channel names align with data
193194 "valid_time" : ("ipoint" , times .astype ("datetime64[ns]" )),
@@ -287,7 +288,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset):
287288 def _write_metadata (self , dataset_group : zarr .Group , dataset : OutputDataset ):
288289 dataset_group .attrs ["channels" ] = dataset .channels
289290 dataset_group .attrs ["geoinfo_channels" ] = dataset .geoinfo_channels
290- dataset_group .attrs ["lead_time " ] = dataset .lead_time
291+ dataset_group .attrs ["lead_time_hrs " ] = dataset .lead_time_hrs
291292
292293 def _write_arrays (self , dataset_group : zarr .Group , dataset : OutputDataset ):
293294 for array_name , array in dataset .arrays .items (): # suffix is eg. data or coords
@@ -338,7 +339,7 @@ def forecast_steps(self) -> list[int]:
338339 def lead_times (self ) -> list [int ]:
339340 """Calculate available lead times from available forecast steps and len_hrs."""
340341 example_prediction = self .load_zarr (self .example_key ).prediction
341- len_hrs = example_prediction .lead_time // self .example_key .forecast_step
342+ len_hrs = example_prediction .lead_time_hrs // self .example_key .forecast_step
342343
343344 return [step * len_hrs for step in self .forecast_steps ]
344345
@@ -385,7 +386,7 @@ class OutputBatchData:
385386
386387 sample_start : int
387388 forecast_offset : int
388- len_hrs : int
389+ t_window_len_hours : int
389390
390391 @functools .cached_property
391392 def samples (self ):
@@ -436,7 +437,7 @@ def extract(self, key: ItemKey) -> OutputItem:
436437 "Number of channel names does not align with prediction data."
437438 )
438439
439- lead_time = self .len_hrs * key
440+ lead_time = self .t_window_len_hours * key . forecast_step
440441
441442 if key .with_source :
442443 source_dataset = self ._extract_sources (offset_key .sample , stream_idx , key , lead_time )
@@ -449,13 +450,17 @@ def extract(self, key: ItemKey) -> OutputItem:
449450 key = key ,
450451 source = source_dataset ,
451452 target = OutputDataset (
452- "target" , key , target_data , lead_time = lead_time , ** dataclasses .asdict (data_coords )
453+ "target" ,
454+ key ,
455+ target_data ,
456+ lead_time_hrs = lead_time ,
457+ ** dataclasses .asdict (data_coords ),
453458 ),
454459 prediction = OutputDataset (
455460 "prediction" ,
456461 key ,
457462 preds_data ,
458- lead_time = lead_time ,
463+ lead_time_hrs = lead_time ,
459464 ** dataclasses .asdict (data_coords ),
460465 ),
461466 )
@@ -516,7 +521,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi
516521
517522 return DataCoordinates (times , coords , geoinfo , channels , geoinfo_channels )
518523
519- def _extract_sources (self , sample , stream_idx , key , lead_time ):
524+ def _extract_sources (self , sample , stream_idx , key , lead_time : int ):
520525 channels = self .source_channels [stream_idx ]
521526 geoinfo_channels = self .geoinfo_channels [stream_idx ]
522527
0 commit comments