55from ocf_datapipes .utils .utils import datetime64_to_float
66
77
8+ def _convert_satellite_to_numpy_batch (xr_data ):
9+ example : NumpyBatch = {
10+ BatchKey .satellite_actual : xr_data .values ,
11+ BatchKey .satellite_t0_idx : xr_data .attrs ["t0_idx" ],
12+ BatchKey .satellite_time_utc : datetime64_to_float (xr_data ["time_utc" ].values ),
13+ }
14+
15+ for batch_key , dataset_key in (
16+ (BatchKey .satellite_y_geostationary , "y_geostationary" ),
17+ (BatchKey .satellite_x_geostationary , "x_geostationary" ),
18+ ):
19+ # HRVSatellite coords are already float32.
20+ example [batch_key ] = xr_data [dataset_key ].values
21+
22+ return example
23+
24+
25+ def _convert_hrvsatellite_to_numpy_batch (xr_data ):
26+ example : NumpyBatch = {
27+ BatchKey .hrvsatellite_actual : xr_data .values ,
28+ BatchKey .hrvsatellite_t0_idx : xr_data .attrs ["t0_idx" ],
29+ BatchKey .hrvsatellite_time_utc : datetime64_to_float (xr_data ["time_utc" ].values ),
30+ }
31+
32+ for batch_key , dataset_key in (
33+ (BatchKey .hrvsatellite_y_geostationary , "y_geostationary" ),
34+ (BatchKey .hrvsatellite_x_geostationary , "x_geostationary" ),
35+ ):
36+ # Satellite coords are already float32.
37+ example [batch_key ] = xr_data [dataset_key ].values
38+
39+ return example
40+
41+
42+ def convert_satellite_to_numpy_batch (xr_data , is_hrv = False ):
43+ """Converts Xarray Satellite to NumpyBatch object"""
44+ if is_hrv :
45+ example = _convert_hrvsatellite_to_numpy_batch (xr_data )
46+ else :
47+ example = _convert_satellite_to_numpy_batch (xr_data )
48+ return example
49+
50+
851@functional_datapipe ("convert_satellite_to_numpy_batch" )
952class ConvertSatelliteToNumpyBatchIterDataPipe (IterDataPipe ):
1053 """Converts Xarray Satellite to NumpyBatch object"""
@@ -24,31 +67,4 @@ def __init__(self, source_datapipe: IterDataPipe, is_hrv: bool = False):
2467 def __iter__ (self ) -> NumpyBatch :
2568 """Convert each example to a NumpyBatch object"""
2669 for xr_data in self .source_datapipe :
27- if self .is_hrv :
28- example : NumpyBatch = {
29- BatchKey .hrvsatellite_actual : xr_data .values ,
30- BatchKey .hrvsatellite_t0_idx : xr_data .attrs ["t0_idx" ],
31- BatchKey .hrvsatellite_time_utc : datetime64_to_float (xr_data ["time_utc" ].values ),
32- }
33-
34- for batch_key , dataset_key in (
35- (BatchKey .hrvsatellite_y_geostationary , "y_geostationary" ),
36- (BatchKey .hrvsatellite_x_geostationary , "x_geostationary" ),
37- ):
38- # HRVSatellite coords are already float32.
39- example [batch_key ] = xr_data [dataset_key ].values
40- else :
41- example : NumpyBatch = {
42- BatchKey .satellite_actual : xr_data .values ,
43- BatchKey .satellite_t0_idx : xr_data .attrs ["t0_idx" ],
44- BatchKey .satellite_time_utc : datetime64_to_float (xr_data ["time_utc" ].values ),
45- }
46-
47- for batch_key , dataset_key in (
48- (BatchKey .satellite_y_geostationary , "y_geostationary" ),
49- (BatchKey .satellite_x_geostationary , "x_geostationary" ),
50- ):
51- # HRVSatellite coords are already float32.
52- example [batch_key ] = xr_data [dataset_key ].values
53-
54- yield example
70+ yield convert_satellite_to_numpy_batch (xr_data , self .is_hrv )
0 commit comments