Skip to content

Commit

Permalink
bugfix - check input dates within available zarr data
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshCu committed Aug 30, 2024
1 parent 35ad58c commit ad63070
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion modules/data_processing/zarr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,29 @@ def load_zarr_datasets() -> xr.Dataset:
dataset = xr.open_mfdataset(s3_stores, parallel=True, engine="zarr")
return dataset


def validate_time_range(dataset: xr.Dataset, start_time: str, end_time: str) -> Tuple[str, str]:
end_time_in_dataset = dataset.time[-1].values
start_time_in_dataset = dataset.time[0].values
if np.datetime64(start_time) < start_time_in_dataset:
logger.warning(
f"provided start {start_time} is before the start of the dataset {start_time_in_dataset}, selecting from {start_time_in_dataset}"
)
start_time = start_time_in_dataset
if np.datetime64(end_time) > end_time_in_dataset:
logger.warning(
f"provided end {end_time} is after the end of the dataset {end_time_in_dataset}, selecting until {end_time_in_dataset}"
)
end_time = end_time_in_dataset
return start_time, end_time


def clip_dataset_to_bounds(
dataset: xr.Dataset, bounds: Tuple[float, float, float, float], start_time: str, end_time: str
) -> xr.Dataset:
"""Clip the dataset to specified geographical bounds."""
# check time range here in case just this function is imported and not the whole module
start_time, end_time = validate_time_range(dataset, start_time, end_time)
dataset = dataset.sel(
x=slice(bounds[0], bounds[2]),
y=slice(bounds[1], bounds[3]),
Expand Down Expand Up @@ -67,10 +86,11 @@ def get_forcing_data(
cached_data = xr.open_mfdataset(
forcing_paths.cached_nc_file(), parallel=True, engine="h5netcdf"
)
start_time, end_time = validate_time_range(cached_data, start_time, end_time)
if cached_data.time[0].values <= np.datetime64(start_time) and cached_data.time[
-1
].values >= np.datetime64(end_time):
logger.info("Time range is correct")
logger.info("Time range is within cached data")
logger.debug(f"Opened cached nc file: [{forcing_paths.cached_nc_file()}]")
merged_data = clip_dataset_to_bounds(
cached_data, gdf.total_bounds, start_time, end_time
Expand Down

0 comments on commit ad63070

Please sign in to comment.