Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 243c65b

Browse files
Merge pull request #92 from openclimatefix/location_picker_x_y_names
Location picker x y names
2 parents 5d18805 + 26a1099 commit 243c65b

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

ocf_datapipes/config/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141

4242
logger = logging.getLogger(__name__)
4343

44+
# add SV to list of providers
45+
providers.append("SV")
46+
4447

4548
class Base(BaseModel):
4649
"""Pydantic Base model where no extras can be added"""

ocf_datapipes/select/location_picker.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Pick locations from a dataset"""
22
import logging
3+
from typing import Optional
34

45
import numpy as np
56
from torchdata.datapipes import functional_datapipe
@@ -14,18 +15,28 @@
1415
class LocationPickerIterDataPipe(IterDataPipe):
1516
"""Picks locations from a dataset and returns them"""
1617

17-
def __init__(self, source_datapipe: IterDataPipe, return_all_locations: bool = False):
18+
def __init__(
19+
self,
20+
source_datapipe: IterDataPipe,
21+
return_all_locations: bool = False,
22+
x_dim_name: Optional[str] = "x_osgb",
23+
y_dim_name: Optional[str] = "y_osgb",
24+
):
1825
"""
1926
Picks locations from a dataset and returns them
2027
2128
Args:
2229
source_datapipe: Datapipe emitting Xarray Dataset
2330
return_all_locations: Whether to return all locations,
24-
if True, also returns them in order
31+
if True, also returns them in order
32+
x_dim_name: x dimension name, defaulted to 'x_osgb'
33+
y_dim_name: y dimension name, defaulted to 'y_osgb'
2534
"""
2635
super().__init__()
2736
self.source_datapipe = source_datapipe
2837
self.return_all_locations = return_all_locations
38+
self.x_dim_name = x_dim_name
39+
self.y_dim_name = y_dim_name
2940

3041
def __iter__(self) -> Location:
3142
"""Returns locations from the inputs datapipe"""
@@ -35,10 +46,10 @@ def __iter__(self) -> Location:
3546

3647
if self.return_all_locations:
3748
# Iterate through all locations in dataset
38-
for location_idx in range(len(xr_dataset["x_osgb"])):
49+
for location_idx in range(len(xr_dataset[self.x_dim_name])):
3950
location = Location(
40-
x=xr_dataset["x_osgb"][location_idx].values,
41-
y=xr_dataset["y_osgb"][location_idx].values,
51+
x=xr_dataset[self.x_dim_name][location_idx].values,
52+
y=xr_dataset[self.y_dim_name][location_idx].values,
4253
)
4354
if "pv_system_id" in xr_dataset.coords.keys():
4455

@@ -48,12 +59,13 @@ def __iter__(self) -> Location:
4859
else:
4960
# Assumes all datasets have osgb coordinates for selecting locations
5061
# Pick 1 random location from the input dataset
51-
location_idx = np.random.randint(0, len(xr_dataset["x_osgb"]))
62+
location_idx = np.random.randint(0, len(xr_dataset[self.x_dim_name]))
5263
location = Location(
53-
x=xr_dataset["x_osgb"][location_idx].values,
54-
y=xr_dataset["y_osgb"][location_idx].values,
64+
x=xr_dataset[self.x_dim_name][location_idx].values,
65+
y=xr_dataset[self.y_dim_name][location_idx].values,
5566
)
5667
if "pv_system_id" in xr_dataset.coords.keys():
5768
location.id = int(xr_dataset["pv_system_id"][location_idx].values)
5869
logger.debug(f"Have selected location.id {location.id}")
59-
yield location
70+
for i in range(0, 10):
71+
yield location

ocf_datapipes/training/simple_pv.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
logger = logging.getLogger(__name__)
1616
xarray.set_options(keep_attrs=True)
1717

18+
# default is set to 1000
19+
BUFFERSIZE = -1
20+
1821

1922
def simple_pv_datapipe(
2023
configuration_filename: Union[Path, str], tag: Optional[str] = "train"
@@ -39,7 +42,9 @@ def simple_pv_datapipe(
3942

4043
logger.debug("Opening Datasets")
4144
pv_datapipe, pv_location_datapipe = (
42-
OpenPVFromNetCDF(pv=configuration.input_data.pv).pv_fill_night_nans().fork(2)
45+
OpenPVFromNetCDF(pv=configuration.input_data.pv)
46+
.pv_fill_night_nans()
47+
.fork(2, buffer_size=BUFFERSIZE)
4348
)
4449

4550
logger.debug("Add t0 idx")
@@ -71,7 +76,7 @@ def simple_pv_datapipe(
7176
)
7277
.ensure_n_pv_systems_per_example(n_pv_systems_per_example=1)
7378
.remove_nans()
74-
.fork(3)
79+
.fork(3, buffer_size=BUFFERSIZE)
7580
)
7681

7782
# get contiguous time periods

0 commit comments

Comments
 (0)