diff --git a/egsim/smtk/flatfile/__init__.py b/egsim/smtk/flatfile/__init__.py index 72737f15..85548d8f 100644 --- a/egsim/smtk/flatfile/__init__.py +++ b/egsim/smtk/flatfile/__init__.py @@ -18,7 +18,8 @@ from .columns import (ColumnDtype, get_rupture_params, get_dtypes_and_defaults, get_all_names_of, get_intensity_measures, MissingColumn, - InvalidDataInColumn, InvalidColumnName, ConflictingColumns) + InvalidDataInColumn, InvalidColumnName, ConflictingColumns, + to_pandas_dtype, cast_to_dtype) from .. import get_SA_period from ...smtk.trellis.configure import vs30_to_z1pt0_cy14, vs30_to_z2pt5_cb14 @@ -99,9 +100,9 @@ def _check_dtypes_and_defaults( -> tuple[dict[str, ColumnDtype], dict[str, Any]]: for col, col_dtype in dtype.items(): - dtype[col] = ColumnDtype.get(col_dtype) + dtype[col] = to_pandas_dtype(col_dtype) if col in defaults: - defaults[col] = ColumnDtype.cast(defaults[col], dtype[col]) + defaults[col] = cast_to_dtype(defaults[col], dtype[col]) if set(defaults) - set(dtype): raise ValueError('Invalid defaults for columns without an explicit ' @@ -173,14 +174,14 @@ def _read_csv(filepath_or_buffer: Union[str, IO], **kwargs) -> pd.DataFrame: # int will be parsed as floats, so that NaN are still possible # (e.g. '' in CSV) whilst raising in case of invalid values ('x'). # The conversion to int will be handled later - kwargs['dtype'][col] = ColumnDtype.float.type_str + kwargs['dtype'][col] = ColumnDtype.float elif col_dtype == ColumnDtype.datetime: # date times in pandas csv must be given in a separate arg. Note that # read_csv does not raise for invalid dates but returns the column # with an inferred data type (usually object) parse_dates.add(col) else: - kwargs['dtype'][col] = col_dtype.type_str + kwargs['dtype'][col] = col_dtype if parse_dates: kwargs['parse_dates'] = list(parse_dates) @@ -194,8 +195,7 @@ def _read_csv(filepath_or_buffer: Union[str, IO], **kwargs) -> pd.DataFrame: # explicitly passed and int dtypes are converted to float (see above). So # just check for floats. Note: get dtypes from kwargs['dtype'] because we # want to check the real dtype passed, all given as numpy str: - cols2check = [c for c, v in kwargs['dtype'].items() - if v == ColumnDtype.float.type_str] + cols2check = [c for c, v in kwargs['dtype'].items() if v == ColumnDtype.float] invalid_columns = [] for c in cols2check: try: @@ -255,17 +255,16 @@ def _adjust_dtypes_and_defaults(dfr: pd.DataFrame, elif expected_dtype == ColumnDtype.bool: not_na = pd.notna(dfr[col]) unique_vals = pd.unique(dfr[col][not_na]) - if actual_dtype == ColumnDtype.str: - mapping = {} - for val in unique_vals: - if isinstance(val, str): - if val.lower() in {'0', 'false'}: - mapping[val] = False - elif val.lower() in {'1', 'true'}: - mapping[val] = True - if mapping: - dfr[col].replace(mapping, inplace=True) - unique_vals = pd.unique(dfr[col][not_na]) + mapping = {} + for val in unique_vals: + if isinstance(val, str): + if val.lower() in {'0', 'false'}: + mapping[val] = False + elif val.lower() in {'1', 'true'}: + mapping[val] = True + if mapping: + dfr[col].replace(mapping, inplace=True) + unique_vals = pd.unique(dfr[col][not_na]) if set(unique_vals).issubset({0, 1}): dtype_ok = True do_type_cast = True @@ -284,7 +283,7 @@ def _adjust_dtypes_and_defaults(dfr: pd.DataFrame, dtype_ok = False elif do_type_cast: try: - dfr[col] = dfr[col].astype(expected_dtype.type_str) + dfr[col] = dfr[col].astype(expected_dtype) except (ValueError, TypeError): dtype_ok = False diff --git a/egsim/smtk/flatfile/columns.py b/egsim/smtk/flatfile/columns.py index d40ad9d9..2aa69e62 100644 --- a/egsim/smtk/flatfile/columns.py +++ b/egsim/smtk/flatfile/columns.py @@ -6,9 +6,11 @@ import re from datetime import datetime, date -from enum import Enum +from enum import Enum, StrEnum from os.path import join, dirname +from pandas.core.arrays import PandasArray +from pandas.core.dtypes.base import ExtensionDtype from typing import Union, Any # try to speed up yaml.safe_load (https://pyyaml.org/wiki/PyYAMLDocumentation): @@ -31,29 +33,28 @@ class ColumnType(Enum): intensity = 'Intensity measure' -class ColumnDtype(Enum): - """Flatfile column data type. Names are used as dtype values in - the YAML file (note that categorical dtypes have to be given as list), - enum values are the relative Python/numpy classes to be used in Python code. - E.g., to get if the dtype of flatfile column `c` (pandas Series) is supported: - ``` - isinstance(c.dtype, pd.CategoricalDtype) or \ - any(issubclass(c.dtype.type, e.value) for e in ColumnDtype) - ``` +class ColumnDtype(StrEnum): + """Enum denoting the supported data types for flatfile columns: + each member behaves exactly as a string compatible to pandas `astype`, + e.g.: `[series|array].astype(ColumnDtype.datetime)` """ - # NOTE: the FIRST VALUE MUST BE THE PYTHON TYPE (e.g. int, not np.int64), AS - # IT WILL BE USED TO CHECK THE CONSISTENCY OF DEFAULT / BOUNDS IN THE YAML - float = float, np.floating - int = int, np.integer - bool = bool, np.bool_ - datetime = datetime, np.datetime64 - str = str, np.str_, np.object_ - - @property - def type_str(self): - """Return the Python class denoting this enum item. The value can be used - in numpy `astype` functions to cast values""" - return 'datetime64' if self == ColumnDtype.datetime else self.name + def __new__(cls, value, *classes:type): + """Constructs a new ColumnDtype member""" + obj = str.__new__(cls, value) + obj._value_ = classes + return obj + + # each member below must be mapped to the numpy name (see `numpy.sctypeDict.keys()` + # for a list of supported names. Exception is 'string' that is pandas only) and + # one or more Python classes that will be used + # to get if an object (Python, numpy or pandas) is instance of a particular + # `ColumnDtype` (see `ColumnDtype.of`) + + float = "float", float, np.floating + int = "int", int, np.integer + bool = "bool", bool, np.bool_ + datetime = "datetime64", datetime, np.datetime64 + str = "string", str, np.str_ # , np.object_ @classmethod def of(cls, obj: Union[int, float, datetime, bool, str, @@ -63,9 +64,9 @@ def of(cls, obj: Union[int, float, datetime, bool, str, -> Union[ColumnDtype, None]: """Return the ColumnDtype of the given argument - :param obj: a Python object( e.g. 4.5), a Python class (`float`), + :param obj: a Python object(e.g. 4.5), a Python class (`float`), a numpy array or pandas Series, a numpy dtype - (e.g as returned from a pandas dataframe `dataframe[column].dtype`) + (e.g., as returned from a pandas dataframe `dataframe[column].dtype`) or a pandas CategoricalDtype. In this last case, return the ColumnDtype of all categories, if they are of the same type. E.g.: `ColumnDtype.of(pd.CategoricalDtype([1,2]) = ColumnDtype.int` @@ -76,9 +77,9 @@ def of(cls, obj: Union[int, float, datetime, bool, str, if len(dtypes) == 1: return next(iter(dtypes)) else: - if isinstance(obj, (pd.Series, np.ndarray)): + if isinstance(obj, (pd.Series, np.ndarray, PandasArray)): obj = obj.dtype # will fall back in the next "if" - if isinstance(obj, np.dtype): + if isinstance(obj, (np.dtype, ExtensionDtype)): obj = obj.type # will NOT fall back into the next "if" if not isinstance(obj, type): obj = type(obj) @@ -91,46 +92,59 @@ def of(cls, obj: Union[int, float, datetime, bool, str, return c_dtype return None - @classmethod - def get(cls, dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]) \ - -> Union[ColumnDtype, pd.CategoricalDtype]: - """Return the ColumnDtype or the pandas CategoricalDtype from the given - argument, converting it if necessary - """ - try: - if isinstance(dtype, ColumnDtype): - return dtype - if isinstance(dtype, (list, tuple)): - dtype = pd.CategoricalDtype(dtype) - if isinstance(dtype, pd.CategoricalDtype): - # check that the categories are all of the same supported data type: - if ColumnDtype.of(dtype) is not None: - return dtype - else: - return ColumnDtype[dtype] - except (KeyError, ValueError, TypeError): - pass - raise ValueError(f'Invalid flatfile data type: {str(dtype)}') + def __repr__(self): + # fix error in repr(self) expecting value to be a string + return self.__str__() - @classmethod - def cast(cls, val: Any, dtype: Union[ColumnDtype, pd.CategoricalDtype]) -> Any: - """cast `val` to the given dtype or pandas CategoricalDtype, raise - ValueError if unsuccessful - """ + + + +def to_pandas_dtype(dtype: Union[list, tuple, str, pd.CategoricalDtype, ColumnDtype]) \ + -> Union[ColumnDtype, pd.CategoricalDtype]: + """Return a value from the given argument that is suitable to be used as data type in + pandas, i.e., either a `ColumnDtype` (str-like enum) or a pandas `CategoricalDtype` + """ + try: + if isinstance(dtype, ColumnDtype): + return dtype + if isinstance(dtype, str): + for val in ColumnDtype: + if val == dtype or val.name == dtype: + return val + if isinstance(dtype, (list, tuple)): + dtype = pd.CategoricalDtype(dtype) if isinstance(dtype, pd.CategoricalDtype): - if val in dtype.categories: - return val - dtype_name = 'categorical' - else: - actual_dtype = ColumnDtype.of(val) - if actual_dtype == dtype: - return val - if dtype == ColumnDtype.float and actual_dtype == ColumnDtype.int: - return float(val) - elif dtype == ColumnDtype.datetime and isinstance(val, date): - return datetime(val.year, val.month, val.day) - dtype_name = dtype.name - raise ValueError(f'Invalid value for type {dtype_name}: {str(val)}') + # check that the categories are all of the same supported data type: + if ColumnDtype.of(dtype) is not None: + return dtype + except (KeyError, ValueError, TypeError): + pass + raise ValueError(f'Invalid data type: {str(dtype)}') + + +def cast_to_dtype(val: Any, pd_dtype: Union[ColumnDtype, pd.CategoricalDtype]) -> Any: + """cast `val` to the given pandas dtype, raise ValueError if unsuccessful + + :param val: any Python object + :param pd_dtype: the result of `to_pandas_dtype: either a `ColumnDtype` or + pandas `CategoricalDtype` object + """ + if pd_dtype is None: + raise ValueError(f'Invalid dtype: {str(pd_dtype)}') + if isinstance(pd_dtype, pd.CategoricalDtype): + if val in pd_dtype.categories: + return val + dtype_name = 'categorical' + else: + actual_dtype = ColumnDtype.of(val) + if actual_dtype == pd_dtype: + return val + if pd_dtype == ColumnDtype.float and actual_dtype == ColumnDtype.int: + return float(val) + elif pd_dtype == ColumnDtype.datetime and isinstance(val, date): + return datetime(val.year, val.month, val.day) + dtype_name = pd_dtype.name + raise ValueError(f'Invalid value for type {dtype_name}: {str(val)}') def get_rupture_params() -> set[str]: @@ -327,10 +341,10 @@ def _extract_from_columns(columns: dict[str, dict[str, Any]], *, else: aliases = set(aliases) aliases.add(c_name) - cdtype = None + pd_dtype = None if ((dtype is not None or default is not None or bounds is not None) and 'dtype' in props): - cdtype = ColumnDtype.get(props['dtype']) + pd_dtype = to_pandas_dtype(props['dtype']) for name in aliases: if check_type and 'type' in props: ctype = ColumnType[props['type']] @@ -344,12 +358,12 @@ def _extract_from_columns(columns: dict[str, dict[str, Any]], *, imts.add(name) if alias is not None: alias[name] = aliases - if dtype is not None and cdtype is not None: - dtype[name] = cdtype + if dtype is not None and pd_dtype is not None: + dtype[name] = pd_dtype if default is not None and 'default' in props: - default[name] = ColumnDtype.cast(props['default'], cdtype) + default[name] = cast_to_dtype(props['default'], pd_dtype) if bounds is not None: - _bounds = {k: ColumnDtype.cast(props[k], cdtype) + _bounds = {k: cast_to_dtype(props[k], pd_dtype) for k in ["<", "<=", ">", ">="] if k in props} if _bounds: diff --git a/tests/smtk/flatfile/test_flatfile_columns_yaml.py b/tests/smtk/flatfile/test_flatfile_columns_yaml.py index 2757db58..1a0fb89d 100644 --- a/tests/smtk/flatfile/test_flatfile_columns_yaml.py +++ b/tests/smtk/flatfile/test_flatfile_columns_yaml.py @@ -13,7 +13,9 @@ from egsim.smtk import get_gsim_names, get_rupture_params_required_by, \ get_sites_params_required_by, get_distances_required_by -from egsim.smtk.flatfile.columns import (ColumnType, ColumnDtype, _extract_from_columns, +from egsim.smtk.flatfile import cast_to_dtype +from egsim.smtk.flatfile.columns import (ColumnType, ColumnDtype, + _extract_from_columns, _ff_metadata_path) @@ -138,7 +140,7 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None], raise ValueError(f"{prefix} bounds cannot be provided with " f"categorical data type") if default_is_given: - ColumnDtype.cast(default, dtype) # raise if not in categories + cast_to_dtype(default, dtype) # raise if not in categories return assert isinstance(dtype, ColumnDtype) @@ -158,7 +160,7 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None], min_val = bounds.get(">", bounds.get(">=", None)) for val in [max_val, min_val]: if val is not None: - ColumnDtype.cast(val, dtype) + cast_to_dtype(val, dtype) if max_val is not None and min_val is not None and max_val <= min_val: raise ValueError(f'{prefix} min. bound must be lower than ' f'max. bound') @@ -166,7 +168,7 @@ def check_column_metadata(*, name: str, ctype: Union[ColumnType, None], # check default value: if default_is_given: # this should already been done but for dafety: - ColumnDtype.cast(default, dtype) + cast_to_dtype(default, dtype) def check_with_openquake(rupture_params: dict[str, set[str]], @@ -224,12 +226,13 @@ def test_Column_dtype(): 'bool': [True, False], 'str': [None, "x"] }) + d.str = d.str.astype('string') for c in d.columns: dtyp = d[c].dtype assert ColumnDtype.of(dtyp) == ColumnDtype[c] assert ColumnDtype.of(d[c]) == ColumnDtype[c] assert ColumnDtype.of(d[c].values) == ColumnDtype[c] - assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c] if _ is not None) - assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c].values if _ is not None) + assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c] if pd.notna(_)) + assert all(ColumnDtype.of(_) == ColumnDtype[c] for _ in d[c].values if pd.notna(_)) assert ColumnDtype.of(None) is None diff --git a/tests/smtk/flatfile/test_flatfile_io.py b/tests/smtk/flatfile/test_flatfile_io.py index d41db74e..84501b1b 100644 --- a/tests/smtk/flatfile/test_flatfile_io.py +++ b/tests/smtk/flatfile/test_flatfile_io.py @@ -8,6 +8,7 @@ import pytest from datetime import datetime import pandas as pd +from pandas import StringDtype from egsim.smtk.flatfile import (read_flatfile, query, read_csv) from egsim.smtk.flatfile.columns import _extract_from_columns, load_from_yaml, \ @@ -48,7 +49,7 @@ def test_read_csv(): 'float': np.dtype('float64'), 'bool': np.dtype('bool'), 'datetime': np.dtype('