diff --git a/docs/conf.py b/docs/conf.py index 1abade2..6c3694b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,7 +15,7 @@ author = 'Riley Hales, PhD' # The full version, including alpha/beta/rc tags -release = '1.4.0' +release = '1.5.0' master_doc = 'index' # -- General configuration --------------------------------------------------- diff --git a/geoglows/__init__.py b/geoglows/__init__.py index ba86d67..0a13e9f 100644 --- a/geoglows/__init__.py +++ b/geoglows/__init__.py @@ -12,6 +12,6 @@ 'bias', 'plots', 'data', 'analyze', 'streams', 'tables', 'streamflow', 'METADATA_TABLE_PATH' ] -__version__ = '1.4.0' +__version__ = '1.5.0' __author__ = 'Riley Hales' __license__ = 'BSD 3-Clause Clear License' diff --git a/geoglows/_constants.py b/geoglows/_constants.py index a653b47..fd69956 100644 --- a/geoglows/_constants.py +++ b/geoglows/_constants.py @@ -1,5 +1,10 @@ import os +ODP_CORE_S3_BUCKET_URI = 's3://geoglows-v2' +ODP_FORECAST_S3_BUCKET_URI = 's3://geoglows-v2-forecasts' +ODP_RETROSPECTIVE_S3_BUCKET_URI = 's3://geoglows-v2-retrospective' +ODP_S3_BUCKET_REGION = 'us-west-2' + METADATA_TABLE_PATH = os.getenv( 'PYGEOGLOWS_METADATA_TABLE_PATH', os.path.join(os.path.dirname(__file__), 'data', 'metadata-tables.parquet') diff --git a/geoglows/_download_decorators.py b/geoglows/_download_decorators.py new file mode 100644 index 0000000..fb4924d --- /dev/null +++ b/geoglows/_download_decorators.py @@ -0,0 +1,230 @@ +import os +import warnings +from io import StringIO + +import pandas as pd +import requests +import s3fs +import xarray as xr +import numpy as np + +from .analyze import ( + simple_forecast as calc_simple_forecast, + forecast_stats as calc_forecast_stats, +) + +from ._constants import ( + ODP_FORECAST_S3_BUCKET_URI, + ODP_RETROSPECTIVE_S3_BUCKET_URI, + ODP_S3_BUCKET_REGION, +) + +DEFAULT_REST_ENDPOINT = 'https://geoglows.ecmwf.int/api/' +DEFAULT_REST_ENDPOINT_VERSION = 'v2' # 'v1, v2, latest' + +__all__ = [ + '_forecast', + '_retrospective', +] + + +def _forecast(function): + def from_aws(*args, **kwargs): + product_name = function.__name__.replace("_", "").lower() + if product_name == 'forecastrecords': + warnings.warn('forecast_records are not available from the AWS Open Data Program.') + return from_rest(*args, **kwargs) + + river_id = kwargs.get('river_id', '') + river_id = args[0] if len(args) > 0 else river_id + + return_format = kwargs.get('format', 'df') + assert return_format in ('df', 'xarray'), f'Unsupported return format requested: {return_format}' + + if kwargs.get('skip_log', False): + requests.post(f'{DEFAULT_REST_ENDPOINT}{DEFAULT_REST_ENDPOINT_VERSION}/log', + json={'river_id': river_id, 'product': product_name, 'format': return_format}, + timeout=1, ) # short timeout- don't need the response, post only needs to be received + + s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION)) + date = kwargs.get('date', False) + if not date: + zarr_vars = ['rivid', 'Qout', 'time', 'ensemble'] + dates = [s3.glob(os.path.join(ODP_FORECAST_S3_BUCKET_URI, f'*.zarr/{var}')) for var in zarr_vars] + dates = [set([d.split('/')[1].replace('.zarr', '') for d in date]) for date in dates] + dates = sorted(set.intersection(*dates), reverse=True) + if product_name == 'dates': + return pd.DataFrame(dict(dates=dates)) + date = dates[0] + if len(date) == 8: + date = f'{date}00.zarr' + elif len(date) == 10: + date = f'{date}.zarr' + else: + raise ValueError('Date must be YYYYMMDD or YYYYMMDDHH format. Use dates() to view available data.') + + s3store = s3fs.S3Map(root=f'{ODP_FORECAST_S3_BUCKET_URI}/{date}', s3=s3, check=False) + + attrs = { + 'source': 'geoglows', + 'forecast_date': date[:8], + 'retrieval_date': pd.Timestamp.now().strftime('%Y%m%d'), + 'units': 'cubic meters per second', + } + ds = xr.open_zarr(s3store).sel(rivid=river_id) + if return_format == 'xarray' and product_name == 'forecastensembles': + ds = ds.rename({'time': 'datetime', 'rivid': 'river_id'}) + ds.attrs = attrs + return ds + df = ds.to_dataframe().round(2).reset_index() + + # rename columns to match the REST API + if isinstance(river_id, int) or isinstance(river_id, np.int64): + df = df.pivot(index='time', columns='ensemble', values='Qout') + else: + df = df.pivot(index=['time', 'rivid'], columns='ensemble', values='Qout') + df.index.names = ['time', 'river_id'] + df = df[sorted(df.columns)] + df.columns = [f'ensemble_{str(x).zfill(2)}' for x in df.columns] + + if product_name == 'forecast': + df = calc_simple_forecast(df) + elif product_name == 'forecaststats': + df = calc_forecast_stats(df) + + if return_format == 'df': + return df + ds = df.to_xarray() + ds.attrs = attrs + return ds + + def from_rest(*args, **kwargs): + # update the default values set by the function unless the user has already specified them + for key, value in function.__kwdefaults__.items() if function.__kwdefaults__ else []: + if key not in kwargs: + kwargs[key] = value + + return_format = kwargs.get('format', 'csv') + assert return_format in ('csv', 'json', 'url'), f'Unsupported format requested: {return_format}' + + # parse out the information necessary to build a request url + endpoint = kwargs.get('endpoint', DEFAULT_REST_ENDPOINT) + endpoint = endpoint[:-1] if endpoint[-1] == '/' else endpoint + endpoint = endpoint + '/api' if not endpoint.endswith('/api') else endpoint + endpoint = f'https://{endpoint}' if not endpoint.startswith(('https://', 'http://')) else endpoint + + version = kwargs.get('version', DEFAULT_REST_ENDPOINT_VERSION) + assert version in ('v2',), ValueError(f'Unrecognized model version parameter: {version}') + + product_name = function.__name__.replace("_", "").lower() + + river_id = args[0] if len(args) > 0 else None + river_id = kwargs.get('river_id', '') if not river_id else river_id + if isinstance(river_id, list): + raise ValueError('Multiple river_ids are not available via REST API or on v1. ' + 'Use data_source="aws" for multiple river_ids.') + river_id = int(river_id) if river_id else None + if river_id and version == 'v2': + assert 1_000_000_000 > river_id >= 110_000_000, ValueError('River ID must be a 9 digit integer') + + # request parameter validation before submitting + for key in ('endpoint', 'version', 'river_id'): + if key in kwargs: + del kwargs[key] + for key, value in kwargs.items(): + if value is None: + del kwargs[key] + for date in ('date', 'start_date', 'end_date'): + if date in kwargs: + assert len(str(kwargs[date])) == 8 or len( + str(kwargs[date])) == 10, f'Invalid date format: {kwargs[date]}' + if 'format' in kwargs and kwargs['format'] != 'json': + del kwargs['format'] + kwargs['source'] = kwargs.get('source', 'pygeoglows') # allow using default for specific apps which override + params = '&'.join([f'{key}={value}' for key, value in kwargs.items()]) + + # piece together the request url + request_url = f'{endpoint}/{version}/{product_name}' # build the base url + request_url = f'{request_url}/{river_id}' if river_id else request_url # add the river_id if it exists + request_url = f'{request_url}?{params}' # add the query parameters + + if return_format == 'url': + return request_url.replace(f'source={kwargs["source"]}', '') + + response = requests.get(request_url) + + if response.status_code != 200: + raise RuntimeError('Received an error from the REST API: ' + response.text) + + if return_format == 'csv': + df = pd.read_csv(StringIO(response.text)) + if 'datetime' in df.columns: + df['datetime'] = pd.to_datetime(df['datetime']) + df = df.set_index('datetime') + return df + elif return_format == 'json': + return response.json() + else: + raise ValueError(f'Unsupported return format requested: {return_format}') + + def main(*args, **kwargs): + source = kwargs.get('data_source', 'aws') + assert source in ('rest', 'aws'), ValueError(f'Unrecognized data source requested: {source}') + if source == 'rest': + return from_rest(*args, **kwargs) + return from_aws(*args, **kwargs) + + main.__doc__ = function.__doc__ # necessary for code documentation auto generators + return main + + +def _retrospective(function): + def main(*args, **kwargs): + product_name = function.__name__.replace("_", "-").lower() + + river_id = args[0] if len(args) > 0 else None + river_id = kwargs.get('river_id', '') if not river_id else river_id + + return_format = kwargs.get('format', 'df') + assert return_format in ('df', 'xarray'), f'Unsupported return format requested: {return_format}' + + method = kwargs.get('method', 'gumbel1') + + if kwargs.get('skip_log', False): + requests.post(f'{DEFAULT_REST_ENDPOINT}{DEFAULT_REST_ENDPOINT_VERSION}/log', + timeout=1, # short timeout because we don't need the response, post just needs to be received + json={'river_id': river_id, 'product': product_name, 'format': return_format}) + + s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION)) + s3store = s3fs.S3Map(root=f'{ODP_RETROSPECTIVE_S3_BUCKET_URI}/{product_name}.zarr', s3=s3, check=False) + ds = xr.open_zarr(s3store) + try: + ds = ds.sel(rivid=river_id) + except Exception: + raise ValueError(f'River ID(s) not found in the retrospective dataset: {river_id}') + if return_format == 'xarray': + return ds + if product_name == 'retrospective': + return ( + ds + .to_dataframe() + .reset_index() + .set_index('time') + .pivot(columns='rivid', values='Qout') + ) + if product_name == 'return-periods': + rp_methods = { + 'gumbel1': 'gumbel1_return_period', + } + assert method in rp_methods, f'Unrecognized return period estimation method given: {method}' + return ( + ds + [rp_methods[method]] + .to_dataframe() + .reset_index() + .pivot(index='rivid', columns='return_period', values=rp_methods[method]) + ) + raise ValueError(f'Unsupported product requested: {product_name}') + + main.__doc__ = function.__doc__ # necessary for code documentation auto generators + return main diff --git a/geoglows/data.py b/geoglows/data.py index dae15b5..dcf8409 100644 --- a/geoglows/data.py +++ b/geoglows/data.py @@ -1,17 +1,13 @@ import os import warnings -from io import StringIO import pandas as pd -import requests -import s3fs import xarray as xr -import numpy as np from ._constants import METADATA_TABLE_PATH +from ._download_decorators import _forecast, _retrospective + from .analyze import ( - simple_forecast as calc_simple_forecast, - forecast_stats as calc_forecast_stats, daily_averages as calc_daily_averages, monthly_averages as calc_monthly_averages, annual_averages as calc_annual_averages, @@ -33,161 +29,9 @@ 'metadata_tables', ] -DEFAULT_REST_ENDPOINT = 'https://geoglows.ecmwf.int/api/' -DEFAULT_REST_ENDPOINT_VERSION = 'v2' # 'v1, v2, latest' -ODP_CORE_S3_BUCKET_URI = 's3://geoglows-v2' -ODP_FORECAST_S3_BUCKET_URI = 's3://geoglows-v2-forecasts' -ODP_RETROSPECTIVE_S3_BUCKET_URI = 's3://geoglows-v2-retrospective' -ODP_S3_BUCKET_REGION = 'us-west-2' - - -def _forecast_endpoint_decorator(function): - def from_aws(*args, **kwargs): - product_name = function.__name__.replace("_", "").lower() - if product_name == 'forecastrecords': - warnings.warn('forecast_records are not available from the AWS Open Data Program.') - return from_rest(*args, **kwargs) - - river_id = kwargs.get('river_id', '') - river_id = args[0] if len(args) > 0 else river_id - - return_format = kwargs.get('format', 'df') - assert return_format in ('df', 'xarray'), f'Unsupported return format requested: {return_format}' - - s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION)) - date = kwargs.get('date', False) - if not date: - zarr_vars = ['rivid', 'Qout', 'time', 'ensemble'] - dates = [s3.glob(os.path.join(ODP_FORECAST_S3_BUCKET_URI, f'*.zarr/{var}')) for var in zarr_vars] - dates = [set([d.split('/')[1].replace('.zarr', '') for d in date]) for date in dates] - dates = sorted(set.intersection(*dates), reverse=True) - if product_name == 'dates': - return pd.DataFrame(dict(dates=dates)) - date = dates[0] - if len(date) == 8: - date = f'{date}00.zarr' - elif len(date) == 10: - date = f'{date}.zarr' - else: - raise ValueError('Date must be YYYYMMDD or YYYYMMDDHH format. Use dates() to view available data.') - - s3store = s3fs.S3Map(root=f'{ODP_FORECAST_S3_BUCKET_URI}/{date}', s3=s3, check=False) - - attrs = { - 'source': 'geoglows', - 'forecast_date': date[:8], - 'retrieval_date': pd.Timestamp.now().strftime('%Y%m%d'), - 'units': 'cubic meters per second', - } - ds = xr.open_zarr(s3store).sel(rivid=river_id) - if return_format == 'xarray' and product_name == 'forecastensembles': - ds = ds.rename({'time': 'datetime', 'rivid': 'river_id'}) - ds.attrs = attrs - return ds - df = ds.to_dataframe().round(2).reset_index() - - # rename columns to match the REST API - if isinstance(river_id, int) or isinstance(river_id, np.int64): - df = df.pivot(index='time', columns='ensemble', values='Qout') - else: - df = df.pivot(index=['time', 'rivid'], columns='ensemble', values='Qout') - df.index.names = ['time', 'river_id'] - df = df[sorted(df.columns)] - df.columns = [f'ensemble_{str(x).zfill(2)}' for x in df.columns] - - if product_name == 'forecast': - df = calc_simple_forecast(df) - elif product_name == 'forecaststats': - df = calc_forecast_stats(df) - - if return_format == 'df': - return df - ds = df.to_xarray() - ds.attrs = attrs - return ds - - def from_rest(*args, **kwargs): - # update the default values set by the function unless the user has already specified them - for key, value in function.__kwdefaults__.items() if function.__kwdefaults__ else []: - if key not in kwargs: - kwargs[key] = value - - return_format = kwargs.get('format', 'csv') - assert return_format in ('csv', 'json', 'url'), f'Unsupported format requested: {return_format}' - - # parse out the information necessary to build a request url - endpoint = kwargs.get('endpoint', DEFAULT_REST_ENDPOINT) - endpoint = endpoint[:-1] if endpoint[-1] == '/' else endpoint - endpoint = endpoint + '/api' if not endpoint.endswith('/api') else endpoint - endpoint = f'https://{endpoint}' if not endpoint.startswith(('https://', 'http://')) else endpoint - - version = kwargs.get('version', DEFAULT_REST_ENDPOINT_VERSION) - assert version in ('v2',), ValueError(f'Unrecognized model version parameter: {version}') - - product_name = function.__name__.replace("_", "").lower() - - river_id = args[0] if len(args) > 0 else None - river_id = kwargs.get('river_id', '') if not river_id else river_id - if isinstance(river_id, list): - raise ValueError('Multiple river_ids are not available via REST API or on v1. ' - 'Use data_source="aws" and version="v2" for multiple river_ids.') - river_id = int(river_id) if river_id else None - if river_id and version == 'v2': - assert 1_000_000_000 > river_id >= 110_000_000, ValueError('River ID must be a 9 digit integer') - - # request parameter validation before submitting - for key in ('endpoint', 'version', 'river_id'): - if key in kwargs: - del kwargs[key] - for key, value in kwargs.items(): - if value is None: - del kwargs[key] - for date in ('date', 'start_date', 'end_date'): - if date in kwargs: - assert len(str(kwargs[date])) == 8 or len( - str(kwargs[date])) == 10, f'Invalid date format: {kwargs[date]}' - if 'format' in kwargs and kwargs['format'] != 'json': - del kwargs['format'] - kwargs['source'] = kwargs.get('source', 'pygeoglows') # allow using default for specific apps which override - params = '&'.join([f'{key}={value}' for key, value in kwargs.items()]) - - # piece together the request url - request_url = f'{endpoint}/{version}/{product_name}' # build the base url - request_url = f'{request_url}/{river_id}' if river_id else request_url # add the river_id if it exists - request_url = f'{request_url}?{params}' # add the query parameters - - if return_format == 'url': - return request_url.replace(f'source={kwargs["source"]}', '') - - response = requests.get(request_url) - - if response.status_code != 200: - raise RuntimeError('Received an error from the REST API: ' + response.text) - - if return_format == 'csv': - df = pd.read_csv(StringIO(response.text)) - if 'datetime' in df.columns: - df['datetime'] = pd.to_datetime(df['datetime']) - df = df.set_index('datetime') - return df - elif return_format == 'json': - return response.json() - else: - raise ValueError(f'Unsupported return format requested: {return_format}') - - def main(*args, **kwargs): - source = kwargs.get('data_source', 'aws') - assert source in ('rest', 'aws'), ValueError(f'Unrecognized data source requested: {source}') - if source == 'rest': - return from_rest(*args, **kwargs) - return from_aws(*args, **kwargs) - - main.__doc__ = function.__doc__ # necessary for code documentation auto generators - return main - # Forecast data and derived products -@_forecast_endpoint_decorator +@_forecast def dates(**kwargs) -> dict or str: """ Gets a list of available forecast product dates @@ -204,7 +48,7 @@ def dates(**kwargs) -> dict or str: pass -@_forecast_endpoint_decorator +@_forecast def forecast(*, river_id: int, date: str, format: str, data_source: str, **kwargs) -> pd.DataFrame or xr.Dataset: """ @@ -222,7 +66,7 @@ def forecast(*, river_id: int, date: str, format: str, data_source: str, pass -@_forecast_endpoint_decorator +@_forecast def forecast_stats(*, river_id: int, date: str, format: str, data_source: str, **kwargs) -> pd.DataFrame or xr.Dataset: """ @@ -241,7 +85,7 @@ def forecast_stats(*, river_id: int, date: str, format: str, data_source: str, pass -@_forecast_endpoint_decorator +@_forecast def forecast_ensembles(*, river_id: int, date: str, format: str, data_source: str, **kwargs) -> pd.DataFrame or xr.Dataset: """ @@ -259,7 +103,7 @@ def forecast_ensembles(*, river_id: int, date: str, format: str, data_source: st pass -@_forecast_endpoint_decorator +@_forecast def forecast_records(*, river_id: int, start_date: str, end_date: str, format: str, **kwargs) -> pd.DataFrame or dict or str: """ @@ -278,7 +122,8 @@ def forecast_records(*, river_id: int, start_date: str, end_date: str, format: s # Retrospective simulation and derived products -def retrospective(river_id: int or list, format: str = 'df') -> pd.DataFrame or xr.Dataset: +@_retrospective +def retrospective(river_id: int or list, *, format: str = 'df') -> pd.DataFrame or xr.Dataset: """ Retrieves the retrospective simulation of streamflow for a given river_id from the AWS Open Data Program GEOGLOWS V2 S3 bucket @@ -290,15 +135,10 @@ def retrospective(river_id: int or list, format: str = 'df') -> pd.DataFrame or Returns: pd.DataFrame """ - s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION)) - s3store = s3fs.S3Map(root=f'{ODP_RETROSPECTIVE_S3_BUCKET_URI}/retrospective.zarr', s3=s3, check=False) - ds = xr.open_zarr(s3store).sel(rivid=river_id) - if format == 'xarray': - return ds - return ds.to_dataframe().reset_index().set_index('time').pivot(columns='rivid', values='Qout') + pass -def daily_averages(river_id: int or list) -> pd.DataFrame: +def daily_averages(river_id: int or list, **kwargs) -> pd.DataFrame: """ Retrieves daily average streamflow for a given river_id @@ -308,11 +148,11 @@ def daily_averages(river_id: int or list) -> pd.DataFrame: Returns: pd.DataFrame """ - df = retrospective(river_id) + df = retrospective(river_id, **kwargs) return calc_daily_averages(df) -def monthly_averages(river_id: int or list) -> pd.DataFrame: +def monthly_averages(river_id: int or list, **kwargs) -> pd.DataFrame: """ Retrieves monthly average streamflow for a given river_id @@ -322,11 +162,11 @@ def monthly_averages(river_id: int or list) -> pd.DataFrame: Returns: pd.DataFrame """ - df = retrospective(river_id) + df = retrospective(river_id, **kwargs) return calc_monthly_averages(df) -def annual_averages(river_id: int or list) -> pd.DataFrame: +def annual_averages(river_id: int or list, **kwargs) -> pd.DataFrame: """ Retrieves annual average streamflow for a given river_id @@ -336,11 +176,12 @@ def annual_averages(river_id: int or list) -> pd.DataFrame: Returns: pd.DataFrame """ - df = retrospective(river_id) + df = retrospective(river_id, **kwargs) return calc_annual_averages(df) -def return_periods(river_id: int or list, format: str = 'df', method: str = 'gumbel1') -> pd.DataFrame or xr.Dataset: +@_retrospective +def return_periods(river_id: int or list, *, format: str = 'df', method: str = 'gumbel1') -> pd.DataFrame or xr.Dataset: """ Retrieves the return period thresholds based on a specified historic simulation forcing on a certain river_id. @@ -355,17 +196,7 @@ def return_periods(river_id: int or list, format: str = 'df', method: str = 'gum Returns: pd.DataFrame """ - rp_methods = { - 'gumbel1': 'gumbel1_return_period', - } - assert method in rp_methods, f'Unrecognized return period estimation method given: {method}' - s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION)) - s3store = s3fs.S3Map(root=f'{ODP_RETROSPECTIVE_S3_BUCKET_URI}/return-periods.zarr', s3=s3, check=False) - ds = xr.open_zarr(s3store).sel(rivid=river_id) - if format == 'xarray': - return ds - return (ds[rp_methods[method]].to_dataframe().reset_index() - .pivot(index='rivid', columns='return_period', values=rp_methods[method])) + pass # model config and supplementary data