Skip to content

Commit

Permalink
fix resiualds tests and add invalid columns exception tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rizac committed Sep 18, 2023
1 parent 30c3b81 commit 03586cc
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 144 deletions.
107 changes: 44 additions & 63 deletions egsim/smtk/flatfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from openquake.hazardlib.contexts import RuptureContext

from .columns import (ColumnDtype, get_rupture_param_columns,
get_dtypes_and_defaults, get_column_names,
get_intensity_measure_columns)
get_dtypes_and_defaults, get_all_names_of,
get_intensity_measure_columns, MissingColumn,
InvalidDataInColumn, InvalidColumnName, ConflictingColumns)
from .. import get_SA_period
from ...smtk.trellis.configure import vs30_to_z1pt0_cy14, vs30_to_z2pt5_cb14

Expand Down Expand Up @@ -246,7 +247,7 @@ def read_csv(filepath_or_buffer: Union[str, IO],
invalid_columns.append(col)

if invalid_columns:
raise ValueError(f'Invalid values in column(s): {", ".join(invalid_columns)}')
raise InvalidDataInColumn(*invalid_columns)

# set defaults:
invalid_defaults = []
Expand All @@ -260,6 +261,8 @@ def read_csv(filepath_or_buffer: Union[str, IO],
pass
invalid_defaults.append(col)

if not isinstance(dfr, pd.RangeIndex):
dfr.reset_index(drop=True, inplace=True)
return dfr


Expand Down Expand Up @@ -362,7 +365,7 @@ def get_column_name(flatfile:pd.DataFrame, column:str) -> Union[str, None]:
Returns None if no column is found, raise `ConflictingColumns` if more than
a matching column is found"""
ff_cols = set(flatfile.columns)
cols = get_column_names(column) & ff_cols
cols = get_all_names_of(column) & ff_cols
if len(cols) > 1:
raise ConflictingColumns(*cols)
elif len(cols) == 0:
Expand Down Expand Up @@ -396,30 +399,44 @@ def get_station_id_column_names(flatfile: pd.DataFrame) -> list[str, ...]:


def prepare_for_residuals(flatfile: pd.DataFrame,
gsims: Iterable[GMPE], imts: Iterable[str]) -> pd.Dataframe:
gsims: Iterable[GMPE], imts: Iterable[str]) -> pd.DataFrame:
"""Return a new dataframe with all columns required to compute residuals
from the given models (`gsim`) and intensity measures (`imts`) given with
periods, when needed (e.g. "SA(0.2)")
"""
new_flatfile = pd.DataFrame(index=flatfile.index)
new_dataframes = []
# prepare the flatfile for the required ground motion properties:
props_flatfile = pd.DataFrame(index=flatfile.index)
for prop in get_required_ground_motion_properties(gsims):
new_flatfile[prop] = \
props_flatfile[prop] = \
get_ground_motion_property_values(flatfile, prop)
if not props_flatfile.empty:
new_dataframes.append(props_flatfile)
# validate imts:
imts = set(imts)
non_sa = {_ for _ in imts if not get_SA_period(_) is None}
non_sa_imts = {_ for _ in imts if get_SA_period(_) is None}
# get supported imts but does not allow 'SA' alone to be valid:
supported_imts = get_intensity_measure_columns() - {'SA'}
if non_sa - supported_imts:
raise InvalidColumn(*{non_sa - supported_imts})
if non_sa_imts:
supported_imts = get_intensity_measure_columns() - {'SA'}
if non_sa_imts - supported_imts:
raise InvalidColumnName(*list(non_sa_imts - supported_imts))
# raise if some imts are not in the flatfile:
if non_sa_imts - set(flatfile.columns):
raise MissingColumn(*list(non_sa_imts - set(flatfile.columns)))
# add non SA imts:
new_dataframes.append(flatfile[sorted(non_sa_imts)])
# prepare the flatfile for SA (create new columns by interpolation if necessary):
sa = imts - non_sa
sa_dataframe = _prepare_for_sa(flatfile, sa)
if not sa_dataframe.empty:
new_flatfile[list(sa_dataframe.columns)] = sa_dataframe
sa_imts = imts - non_sa_imts
if sa_imts:
sa_dataframe = _prepare_for_sa(flatfile, sa_imts)
if not sa_dataframe.empty:
new_dataframes.append(sa_dataframe)

if not new_dataframes:
return pd.DataFrame(columns=flatfile.columns) # empty dataframe

return pd.concat(new_dataframes, axis=1)

return new_flatfile


def get_required_ground_motion_properties(gsims: Iterable[GMPE]) -> set[str]:
Expand Down Expand Up @@ -528,10 +545,12 @@ def fill_na(flatfile:pd.DataFrame,


def _prepare_for_sa(flatfile: pd.DataFrame, sa_imts: Iterable[str]) -> pd.DataFrame:
"""Modify inplace the flatfile assuring the SA columns in `sa_imts` (e.g. "SA(0.2)")
are present. The SA column of the flatfile will be used to obtain
the target SA via interpolation, and removed if not necessary.
"""Return a new Dataframe with the SA columns defined in `sa_imts`
The returned DataFrame will have all strings supplied in `sa_imts` as columns,
with relative values copied (or inferred via interpolation) from the given flatfile
:param flatfile: the flatfile
:param sa_imts: Iterable of strings denoting SA (e.g. "SA(0.2)")
Return the newly created Sa columns, as tuple of strings
"""
src_sa = []
Expand All @@ -552,16 +571,18 @@ def _prepare_for_sa(flatfile: pd.DataFrame, sa_imts: Iterable[str]) -> pd.DataFr
if p not in source_sa:
tgt_sa.append((p, i))
if invalid_sa:
raise InvalidColumn(*invalid_sa)
raise InvalidDataInColumn(*invalid_sa)

# source_sa: period [float] -> mapped to the relative column:
target_sa: dict[float, str] = {p: c for p, c in sorted(tgt_sa, key=lambda t: t[0])}

if not source_sa or not target_sa:
return pd.DataFrame(index=flatfile.index, data=[])
source_sa_flatfile = flatfile[list(source_sa.values())]

if not target_sa:
return source_sa_flatfile

# Take the log10 of all SA:
source_spectrum = np.log10(flatfile[list(source_sa.values())])
source_spectrum = np.log10(source_sa_flatfile)
# we need to interpolate row wise
# build the interpolation function:
interp = interp1d(list(source_sa), source_spectrum, axis=1)
Expand Down Expand Up @@ -618,46 +639,6 @@ def __getattr__(self, column_name):
return values


class InvalidColumn(Exception):
"""
General flatfile column(s) error. See subclasses for details
"""
def __init__(self, *names, sep=', '):
super().__init__(*names)
self._sep = sep

def __str__(self):
"""Make str(self) more clear"""
prefix = self.__class__.__name__
# replace upper cases with space + lower case letter
prefix = re.sub("([A-Z])", " \\1", prefix).strip().capitalize()
names = self.args
suffix = self._sep.join(repr(_) for _ in names)
return f"{prefix}{'s' if len(names) > 1 else ''} {suffix}"

def __repr__(self):
return self.__str__()


class MissingColumn(InvalidColumn, AttributeError, KeyError):
"""MissingColumnError. It inherits also from AttributeError and
KeyError to be compliant with pandas and OpenQuake"""

def __init__(self, name):
sorted_names = get_column_names(name, sort=True)
suffix_str = repr(sorted_names[0] or name)
if len(sorted_names) > 1:
suffix_str += " (alias" if len(sorted_names) == 2 else " (aliases"
suffix_str += f": {', '.join(repr(_) for _ in sorted_names[1:])})"
super().__init__(suffix_str)


class ConflictingColumns(InvalidColumn):

def __init__(self, *names):
InvalidColumn.__init__(self, *names, sep=" vs. ")


# FIXME REMOVE LEGACY STUFF CHECK WITH GW:

# FIXME: remove columns checks will be done when reading the flatfile and
Expand Down
98 changes: 84 additions & 14 deletions egsim/smtk/flatfile/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
module containing all column metadata information stored in the associated
YAML file
"""
import re
from datetime import datetime, date
from enum import Enum
from os.path import join, dirname

from typing import Union, Any, Iterable
from typing import Union, Any

# try to speed up yaml.safe_load (https://pyyaml.org/wiki/PyYAMLDocumentation):
from yaml import load as yaml_load

try:
from yaml import CSafeLoader as default_yaml_loader # faster, if available
from yaml import CSafeLoader as SafeLoader # faster, if available
except ImportError:
from yaml import SafeLoader as default_yaml_loader # same as using yaml.safe_load
from yaml import SafeLoader # same as using yaml.safe_load

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -70,21 +71,15 @@ def get_intensity_measure_columns() -> set[str]:
_alias: dict[str, set[str]] = None # noqa


def get_column_names(column, sort=False) -> Union[set[str], list[str]]:
"""Return all possible names of the given column, as set of strings. If sort is
True, a list is returned where the first element is the primary column name (one of
the top-level keys defined in the YAML dict)
def get_all_names_of(column) -> set[str]:
"""Return all possible names of the given column, as set of strings. The set
will be empty if `column` does not denote a flatfile column
"""
global _alias
if _alias is None:
_alias = {}
_extract_from_columns(load_from_yaml(), alias=_alias)
names = _alias.get(column, set())
if not sort:
return names
else:
return [n for n in names if n in _columns] + \
[n for n in names if n in _columns]
return _alias.get(column, set())


def get_dtypes_and_defaults() -> \
Expand All @@ -98,6 +93,81 @@ def get_dtypes_and_defaults() -> \
return _dtype, _default


class InvalidColumn(Exception):
"""
General flatfile column(s) error. See subclasses for details
"""
def __init__(self, *names, sep=', ', plural_suffix='s'):
super().__init__(*names)
self._sep = sep
self._plural_suffix = plural_suffix

@property
def names(self):
"""return the names (usually column names) raising this Exception
and passed in `__init__`"""
return [repr(_) for _ in self.args]

def __str__(self):
"""Make str(self) more clear"""
# get prefix (e.g. 'Missing column(s)'):
prefix = self.__class__.__name__
# replace upper cases with space + lower case letter
prefix = re.sub("([A-Z])", " \\1", prefix).strip().capitalize()
names = self.names
if len(names) != 1:
prefix += self._plural_suffix
# return full string:
return f"{prefix} {self._sep.join(names)}"

def __repr__(self):
return self.__str__()


class MissingColumn(InvalidColumn, AttributeError, KeyError):
"""MissingColumnError. It inherits also from AttributeError and
KeyError to be compliant with pandas and OpenQuake"""

@property
def names(self):
"""return the names with their alias(es), if any"""
_names = []
for name in self.args:
sorted_names = self.get_all_names_of(name)
suffix_str = repr(sorted_names[0])
if len(sorted_names) > 1:
suffix_str += f" (or {', '.join(repr(_) for _ in sorted_names[1:])})"
_names.append(suffix_str)
return _names

@classmethod
def get_all_names_of(cls, col_name) -> list[str]:
"""Return a list of all column names of the argument, with the first element
being the flatfile primary name. Returns `[col_name]` if the argument does not
denote any flatfile column"""
names = get_all_names_of(col_name)
if len(names) <= 1:
return [col_name]
global _columns # not needed, just as reminder
return [n for n in names if n in _columns] + \
[n for n in names if n not in _columns]


class ConflictingColumns(InvalidColumn):

def __init__(self, name1, name2, *other_names):
InvalidColumn.__init__(self, name1, name2, *other_names,
sep=" vs. ", plural_suffix='')


class InvalidDataInColumn(InvalidColumn, ValueError, TypeError):
pass


class InvalidColumnName(InvalidColumn):
pass


# YAML file path:
_ff_metadata_path = join(dirname(__file__), 'columns.yaml')
# cache storage of the data in the YAML:
Expand All @@ -116,7 +186,7 @@ def load_from_yaml(cache=True) -> dict[str, dict[str, Any]]:
if cache and _columns:
return _columns
with open(_ff_metadata_path) as fpt:
_cols = yaml_load(fpt, default_yaml_loader)
_cols = yaml_load(fpt, SafeLoader)
if cache:
_columns = _cols
return _cols
Expand Down
Loading

0 comments on commit 03586cc

Please sign in to comment.