1111from pathlib import Path
1212from typing import override
1313
14- import dask .array as da
1514import numpy as np
1615import xarray as xr
1716from numpy .typing import NDArray
1817
1918from weathergen .datasets .data_reader_base import (
2019 DataReaderTimestep ,
21- DTRange ,
2220 ReaderData ,
2321 TimeWindowHandler ,
2422 TIndex ,
2826
2927_logger = logging .getLogger (__name__ )
3028
29+
3130# TODO make this datareader works with multiple datasets in ZARR format
3231class DataReaderEObs (DataReaderTimestep ):
3332 """
3433 Data reader for gridded Zarr datasets with regular lat/lon structure.
35-
34+
3635 This reader handles datasets stored as Zarr with dimensions (time, latitude, longitude)
3736 and converts the gridded data to point-wise format required by the framework.
38-
37+
3938 The reader implements lazy initialization to work efficiently with multiple dataloader workers.
4039 """
4140
@@ -76,7 +75,7 @@ def __init__(
7675 self .geoinfo_channels = []
7776 self .geoinfo_idx = []
7877 self .properties = {}
79-
78+
8079 # Grid properties
8180 self .latitudes : NDArray | None = None
8281 self .longitudes : NDArray | None = None
@@ -104,12 +103,7 @@ def _lazy_init(self) -> None:
104103
105104 try :
106105 # Open the Zarr dataset with xarray
107- self .ds = xr .open_zarr (
108- self ._filename ,
109- consolidated = True ,
110- chunks = None ,
111- zarr_format = 2
112- )
106+ self .ds = xr .open_zarr (self ._filename , consolidated = True , chunks = None , zarr_format = 2 )
113107 except Exception as e :
114108 name = self ._stream_info ["name" ]
115109 _logger .error (f"Failed to open { name } at { self ._filename } : { e } " )
@@ -123,14 +117,9 @@ def _lazy_init(self) -> None:
123117 data_end_time = np .datetime64 (time_coord [- 1 ])
124118
125119 # Check if dataset overlaps with requested time window
126- if (
127- self ._tw_handler .t_start >= data_end_time
128- or self ._tw_handler .t_end <= data_start_time
129- ):
120+ if self ._tw_handler .t_start >= data_end_time or self ._tw_handler .t_end <= data_start_time :
130121 name = self ._stream_info ["name" ]
131- _logger .warning (
132- f"{ name } is not supported over data loader window. Stream is skipped."
133- )
122+ _logger .warning (f"{ name } is not supported over data loader window. Stream is skipped." )
134123 self .init_empty ()
135124 self ._initialized = True
136125 return
@@ -156,9 +145,7 @@ def _lazy_init(self) -> None:
156145 )
157146
158147 # Calculate valid time range indices
159- time_mask = (time_coord >= self ._tw_handler .t_start ) & (
160- time_coord < self ._tw_handler .t_end
161- )
148+ time_mask = (time_coord >= self ._tw_handler .t_start ) & (time_coord < self ._tw_handler .t_end )
162149 self .len = int (np .sum (time_mask ))
163150
164151 if self .len <= 0 :
@@ -183,9 +170,7 @@ def _lazy_init(self) -> None:
183170 f"Longitude values outside valid range [-180, 180] in stream "
184171 f"'{ self ._stream_info ['name' ]} '. Converting from [0, 360] format."
185172 )
186- self .longitudes = ((self .longitudes + 180.0 ) % 360.0 - 180.0 ).astype (
187- np .float32
188- )
173+ self .longitudes = ((self .longitudes + 180.0 ) % 360.0 - 180.0 ).astype (np .float32 )
189174
190175 self .n_lat = len (self .latitudes )
191176 self .n_lon = len (self .longitudes )
@@ -405,9 +390,9 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
405390
406391 # Create coordinate grid
407392 lon_grid , lat_grid = np .meshgrid (self .longitudes , self .latitudes )
408- coords_single = np .stack (
409- [ lat_grid . flatten (), lon_grid . flatten ()], axis = 1
410- ). astype ( np . float32 )
393+ coords_single = np .stack ([ lat_grid . flatten (), lon_grid . flatten ()], axis = 1 ). astype (
394+ np . float32
395+ )
411396
412397 # Repeat coordinates for each timestep
413398 coords = np .tile (coords_single , (len (t_idxs ), 1 ))
@@ -427,4 +412,4 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData:
427412
428413 check_reader_data (rd , dtr )
429414
430- return rd
415+ return rd
0 commit comments