Skip to content

Commit

Permalink
loosen versions, correct warnings, xarray datasets unfiltered by rive…
Browse files Browse the repository at this point in the history
…r number
  • Loading branch information
rileyhales committed May 1, 2024
1 parent 0d264ca commit 5be2a01
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
author = 'Riley Hales, PhD'

# The full version, including alpha/beta/rc tags
release = '1.2.0'
release = '1.4.0'
master_doc = 'index'

# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion geoglows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
'bias', 'plots', 'data', 'analyze', 'streams', 'tables', 'streamflow',
'METADATA_TABLE_PATH'
]
__version__ = '1.3.0'
__version__ = '1.4.0'
__author__ = 'Riley Hales'
__license__ = 'BSD 3-Clause Clear License'
35 changes: 20 additions & 15 deletions geoglows/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def from_aws(*args, **kwargs):
s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION))
date = kwargs.get('date', False)
if not date:
dates = sorted([x.split('/')[-1] for x in s3.ls(ODP_FORECAST_S3_BUCKET_URI)], reverse=True)
dates = [x.split('.')[0] for x in dates if x.endswith('.zarr')] # ignore the index.html file
dates = [x.replace('00.zarr', '') for x in dates]
zarr_vars = ['rivid', 'Qout', 'time', 'ensemble']
dates = [s3.glob(os.path.join(ODP_FORECAST_S3_BUCKET_URI, f'*.zarr/{var}')) for var in zarr_vars]
dates = [set([d.split('/')[1].replace('.zarr', '') for d in date]) for date in dates]
dates = sorted(set.intersection(*dates), reverse=True)
if product_name == 'dates':
return pd.DataFrame(dict(dates=dates))
date = dates[0]
Expand Down Expand Up @@ -121,7 +122,7 @@ def from_rest(*args, **kwargs):
endpoint = f'https://{endpoint}' if not endpoint.startswith(('https://', 'http://')) else endpoint

version = kwargs.get('version', DEFAULT_REST_ENDPOINT_VERSION)
assert version in ('v2', ), ValueError(f'Unrecognized model version parameter: {version}')
assert version in ('v2',), ValueError(f'Unrecognized model version parameter: {version}')

product_name = function.__name__.replace("_", "").lower()

Expand Down Expand Up @@ -180,6 +181,7 @@ def main(*args, **kwargs):
if source == 'rest':
return from_rest(*args, **kwargs)
return from_aws(*args, **kwargs)

main.__doc__ = function.__doc__ # necessary for code documentation auto generators
return main

Expand Down Expand Up @@ -290,15 +292,10 @@ def retrospective(river_id: int or list, format: str = 'df') -> pd.DataFrame or
"""
s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION))
s3store = s3fs.S3Map(root=f'{ODP_RETROSPECTIVE_S3_BUCKET_URI}/retrospective.zarr', s3=s3, check=False)
ds = xr.open_zarr(s3store).sel(rivid=river_id)
ds = xr.open_zarr(s3store)
if format == 'xarray':
return ds
return ds.to_dataframe().reset_index().set_index('time').pivot(columns='rivid', values='Qout')


def historical(*args, **kwargs):
"""Alias for retrospective"""
return retrospective(*args, **kwargs)
return ds.sel(rivid=river_id).to_dataframe().reset_index().set_index('time').pivot(columns='rivid', values='Qout')


def daily_averages(river_id: int or list) -> pd.DataFrame:
Expand Down Expand Up @@ -343,24 +340,32 @@ def annual_averages(river_id: int or list) -> pd.DataFrame:
return calc_annual_averages(df)


def return_periods(river_id: int or list, format: str = 'df') -> pd.DataFrame or xr.Dataset:
def return_periods(river_id: int or list, format: str = 'df', method: str = 'gumbel1') -> pd.DataFrame or xr.Dataset:
"""
Retrieves the return period thresholds based on a specified historic simulation forcing on a certain river_id.
Args:
river_id (int): the ID of a stream, should be a 9 digit integer
format (str): the format to return the data, either 'df' or 'xarray'. default is 'df'
method (str): the method to use to estimate the return period thresholds. default is 'gumbel1'
Changelog:
v1.4.0: adds method parameter for future expansion of multiple return period methods
Returns:
pd.DataFrame
"""
rp_methods = {
'gumbel1': 'gumbel1_return_period',
}
assert method in rp_methods, f'Unrecognized return period estimation method given: {method}'
s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(region_name=ODP_S3_BUCKET_REGION))
s3store = s3fs.S3Map(root=f'{ODP_RETROSPECTIVE_S3_BUCKET_URI}/return-periods.zarr', s3=s3, check=False)
ds = xr.open_zarr(s3store).sel(rivid=river_id)
ds = xr.open_zarr(s3store)
if format == 'xarray':
return ds
return (ds['return_period_flow'].to_dataframe().reset_index()
.pivot(index='rivid', columns='return_period', values='return_period_flow'))
return (ds.sel(rivid=river_id)[rp_methods[method]].to_dataframe().reset_index()
.pivot(index='rivid', columns='return_period', values=rp_methods[method]))


# model config and supplementary data
Expand Down
18 changes: 9 additions & 9 deletions geoglows/streamflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forecast_stats(reach_id: int, return_format: str = 'csv', forecast_date: str
data = streamflow.rst.forecast_stats(12341234)
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'ForecastStats/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -113,7 +113,7 @@ def forecast_warnings(region: str = 'all', return_format='csv',
data = streamflow.rst.forecast_warnings('australia-geoglows')
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'ForecastWarnings/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -148,7 +148,7 @@ def forecast_records(reach_id: int, start_date: str = None, end_date: str = None
data = streamflow.rst.forecast_warnings('australia-geoglows')
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'ForecastRecords/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -188,7 +188,7 @@ def historic_simulation(reach_id: int, return_format='csv', forcing='era_5',
data = streamflow.rst.historic_simulation(12341234)
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'HistoricSimulation/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -223,7 +223,7 @@ def daily_averages(reach_id: int, return_format='csv', forcing='era_5',
data = streamflow.rst.seasonal_average(12341234)
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'DailyAverages/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -258,7 +258,7 @@ def monthly_averages(reach_id: int, return_format='csv', forcing='era_5',
data = streamflow.rst.seasonal_average(12341234)
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'MonthlyAverages/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -293,7 +293,7 @@ def return_periods(reach_id: int, return_format='csv', forcing='era_5',
data = streamflow.rst.return_periods(12341234)
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'ReturnPeriods/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -324,7 +324,7 @@ def available_data(endpoint: str = ENDPOINT, return_format='json', s: requests.S
data = streamflow.rst.available_data()
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'AvailableData/'

# if you only wanted the url, quit here
Expand Down Expand Up @@ -356,7 +356,7 @@ def available_dates(reach_id: int = None, region: str = None, return_format: str
data = streamflow.rst.available_dates(12341234)
"""
warnings.warn(DEPRECATIONWARNING, DeprecationWarning)
warnings.warn(DEPRECATIONWARNING, DeprecationWarning, stacklevel=2)
method = 'AvailableDates/'

# you need a region for the api call, so the user needs to provide one or a valid reach_id to get it from
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
dask>=2024
dask>=2022
fastparquet
requests
pandas>=1
plotly>=5
scipy>=1
s3fs>=2024
s3fs>=2022
numpy>=1
hydrostats
HydroErr
xarray>=2024
xarray>=2022
zarr
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
DESCRIPTION = 'Package for accessing data from the GEOGLOWS Hydrological Model'
URL = 'https://data.geoglows.org'
AUTHOR = 'Riley Hales PhD'
REQUIRES_PYTHON = '>=3.10.0'
REQUIRES_PYTHON = '>=3.7.0'
LICENSE = 'BSD 3-Clause Clear License'

with open("README.md", "r") as readme:
Expand Down

0 comments on commit 5be2a01

Please sign in to comment.