diff --git a/docs/conf.py b/docs/conf.py index 41da386..6f9a35d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,7 +15,7 @@ author = 'Riley Hales, PhD' # The full version, including alpha/beta/rc tags -release = '1.6.1' +release = '1.6.2' master_doc = 'index' # -- General configuration --------------------------------------------------- diff --git a/environment.yaml b/environment.yaml index 6908fc6..0ab021c 100644 --- a/environment.yaml +++ b/environment.yaml @@ -3,6 +3,7 @@ channels: - conda-forge dependencies: - python>=3 + - cftime - dask >=2024 - fastparquet - HydroErr diff --git a/geoglows/__init__.py b/geoglows/__init__.py index e6a6d92..a0ee1f5 100644 --- a/geoglows/__init__.py +++ b/geoglows/__init__.py @@ -12,6 +12,6 @@ 'bias', 'plots', 'data', 'analyze', 'streams', 'tables', 'streamflow', 'get_metadata_table_path', 'set_metadata_table_path', ] -__version__ = '1.6.1' +__version__ = '1.6.2' __author__ = 'Riley Hales' __license__ = 'BSD 3-Clause Clear License' diff --git a/geoglows/_download_decorators.py b/geoglows/_download_decorators.py index 6e01e0f..501ea70 100644 --- a/geoglows/_download_decorators.py +++ b/geoglows/_download_decorators.py @@ -1,23 +1,21 @@ -import os import warnings from io import StringIO +import numpy as np import pandas as pd import requests import s3fs import xarray as xr -import numpy as np - -from .analyze import ( - simple_forecast as calc_simple_forecast, - forecast_stats as calc_forecast_stats, -) from ._constants import ( ODP_FORECAST_S3_BUCKET_URI, ODP_RETROSPECTIVE_S3_BUCKET_URI, ODP_S3_BUCKET_REGION, ) +from .analyze import ( + simple_forecast as calc_simple_forecast, + forecast_stats as calc_forecast_stats, +) DEFAULT_REST_ENDPOINT = 'https://geoglows.ecmwf.int/api/' DEFAULT_REST_ENDPOINT_VERSION = 'v2' # 'v1, v2, latest' @@ -29,14 +27,22 @@ def _forecast(function): + def _river_id_is_iterable(river_id): + return bool( + isinstance(river_id, list) or + isinstance(river_id, tuple) or + isinstance(river_id, set) or + isinstance(river_id, np.ndarray) + ) + def from_aws(*args, **kwargs): product_name = function.__name__.replace("_", "").lower() if product_name == 'forecastrecords': warnings.warn('forecast_records are not available from the AWS Open Data Program.') return from_rest(*args, **kwargs) - river_id = args[0] if len(args) > 0 else kwargs.get('river_id', '') - if river_id is None or river_id == '': + river_id = args[0] if len(args) > 0 else kwargs.get('river_id', None) + if river_id is None: raise ValueError('River ID must be provided to retrieve forecast data.') return_format = kwargs.get('format', 'df') @@ -51,7 +57,7 @@ def from_aws(*args, **kwargs): date = kwargs.get('date', False) if not date: zarr_vars = ['rivid', 'Qout', 'time', 'ensemble'] - dates = [s3.glob(os.path.join(ODP_FORECAST_S3_BUCKET_URI, f'*.zarr/{var}')) for var in zarr_vars] + dates = [s3.glob(ODP_FORECAST_S3_BUCKET_URI + '/' + f'*.zarr/{var}') for var in zarr_vars] dates = [set([d.split('/')[1].replace('.zarr', '') for d in date]) for date in dates] dates = sorted(set.intersection(*dates), reverse=True) if product_name == 'dates': @@ -119,11 +125,11 @@ def from_rest(*args, **kwargs): product_name = function.__name__.replace("_", "").lower() - river_id = args[0] if len(args) > 0 else kwargs.get('river_id', '') - if isinstance(river_id, list): - raise ValueError('Multiple river_ids are not available via REST API or on v1. ' - 'Use data_source="aws" for multiple river_ids.') - river_id = int(river_id) if river_id else None + river_id = args[0] if len(args) > 0 else kwargs.get('river_id', None) + if river_id is None: + raise ValueError('River ID must be provided to retrieve forecast data.') + if not isinstance(river_id, (int, np.int64, )): + raise ValueError('Multiple river_ids are not available via REST API. Provide a single 9 digit integer.') if river_id and version == 'v2': assert 1_000_000_000 > river_id >= 110_000_000, ValueError('River ID must be a 9 digit integer') @@ -182,8 +188,8 @@ def _retrospective(function): def main(*args, **kwargs): product_name = function.__name__.replace("_", "-").lower() - river_id = args[0] if len(args) > 0 else kwargs.get('river_id', '') - if river_id is None or river_id == '': + river_id = args[0] if len(args) > 0 else kwargs.get('river_id', None) + if river_id is None: raise ValueError('River ID must be provided to retrieve retrospective data.') return_format = kwargs.get('format', 'df')