Skip to content

Commit

Permalink
refactor flatfile io WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rizac committed Sep 23, 2023
1 parent 35656bf commit aa512cc
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 44 deletions.
8 changes: 6 additions & 2 deletions egsim/smtk/flatfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_dtypes_and_defaults, get_all_names_of,
get_intensity_measures, MissingColumn,
InvalidDataInColumn, InvalidColumnName, ConflictingColumns,
check_dtypes_defaults_and_bounds)
cast_dtype, cast_value)
from .. import get_SA_period
from ...smtk.trellis.configure import vs30_to_z1pt0_cy14, vs30_to_z2pt5_cb14

Expand Down Expand Up @@ -85,8 +85,12 @@ def read_csv(filepath_or_buffer: Union[str, IO],
dtype = {}
if not defaults:
defaults = {}
# hidden arg defaulting True: check and cast dtype and defaults:
if kwargs.pop('_check_dtypes_and_defaults', True):
dtype, defaults = check_dtypes_defaults_and_bounds(dtype, defaults)
for col, _dtype in dtype.items():
dtype[col] = cast_dtype(_dtype)
if col in defaults:
defaults[col] = cast_value(defaults[col], dtype[col])
dfr = _read_csv(filepath_or_buffer, sep=sep, dtype=dtype, usecols=usecols, **kwargs)
_adjust_dtypes_and_defaults(dfr, dtype, defaults)
if not isinstance(dfr.index, pd.RangeIndex):
Expand Down
50 changes: 13 additions & 37 deletions egsim/smtk/flatfile/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re
from datetime import datetime, date
from enum import Enum, StrEnum, ReprEnum
from enum import Enum, ReprEnum
from os.path import join, dirname

from pandas.core.arrays import PandasArray
Expand Down Expand Up @@ -223,16 +223,14 @@ def _extract_from_columns(columns: dict[str, dict[str, Any]], *,
"""
check_type = rupture_params is not None or sites_params is not None \
or distances is not None or imts is not None
get_dtype = dtype is not None or default is not None or bounds is not None
if get_dtype and dtype is None:
dtype = {}
for c_name, props in columns.items():
aliases = props.get('alias', [])
if isinstance(aliases, str):
aliases = {aliases}
else:
aliases = set(aliases)
aliases.add(c_name)
_dtype = None # cache value, see below
for name in aliases:
if check_type and 'type' in props:
ctype = ColumnType[props['type']]
Expand All @@ -247,44 +245,22 @@ def _extract_from_columns(columns: dict[str, dict[str, Any]], *,
if alias is not None:
alias[name] = aliases
if dtype is not None and 'dtype' in props:
dtype[name] = props['dtype']
dtype[name] = _dtype = cast_dtype(props['dtype'])
if default is not None and 'default' in props:
default[name] = props['default']
default[name] = props['default'] if dtype is None else \
cast_value(props['default'], _dtype)
if bounds is not None:
_bounds = {k: props[k]
for k in ["<", "<=", ">", ">="]
if k in props}
if _bounds:
bounds[name] = _bounds
keys = [k for k in ["<", "<=", ">", ">="] if k in props]
if keys:
bounds[name] = {}
for k in keys:
bounds[name][k] = props[k] if _dtype is None else \
cast_value(props[k], _dtype)
if help is not None and props.get('help', ''):
help[name] = props['help']

if dtype is not None and (default is not None or bounds is not None):
check_dtypes_defaults_and_bounds(dtype, default, bounds)


def check_dtypes_defaults_and_bounds(
dtype: dict[str, Union[str, list, tuple, pd.CategoricalDtype, ColumnDtype]],
defaults: dict[str, Any] = None,
bounds: dict[str, dict[str, Any]] = None) \
-> tuple[dict[str, Union[pd.CategoricalDtype, ColumnDtype]], dict[str, Any]]:

for col, col_dtype in dtype.items():
dtype[col] = _cast_dtype(col_dtype)

if defaults:
for col in defaults:
defaults[col] = _cast_value(defaults[col], dtype[col])

if bounds:
for col, col_bounds in bounds.items():
for key, val in col_bounds.items():
col_bounds[key] = _cast_value(val, dtype[col])

return dtype, defaults


def _cast_dtype(dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]) \
def cast_dtype(dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]) \
-> Union[ColumnDtype, pd.CategoricalDtype]:
"""Return a value from the given argument that is suitable to be used as data type in
pandas, i.e., either a `ColumnDtype` (str-like enum) or a pandas `CategoricalDtype`
Expand All @@ -308,7 +284,7 @@ def _cast_dtype(dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]
raise ValueError(f'Invalid data type: {str(dtype)}')


def _cast_value(val: Any, pd_dtype: Union[ColumnDtype, pd.CategoricalDtype]) -> Any:
def cast_value(val: Any, pd_dtype: Union[ColumnDtype, pd.CategoricalDtype]) -> Any:
"""cast `val` to the given pandas dtype, raise ValueError if unsuccessful
:param val: any Python object
Expand Down
9 changes: 4 additions & 5 deletions tests/smtk/flatfile/test_flatfile_columns_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_sites_params_required_by, get_distances_required_by
from egsim.smtk.flatfile.columns import (ColumnType, ColumnDtype,
_extract_from_columns,
_ff_metadata_path, _cast_value)
_ff_metadata_path, cast_value)


def test_flatfile_extract_from_yaml():
Expand Down Expand Up @@ -139,7 +139,7 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None],
raise ValueError(f"{prefix} bounds cannot be provided with "
f"categorical data type")
if default_is_given:
_cast_value(default, dtype) # raise if not in categories
assert default is cast_value(default, dtype) # raise if not in categories
return

assert isinstance(dtype, ColumnDtype)
Expand All @@ -159,15 +159,14 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None],
min_val = bounds.get(">", bounds.get(">=", None))
for val in [max_val, min_val]:
if val is not None:
_cast_value(val, dtype)
assert val is cast_value(val, dtype)
if max_val is not None and min_val is not None and max_val <= min_val:
raise ValueError(f'{prefix} min. bound must be lower than '
f'max. bound')

# check default value:
if default_is_given:
# this should already been done but for dafety:
_cast_value(default, dtype)
assert default is cast_value(default, dtype)


def check_with_openquake(rupture_params: dict[str, set[str]],
Expand Down

0 comments on commit aa512cc

Please sign in to comment.