Skip to content

Commit 50fc8ab

Browse files
committed
addressed comments
1 parent 1c56cb0 commit 50fc8ab

File tree

1 file changed

+14
-9
lines changed
  • packages/common/src/weathergen/common

1 file changed

+14
-9
lines changed

packages/common/src/weathergen/common/io.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ class OutputDataset:
148148

149149
channels: list[str]
150150
geoinfo_channels: list[str]
151-
lead_time: int
151+
# lead time in hours defined as forecast step * length of forecast step (len_hours)
152+
lead_time_hrs: int
152153

153154
@functools.cached_property
154155
def arrays(self) -> dict[str, zarr.Array | NDArray]:
@@ -187,7 +188,7 @@ def as_xarray(self, chunk_nsamples=CHUNK_N_SAMPLES) -> xr.DataArray:
187188
"sample": [self.item_key.sample],
188189
"stream": [self.item_key.stream],
189190
"forecast_step": [self.item_key.forecast_step],
190-
"lead_time": ("forecast_step", [self.lead_time]),
191+
"lead_time_hrs": ("forecast_step", [self.lead_time_hrs]),
191192
"ipoint": self.datapoints,
192193
"channel": self.channels, # TODO: make sure channel names align with data
193194
"valid_time": ("ipoint", times.astype("datetime64[ns]")),
@@ -287,7 +288,7 @@ def _write_dataset(self, item_group: zarr.Group, dataset: OutputDataset):
287288
def _write_metadata(self, dataset_group: zarr.Group, dataset: OutputDataset):
288289
dataset_group.attrs["channels"] = dataset.channels
289290
dataset_group.attrs["geoinfo_channels"] = dataset.geoinfo_channels
290-
dataset_group.attrs["lead_time"] = dataset.lead_time
291+
dataset_group.attrs["lead_time_hrs"] = dataset.lead_time_hrs
291292

292293
def _write_arrays(self, dataset_group: zarr.Group, dataset: OutputDataset):
293294
for array_name, array in dataset.arrays.items(): # suffix is eg. data or coords
@@ -338,7 +339,7 @@ def forecast_steps(self) -> list[int]:
338339
def lead_times(self) -> list[int]:
339340
"""Calculate available lead times from available forecast steps and len_hrs."""
340341
example_prediction = self.load_zarr(self.example_key).prediction
341-
len_hrs = example_prediction.lead_time // self.example_key.forecast_step
342+
len_hrs = example_prediction.lead_time_hrs // self.example_key.forecast_step
342343

343344
return [step * len_hrs for step in self.forecast_steps]
344345

@@ -385,7 +386,7 @@ class OutputBatchData:
385386

386387
sample_start: int
387388
forecast_offset: int
388-
len_hrs: int
389+
t_window_len_hours: int
389390

390391
@functools.cached_property
391392
def samples(self):
@@ -436,7 +437,7 @@ def extract(self, key: ItemKey) -> OutputItem:
436437
"Number of channel names does not align with prediction data."
437438
)
438439

439-
lead_time = self.len_hrs * key
440+
lead_time = self.t_window_len_hours * key.forecast_step
440441

441442
if key.with_source:
442443
source_dataset = self._extract_sources(offset_key.sample, stream_idx, key, lead_time)
@@ -449,13 +450,17 @@ def extract(self, key: ItemKey) -> OutputItem:
449450
key=key,
450451
source=source_dataset,
451452
target=OutputDataset(
452-
"target", key, target_data, lead_time=lead_time, **dataclasses.asdict(data_coords)
453+
"target",
454+
key,
455+
target_data,
456+
lead_time_hrs=lead_time,
457+
**dataclasses.asdict(data_coords),
453458
),
454459
prediction=OutputDataset(
455460
"prediction",
456461
key,
457462
preds_data,
458-
lead_time=lead_time,
463+
lead_time_hrs=lead_time,
459464
**dataclasses.asdict(data_coords),
460465
),
461466
)
@@ -516,7 +521,7 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi
516521

517522
return DataCoordinates(times, coords, geoinfo, channels, geoinfo_channels)
518523

519-
def _extract_sources(self, sample, stream_idx, key, lead_time):
524+
def _extract_sources(self, sample, stream_idx, key, lead_time: int):
520525
channels = self.source_channels[stream_idx]
521526
geoinfo_channels = self.geoinfo_channels[stream_idx]
522527

0 commit comments

Comments
 (0)