Skip to content

Commit

Permalink
v1.5.0 (#35)
Browse files Browse the repository at this point in the history
* move decorators to file, add decorator for retrospective, log aws requests

* increment docs version number
  • Loading branch information
rileyhales authored May 2, 2024
1 parent 12502d7 commit ee0f8c5
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 190 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
author = 'Riley Hales, PhD'

# The full version, including alpha/beta/rc tags
release = '1.4.0'
release = '1.5.0'
master_doc = 'index'

# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion geoglows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
'bias', 'plots', 'data', 'analyze', 'streams', 'tables', 'streamflow',
'METADATA_TABLE_PATH'
]
__version__ = '1.4.0'
__version__ = '1.5.0'
__author__ = 'Riley Hales'
__license__ = 'BSD 3-Clause Clear License'
5 changes: 5 additions & 0 deletions geoglows/_constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os

ODP_CORE_S3_BUCKET_URI = 's3://geoglows-v2'
ODP_FORECAST_S3_BUCKET_URI = 's3://geoglows-v2-forecasts'
ODP_RETROSPECTIVE_S3_BUCKET_URI = 's3://geoglows-v2-retrospective'
ODP_S3_BUCKET_REGION = 'us-west-2'

METADATA_TABLE_PATH = os.getenv(
'PYGEOGLOWS_METADATA_TABLE_PATH',
os.path.join(os.path.dirname(__file__), 'data', 'metadata-tables.parquet')
Expand Down
230 changes: 230 additions & 0 deletions geoglows/_download_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import os
import warnings
from io import StringIO

import pandas as pd
import requests
import s3fs
import xarray as xr
import numpy as np

from .analyze import (
simple_forecast as calc_simple_forecast,
forecast_stats as calc_forecast_stats,
)

from ._constants import (
ODP_FORECAST_S3_BUCKET_URI,
ODP_RETROSPECTIVE_S3_BUCKET_URI,
ODP_S3_BUCKET_REGION,
)

DEFAULT_REST_ENDPOINT = 'https://geoglows.ecmwf.int/api/'
DEFAULT_REST_ENDPOINT_VERSION = 'v2' # 'v1, v2, latest'

__all__ = [
'_forecast',
'_retrospective',
]


def _forecast(function):
def from_aws(*args, **kwargs):
product_name = function.__name__.replace("_", "").lower()
if product_name == 'forecastrecords':
warnings.warn('forecast_records are not available from the AWS Open Data Program.')
return from_rest(*args, **kwargs)

river_id = kwargs.get('river_id', '')
river_id = args[0] if len(args) > 0 else river_id

return_format = kwargs.get('format', 'df')
assert return_format in ('df', 'xarray'), f'Unsupported return format requested: {return_format}'

if kwargs.get('skip_log', False):
requests.post(f'{DEFAULT_REST_ENDPOINT}{DEFAULT_REST_ENDPOINT_VERSION}/log',
json={'river_id': river_id, 'product': product_name, 'format': return_format},
timeout=1, ) # short timeout- don't need the response, post only needs to be received

s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION))
date = kwargs.get('date', False)
if not date:
zarr_vars = ['rivid', 'Qout', 'time', 'ensemble']
dates = [s3.glob(os.path.join(ODP_FORECAST_S3_BUCKET_URI, f'*.zarr/{var}')) for var in zarr_vars]
dates = [set([d.split('/')[1].replace('.zarr', '') for d in date]) for date in dates]
dates = sorted(set.intersection(*dates), reverse=True)
if product_name == 'dates':
return pd.DataFrame(dict(dates=dates))
date = dates[0]
if len(date) == 8:
date = f'{date}00.zarr'
elif len(date) == 10:
date = f'{date}.zarr'
else:
raise ValueError('Date must be YYYYMMDD or YYYYMMDDHH format. Use dates() to view available data.')

s3store = s3fs.S3Map(root=f'{ODP_FORECAST_S3_BUCKET_URI}/{date}', s3=s3, check=False)

attrs = {
'source': 'geoglows',
'forecast_date': date[:8],
'retrieval_date': pd.Timestamp.now().strftime('%Y%m%d'),
'units': 'cubic meters per second',
}
ds = xr.open_zarr(s3store).sel(rivid=river_id)
if return_format == 'xarray' and product_name == 'forecastensembles':
ds = ds.rename({'time': 'datetime', 'rivid': 'river_id'})
ds.attrs = attrs
return ds
df = ds.to_dataframe().round(2).reset_index()

# rename columns to match the REST API
if isinstance(river_id, int) or isinstance(river_id, np.int64):
df = df.pivot(index='time', columns='ensemble', values='Qout')
else:
df = df.pivot(index=['time', 'rivid'], columns='ensemble', values='Qout')
df.index.names = ['time', 'river_id']
df = df[sorted(df.columns)]
df.columns = [f'ensemble_{str(x).zfill(2)}' for x in df.columns]

if product_name == 'forecast':
df = calc_simple_forecast(df)
elif product_name == 'forecaststats':
df = calc_forecast_stats(df)

if return_format == 'df':
return df
ds = df.to_xarray()
ds.attrs = attrs
return ds

def from_rest(*args, **kwargs):
# update the default values set by the function unless the user has already specified them
for key, value in function.__kwdefaults__.items() if function.__kwdefaults__ else []:
if key not in kwargs:
kwargs[key] = value

return_format = kwargs.get('format', 'csv')
assert return_format in ('csv', 'json', 'url'), f'Unsupported format requested: {return_format}'

# parse out the information necessary to build a request url
endpoint = kwargs.get('endpoint', DEFAULT_REST_ENDPOINT)
endpoint = endpoint[:-1] if endpoint[-1] == '/' else endpoint
endpoint = endpoint + '/api' if not endpoint.endswith('/api') else endpoint
endpoint = f'https://{endpoint}' if not endpoint.startswith(('https://', 'http://')) else endpoint

version = kwargs.get('version', DEFAULT_REST_ENDPOINT_VERSION)
assert version in ('v2',), ValueError(f'Unrecognized model version parameter: {version}')

product_name = function.__name__.replace("_", "").lower()

river_id = args[0] if len(args) > 0 else None
river_id = kwargs.get('river_id', '') if not river_id else river_id
if isinstance(river_id, list):
raise ValueError('Multiple river_ids are not available via REST API or on v1. '
'Use data_source="aws" for multiple river_ids.')
river_id = int(river_id) if river_id else None
if river_id and version == 'v2':
assert 1_000_000_000 > river_id >= 110_000_000, ValueError('River ID must be a 9 digit integer')

# request parameter validation before submitting
for key in ('endpoint', 'version', 'river_id'):
if key in kwargs:
del kwargs[key]
for key, value in kwargs.items():
if value is None:
del kwargs[key]
for date in ('date', 'start_date', 'end_date'):
if date in kwargs:
assert len(str(kwargs[date])) == 8 or len(
str(kwargs[date])) == 10, f'Invalid date format: {kwargs[date]}'
if 'format' in kwargs and kwargs['format'] != 'json':
del kwargs['format']
kwargs['source'] = kwargs.get('source', 'pygeoglows') # allow using default for specific apps which override
params = '&'.join([f'{key}={value}' for key, value in kwargs.items()])

# piece together the request url
request_url = f'{endpoint}/{version}/{product_name}' # build the base url
request_url = f'{request_url}/{river_id}' if river_id else request_url # add the river_id if it exists
request_url = f'{request_url}?{params}' # add the query parameters

if return_format == 'url':
return request_url.replace(f'source={kwargs["source"]}', '')

response = requests.get(request_url)

if response.status_code != 200:
raise RuntimeError('Received an error from the REST API: ' + response.text)

if return_format == 'csv':
df = pd.read_csv(StringIO(response.text))
if 'datetime' in df.columns:
df['datetime'] = pd.to_datetime(df['datetime'])
df = df.set_index('datetime')
return df
elif return_format == 'json':
return response.json()
else:
raise ValueError(f'Unsupported return format requested: {return_format}')

def main(*args, **kwargs):
source = kwargs.get('data_source', 'aws')
assert source in ('rest', 'aws'), ValueError(f'Unrecognized data source requested: {source}')
if source == 'rest':
return from_rest(*args, **kwargs)
return from_aws(*args, **kwargs)

main.__doc__ = function.__doc__ # necessary for code documentation auto generators
return main


def _retrospective(function):
def main(*args, **kwargs):
product_name = function.__name__.replace("_", "-").lower()

river_id = args[0] if len(args) > 0 else None
river_id = kwargs.get('river_id', '') if not river_id else river_id

return_format = kwargs.get('format', 'df')
assert return_format in ('df', 'xarray'), f'Unsupported return format requested: {return_format}'

method = kwargs.get('method', 'gumbel1')

if kwargs.get('skip_log', False):
requests.post(f'{DEFAULT_REST_ENDPOINT}{DEFAULT_REST_ENDPOINT_VERSION}/log',
timeout=1, # short timeout because we don't need the response, post just needs to be received
json={'river_id': river_id, 'product': product_name, 'format': return_format})

s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION))
s3store = s3fs.S3Map(root=f'{ODP_RETROSPECTIVE_S3_BUCKET_URI}/{product_name}.zarr', s3=s3, check=False)
ds = xr.open_zarr(s3store)
try:
ds = ds.sel(rivid=river_id)
except Exception:
raise ValueError(f'River ID(s) not found in the retrospective dataset: {river_id}')
if return_format == 'xarray':
return ds
if product_name == 'retrospective':
return (
ds
.to_dataframe()
.reset_index()
.set_index('time')
.pivot(columns='rivid', values='Qout')
)
if product_name == 'return-periods':
rp_methods = {
'gumbel1': 'gumbel1_return_period',
}
assert method in rp_methods, f'Unrecognized return period estimation method given: {method}'
return (
ds
[rp_methods[method]]
.to_dataframe()
.reset_index()
.pivot(index='rivid', columns='return_period', values=rp_methods[method])
)
raise ValueError(f'Unsupported product requested: {product_name}')

main.__doc__ = function.__doc__ # necessary for code documentation auto generators
return main
Loading

0 comments on commit ee0f8c5

Please sign in to comment.