Skip to content

Commit

Permalink
refactor flatfile io WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rizac committed Sep 22, 2023
1 parent 8a74dee commit 6083f0b
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 101 deletions.
37 changes: 18 additions & 19 deletions egsim/smtk/flatfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .columns import (ColumnDtype, get_rupture_params,
get_dtypes_and_defaults, get_all_names_of,
get_intensity_measures, MissingColumn,
InvalidDataInColumn, InvalidColumnName, ConflictingColumns)
InvalidDataInColumn, InvalidColumnName, ConflictingColumns,
to_pandas_dtype, cast_to_dtype)
from .. import get_SA_period
from ...smtk.trellis.configure import vs30_to_z1pt0_cy14, vs30_to_z2pt5_cb14

Expand Down Expand Up @@ -99,9 +100,9 @@ def _check_dtypes_and_defaults(
-> tuple[dict[str, ColumnDtype], dict[str, Any]]:

for col, col_dtype in dtype.items():
dtype[col] = ColumnDtype.get(col_dtype)
dtype[col] = to_pandas_dtype(col_dtype)
if col in defaults:
defaults[col] = ColumnDtype.cast(defaults[col], dtype[col])
defaults[col] = cast_to_dtype(defaults[col], dtype[col])

if set(defaults) - set(dtype):
raise ValueError('Invalid defaults for columns without an explicit '
Expand Down Expand Up @@ -173,14 +174,14 @@ def _read_csv(filepath_or_buffer: Union[str, IO], **kwargs) -> pd.DataFrame:
# int will be parsed as floats, so that NaN are still possible
# (e.g. '' in CSV) whilst raising in case of invalid values ('x').
# The conversion to int will be handled later
kwargs['dtype'][col] = ColumnDtype.float.type_str
kwargs['dtype'][col] = ColumnDtype.float
elif col_dtype == ColumnDtype.datetime:
# date times in pandas csv must be given in a separate arg. Note that
# read_csv does not raise for invalid dates but returns the column
# with an inferred data type (usually object)
parse_dates.add(col)
else:
kwargs['dtype'][col] = col_dtype.type_str
kwargs['dtype'][col] = col_dtype

if parse_dates:
kwargs['parse_dates'] = list(parse_dates)
Expand All @@ -194,8 +195,7 @@ def _read_csv(filepath_or_buffer: Union[str, IO], **kwargs) -> pd.DataFrame:
# explicitly passed and int dtypes are converted to float (see above). So
# just check for floats. Note: get dtypes from kwargs['dtype'] because we
# want to check the real dtype passed, all given as numpy str:
cols2check = [c for c, v in kwargs['dtype'].items()
if v == ColumnDtype.float.type_str]
cols2check = [c for c, v in kwargs['dtype'].items() if v == ColumnDtype.float]
invalid_columns = []
for c in cols2check:
try:
Expand Down Expand Up @@ -255,17 +255,16 @@ def _adjust_dtypes_and_defaults(dfr: pd.DataFrame,
elif expected_dtype == ColumnDtype.bool:
not_na = pd.notna(dfr[col])
unique_vals = pd.unique(dfr[col][not_na])
if actual_dtype == ColumnDtype.str:
mapping = {}
for val in unique_vals:
if isinstance(val, str):
if val.lower() in {'0', 'false'}:
mapping[val] = False
elif val.lower() in {'1', 'true'}:
mapping[val] = True
if mapping:
dfr[col].replace(mapping, inplace=True)
unique_vals = pd.unique(dfr[col][not_na])
mapping = {}
for val in unique_vals:
if isinstance(val, str):
if val.lower() in {'0', 'false'}:
mapping[val] = False
elif val.lower() in {'1', 'true'}:
mapping[val] = True
if mapping:
dfr[col].replace(mapping, inplace=True)
unique_vals = pd.unique(dfr[col][not_na])
if set(unique_vals).issubset({0, 1}):
dtype_ok = True
do_type_cast = True
Expand All @@ -284,7 +283,7 @@ def _adjust_dtypes_and_defaults(dfr: pd.DataFrame,
dtype_ok = False
elif do_type_cast:
try:
dfr[col] = dfr[col].astype(expected_dtype.type_str)
dfr[col] = dfr[col].astype(expected_dtype)
except (ValueError, TypeError):
dtype_ok = False

Expand Down
156 changes: 85 additions & 71 deletions egsim/smtk/flatfile/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import re
from datetime import datetime, date
from enum import Enum
from enum import Enum, StrEnum
from os.path import join, dirname

from pandas.core.arrays import PandasArray
from pandas.core.dtypes.base import ExtensionDtype
from typing import Union, Any

# try to speed up yaml.safe_load (https://pyyaml.org/wiki/PyYAMLDocumentation):
Expand All @@ -31,29 +33,28 @@ class ColumnType(Enum):
intensity = 'Intensity measure'


class ColumnDtype(Enum):
"""Flatfile column data type. Names are used as dtype values in
the YAML file (note that categorical dtypes have to be given as list),
enum values are the relative Python/numpy classes to be used in Python code.
E.g., to get if the dtype of flatfile column `c` (pandas Series) is supported:
```
isinstance(c.dtype, pd.CategoricalDtype) or \
any(issubclass(c.dtype.type, e.value) for e in ColumnDtype)
```
class ColumnDtype(StrEnum):
"""Enum denoting the supported data types for flatfile columns:
each member behaves exactly as a string compatible to pandas `astype`,
e.g.: `[series|array].astype(ColumnDtype.datetime)`
"""
# NOTE: the FIRST VALUE MUST BE THE PYTHON TYPE (e.g. int, not np.int64), AS
# IT WILL BE USED TO CHECK THE CONSISTENCY OF DEFAULT / BOUNDS IN THE YAML
float = float, np.floating
int = int, np.integer
bool = bool, np.bool_
datetime = datetime, np.datetime64
str = str, np.str_, np.object_

@property
def type_str(self):
"""Return the Python class denoting this enum item. The value can be used
in numpy `astype` functions to cast values"""
return 'datetime64' if self == ColumnDtype.datetime else self.name
def __new__(cls, value, *classes:type):
"""Constructs a new ColumnDtype member"""
obj = str.__new__(cls, value)
obj._value_ = classes
return obj

# each member below must be mapped to the numpy name (see `numpy.sctypeDict.keys()`
# for a list of supported names. Exception is 'string' that is pandas only) and
# one or more Python classes that will be used
# to get if an object (Python, numpy or pandas) is instance of a particular
# `ColumnDtype` (see `ColumnDtype.of`)

float = "float", float, np.floating
int = "int", int, np.integer
bool = "bool", bool, np.bool_
datetime = "datetime64", datetime, np.datetime64
str = "string", str, np.str_ # , np.object_

@classmethod
def of(cls, obj: Union[int, float, datetime, bool, str,
Expand All @@ -63,9 +64,9 @@ def of(cls, obj: Union[int, float, datetime, bool, str,
-> Union[ColumnDtype, None]:
"""Return the ColumnDtype of the given argument
:param obj: a Python object( e.g. 4.5), a Python class (`float`),
:param obj: a Python object(e.g. 4.5), a Python class (`float`),
a numpy array or pandas Series, a numpy dtype
(e.g as returned from a pandas dataframe `dataframe[column].dtype`)
(e.g., as returned from a pandas dataframe `dataframe[column].dtype`)
or a pandas CategoricalDtype. In this last case, return the ColumnDtype
of all categories, if they are of the same type. E.g.:
`ColumnDtype.of(pd.CategoricalDtype([1,2]) = ColumnDtype.int`
Expand All @@ -76,9 +77,9 @@ def of(cls, obj: Union[int, float, datetime, bool, str,
if len(dtypes) == 1:
return next(iter(dtypes))
else:
if isinstance(obj, (pd.Series, np.ndarray)):
if isinstance(obj, (pd.Series, np.ndarray, PandasArray)):
obj = obj.dtype # will fall back in the next "if"
if isinstance(obj, np.dtype):
if isinstance(obj, (np.dtype, ExtensionDtype)):
obj = obj.type # will NOT fall back into the next "if"
if not isinstance(obj, type):
obj = type(obj)
Expand All @@ -91,46 +92,59 @@ def of(cls, obj: Union[int, float, datetime, bool, str,
return c_dtype
return None

@classmethod
def get(cls, dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]) \
-> Union[ColumnDtype, pd.CategoricalDtype]:
"""Return the ColumnDtype or the pandas CategoricalDtype from the given
argument, converting it if necessary
"""
try:
if isinstance(dtype, ColumnDtype):
return dtype
if isinstance(dtype, (list, tuple)):
dtype = pd.CategoricalDtype(dtype)
if isinstance(dtype, pd.CategoricalDtype):
# check that the categories are all of the same supported data type:
if ColumnDtype.of(dtype) is not None:
return dtype
else:
return ColumnDtype[dtype]
except (KeyError, ValueError, TypeError):
pass
raise ValueError(f'Invalid flatfile data type: {str(dtype)}')
def __repr__(self):
# fix error in repr(self) expecting value to be a string
return self.__str__()

@classmethod
def cast(cls, val: Any, dtype: Union[ColumnDtype, pd.CategoricalDtype]) -> Any:
"""cast `val` to the given dtype or pandas CategoricalDtype, raise
ValueError if unsuccessful
"""



def to_pandas_dtype(dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]) \
-> Union[ColumnDtype, pd.CategoricalDtype]:
"""Return a value from the given argument that is suitable to be used as data type in
pandas, i.e., either a `ColumnDtype` (str-like enum) or a pandas `CategoricalDtype`
"""
try:
if isinstance(dtype, ColumnDtype):
return dtype
if isinstance(dtype, str):
for val in ColumnDtype:
if val == dtype or val.name == dtype:
return val
if isinstance(dtype, (list, tuple)):
dtype = pd.CategoricalDtype(dtype)
if isinstance(dtype, pd.CategoricalDtype):
if val in dtype.categories:
return val
dtype_name = 'categorical'
else:
actual_dtype = ColumnDtype.of(val)
if actual_dtype == dtype:
return val
if dtype == ColumnDtype.float and actual_dtype == ColumnDtype.int:
return float(val)
elif dtype == ColumnDtype.datetime and isinstance(val, date):
return datetime(val.year, val.month, val.day)
dtype_name = dtype.name
raise ValueError(f'Invalid value for type {dtype_name}: {str(val)}')
# check that the categories are all of the same supported data type:
if ColumnDtype.of(dtype) is not None:
return dtype
except (KeyError, ValueError, TypeError):
pass
raise ValueError(f'Invalid data type: {str(dtype)}')


def cast_to_dtype(val: Any, pd_dtype: Union[ColumnDtype, pd.CategoricalDtype]) -> Any:
"""cast `val` to the given pandas dtype, raise ValueError if unsuccessful
:param val: any Python object
:param pd_dtype: the result of `to_pandas_dtype: either a `ColumnDtype` or
pandas `CategoricalDtype` object
"""
if pd_dtype is None:
raise ValueError(f'Invalid dtype: {str(pd_dtype)}')
if isinstance(pd_dtype, pd.CategoricalDtype):
if val in pd_dtype.categories:
return val
dtype_name = 'categorical'
else:
actual_dtype = ColumnDtype.of(val)
if actual_dtype == pd_dtype:
return val
if pd_dtype == ColumnDtype.float and actual_dtype == ColumnDtype.int:
return float(val)
elif pd_dtype == ColumnDtype.datetime and isinstance(val, date):
return datetime(val.year, val.month, val.day)
dtype_name = pd_dtype.name
raise ValueError(f'Invalid value for type {dtype_name}: {str(val)}')


def get_rupture_params() -> set[str]:
Expand Down Expand Up @@ -327,10 +341,10 @@ def _extract_from_columns(columns: dict[str, dict[str, Any]], *,
else:
aliases = set(aliases)
aliases.add(c_name)
cdtype = None
pd_dtype = None
if ((dtype is not None or default is not None or bounds is not None)
and 'dtype' in props):
cdtype = ColumnDtype.get(props['dtype'])
pd_dtype = to_pandas_dtype(props['dtype'])
for name in aliases:
if check_type and 'type' in props:
ctype = ColumnType[props['type']]
Expand All @@ -344,12 +358,12 @@ def _extract_from_columns(columns: dict[str, dict[str, Any]], *,
imts.add(name)
if alias is not None:
alias[name] = aliases
if dtype is not None and cdtype is not None:
dtype[name] = cdtype
if dtype is not None and pd_dtype is not None:
dtype[name] = pd_dtype
if default is not None and 'default' in props:
default[name] = ColumnDtype.cast(props['default'], cdtype)
default[name] = cast_to_dtype(props['default'], pd_dtype)
if bounds is not None:
_bounds = {k: ColumnDtype.cast(props[k], cdtype)
_bounds = {k: cast_to_dtype(props[k], pd_dtype)
for k in ["<", "<=", ">", ">="]
if k in props}
if _bounds:
Expand Down
15 changes: 9 additions & 6 deletions tests/smtk/flatfile/test_flatfile_columns_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

from egsim.smtk import get_gsim_names, get_rupture_params_required_by, \
get_sites_params_required_by, get_distances_required_by
from egsim.smtk.flatfile.columns import (ColumnType, ColumnDtype, _extract_from_columns,
from egsim.smtk.flatfile import cast_to_dtype
from egsim.smtk.flatfile.columns import (ColumnType, ColumnDtype,
_extract_from_columns,
_ff_metadata_path)


Expand Down Expand Up @@ -138,7 +140,7 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None],
raise ValueError(f"{prefix} bounds cannot be provided with "
f"categorical data type")
if default_is_given:
ColumnDtype.cast(default, dtype) # raise if not in categories
cast_to_dtype(default, dtype) # raise if not in categories
return

assert isinstance(dtype, ColumnDtype)
Expand All @@ -158,15 +160,15 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None],
min_val = bounds.get(">", bounds.get(">=", None))
for val in [max_val, min_val]:
if val is not None:
ColumnDtype.cast(val, dtype)
cast_to_dtype(val, dtype)
if max_val is not None and min_val is not None and max_val <= min_val:
raise ValueError(f'{prefix} min. bound must be lower than '
f'max. bound')

# check default value:
if default_is_given:
# this should already been done but for dafety:
ColumnDtype.cast(default, dtype)
cast_to_dtype(default, dtype)


def check_with_openquake(rupture_params: dict[str, set[str]],
Expand Down Expand Up @@ -224,12 +226,13 @@ def test_Column_dtype():
'bool': [True, False],
'str': [None, "x"]
})
d.str = d.str.astype('string')
for c in d.columns:
dtyp = d[c].dtype
assert ColumnDtype.of(dtyp) == ColumnDtype[c]
assert ColumnDtype.of(d[c]) == ColumnDtype[c]
assert ColumnDtype.of(d[c].values) == ColumnDtype[c]
assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c] if _ is not None)
assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c].values if _ is not None)
assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c] if pd.notna(_))
assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c].values if pd.notna(_))

assert ColumnDtype.of(None) is None
Loading

0 comments on commit 6083f0b

Please sign in to comment.