11"""Pick locations from a dataset"""
22import logging
3+ from typing import Optional
34
45import numpy as np
56from torchdata .datapipes import functional_datapipe
1415class LocationPickerIterDataPipe (IterDataPipe ):
1516 """Picks locations from a dataset and returns them"""
1617
17- def __init__ (self , source_datapipe : IterDataPipe , return_all_locations : bool = False ):
18+ def __init__ (
19+ self ,
20+ source_datapipe : IterDataPipe ,
21+ return_all_locations : bool = False ,
22+ x_dim_name : Optional [str ] = "x_osgb" ,
23+ y_dim_name : Optional [str ] = "y_osgb" ,
24+ ):
1825 """
1926 Picks locations from a dataset and returns them
2027
2128 Args:
2229 source_datapipe: Datapipe emitting Xarray Dataset
2330 return_all_locations: Whether to return all locations,
24- if True, also returns them in order
31+ if True, also returns them in order
32+ x_dim_name: x dimension name, defaulted to 'x_osgb'
33+ y_dim_name: y dimension name, defaulted to 'y_osgb'
2534 """
2635 super ().__init__ ()
2736 self .source_datapipe = source_datapipe
2837 self .return_all_locations = return_all_locations
38+ self .x_dim_name = x_dim_name
39+ self .y_dim_name = y_dim_name
2940
3041 def __iter__ (self ) -> Location :
3142 """Returns locations from the inputs datapipe"""
@@ -35,10 +46,10 @@ def __iter__(self) -> Location:
3546
3647 if self .return_all_locations :
3748 # Iterate through all locations in dataset
38- for location_idx in range (len (xr_dataset ["x_osgb" ])):
49+ for location_idx in range (len (xr_dataset [self . x_dim_name ])):
3950 location = Location (
40- x = xr_dataset ["x_osgb" ][location_idx ].values ,
41- y = xr_dataset ["y_osgb" ][location_idx ].values ,
51+ x = xr_dataset [self . x_dim_name ][location_idx ].values ,
52+ y = xr_dataset [self . y_dim_name ][location_idx ].values ,
4253 )
4354 if "pv_system_id" in xr_dataset .coords .keys ():
4455
@@ -48,12 +59,13 @@ def __iter__(self) -> Location:
4859 else :
4960 # Assumes all datasets have osgb coordinates for selecting locations
5061 # Pick 1 random location from the input dataset
51- location_idx = np .random .randint (0 , len (xr_dataset ["x_osgb" ]))
62+ location_idx = np .random .randint (0 , len (xr_dataset [self . x_dim_name ]))
5263 location = Location (
53- x = xr_dataset ["x_osgb" ][location_idx ].values ,
54- y = xr_dataset ["y_osgb" ][location_idx ].values ,
64+ x = xr_dataset [self . x_dim_name ][location_idx ].values ,
65+ y = xr_dataset [self . y_dim_name ][location_idx ].values ,
5566 )
5667 if "pv_system_id" in xr_dataset .coords .keys ():
5768 location .id = int (xr_dataset ["pv_system_id" ][location_idx ].values )
5869 logger .debug (f"Have selected location.id { location .id } " )
59- yield location
70+ for i in range (0 , 10 ):
71+ yield location
0 commit comments