Skip to content

Commit

Permalink
smtk residuals code refactor - improving read_flatfile: bool dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
rizac committed Aug 28, 2023
1 parent 7500b2d commit 6136f86
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 44 deletions.
92 changes: 48 additions & 44 deletions egsim/smtk/flatfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,80 +270,84 @@ def read_csv(filepath_or_buffer: str,
if issubclass(col_dtype, ColumnDtype[expected_col_dtype_name].value):
continue

old_series = dfr[col]
# if old_series contains value that can be converted to the proper dtype, put it
# in new_series (None means data type is not convertible, i.e. invalid column):
new_series = None

# date-times: try to convert columns of type str.
# Note: Remember that we do not use pandas parse_dates arg as
# it raises if we put columns that eventually are not in the flatfile):
if expected_col_dtype_name == ColumnDtype.datetime.name:
datetime_parsed = False
if issubclass(col_dtype, ColumnDtype.str.value):
try:
dfr[col] = pd.to_datetime(dfr[col])
datetime_parsed = True
new_series = pd.to_datetime(old_series)
except ValueError:
# be relaxed on missing values, if any, and retry:
missing = dfr[col].isin(missing_values)
missing = old_series.isin(missing_values)
if missing.any():
values = dfr[col].copy()
values[missing] = pd.NaT
old_series = old_series.copy()
old_series[missing] = pd.NaT
try:
dfr[col] = pd.to_datetime(values)
datetime_parsed = True
new_series = pd.to_datetime(old_series)
except ValueError:
pass
if datetime_parsed:
continue

# float: convert columns of type int:
if expected_col_dtype_name == ColumnDtype.float.name: # "float"
elif expected_col_dtype_name == ColumnDtype.float.name: # "float"
if issubclass(col_dtype, ColumnDtype.int.value):
dfr[col] = dfr[col].astype(float)
continue
new_series = old_series.astype(float)

# int: convert columns of type float as long as they are float with either int
# or missing values (replace the latter with the column default, or 0), e.g.
# [NaN, 1] becomes [<default_or_zero>, 1], [NaN, 1.2] raises:
if expected_col_dtype_name == ColumnDtype.int.name: # "int"
# [NaN, 1] becomes [<default_or_zero>, 1], [NaN, 1.2] is invalid:
elif expected_col_dtype_name == ColumnDtype.int.name: # "int"
if issubclass(col_dtype, ColumnDtype.float.value):
series = dfr[col].copy()
na_values = pd.isna(series)
series.loc[na_values] = defaults.pop(col, 0)
na_values = pd.isna(old_series)
old_series = old_series.copy()
old_series.loc[na_values] = defaults.pop(col, 0)
try:
series = series.astype(int)
old_series = old_series.astype(int)
# check all non missing elements are int:
if (dfr[col][~na_values] == series[~na_values]).all(): # noqa
dfr[col] = series
continue
if (dfr[col][~na_values] == old_series[~na_values]).all(): # noqa
new_series = old_series
except Exception: # noqa
pass

# bool: try to convert str/int/float columns:
if expected_col_dtype_name == ColumnDtype.bool.name: # "bool"
new_values = None
if issubclass(col_dtype, ColumnDtype.int.value):
new_values = dfr[col]
if sorted(pd.unique(new_values)) != [0, 1]:
new_values = None
elif issubclass(col_dtype, ColumnDtype.float.value):
na_values = pd.isna(dfr[col])
new_values = dfr[col]
new_values.loc[na_values] = defaults.pop(col, False)
if sorted(pd.unique(new_values)) != [0, 1]:
new_values = None
# bool: try to convert str/int/float columns. Try to convert to numeric
# and then cast to bool if the values are either 0 or 1
elif expected_col_dtype_name == ColumnDtype.bool.name: # "bool"
if issubclass(col_dtype, ColumnDtype.float.value):
# float: convert missing values to 0
old_series = old_series.copy()
na_values = pd.isna(old_series)
old_series.loc[na_values] = defaults.pop(col, 0)
elif issubclass(col_dtype, ColumnDtype.str.value):
na_values = pd.isna(dfr[col])
new_values = dfr[col].astype(str).str.lower()
new_values.loc[na_values] = defaults.pop(col, False)
# str/object: # convert missing values to False, true_values to True
# and false_values to False
na_values = pd.isna(old_series)
old_series = old_series.astype(str).str.lower()
old_series.loc[na_values] = defaults.pop(col, False)
true_values = ['true'] + list(kwargs.get('true_values', []))
new_values.loc[new_values.isin(true_values)] = True
old_series.loc[old_series.isin(true_values)] = True
false_values = ['false'] + list(kwargs.get('false_values', []))
new_values.loc[new_values.isin(false_values)] = False
if new_values is not None:
old_series.loc[old_series.isin(false_values)] = False
elif not issubclass(col_dtype, ColumnDtype.int.value):
old_series = None

if old_series is not None and \
pd.unique(old_series).tolist() in ([0,1], [1, 0]):
# note: the above holds also if old_series values are bool (True == 1,
# np.array(True) == 1, the same for False and 0)
try:
dfr[col] = new_values.astype(bool)
continue
new_series = old_series.astype(bool)
except Exception: # noqa
pass
invalid_columns.append(col)

if new_series is None:
invalid_columns.append(col)
else:
dfr[col] = new_series

# set categorical columns:
for col in set(categorical_dtypes) & dfr_columns:
Expand Down
50 changes: 50 additions & 0 deletions tests/smtk/flatfile/test_flatfile_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,56 @@ def test_read_csv():
assert header in str(verr.value)


def test_read_csv_bool():
args = {
'sep': ';',
'skip_blank_lines': False,
# the above arg is needed (default is True) because
# we have a single column
# and missing values are input as blank lines
'dtype': {"str": "str", "float": "float",
"int": "int", "datetime": "datetime", "bool": "bool"},
}
expected = [True, True, True, False, False, False]
csv_str = "bool\n1\nTrue\ntrue\n0\nFalse\nfalse"
d = read_csv(StringIO(csv_str), **args) # noqa
assert (d['bool'] == expected).all()

# Insert a missing value at the beginning (defaults to False).
# NOTE: appending a missing value (empty line) is skipped even if skip_blank_lines is
# True (as it is probably interpreted as ending newline of the previous csv row?)
csv_str = csv_str.replace("bool\n", "bool\n\n")
d = read_csv(StringIO(csv_str), **args) # noqa
assert (d['bool'] == [False] + expected).all()

# Append invalid value (float not in [0, 1]):
with pytest.raises(ValueError):
d = read_csv(StringIO("bool\n1\nTrue\ntrue\nFalse\nfalse\n1.1"), **args) # noqa

# Append invalid value ("X"):
with pytest.raises(ValueError):
d = read_csv(StringIO("bool\n1\nTrue\ntrue\nFalse\nfalse\nX"), **args) # noqa

# int series is ok
csv_str = "bool\n1\n1\n1\n0\n0\n0"
d = read_csv(StringIO(csv_str), **args) # noqa
assert (d['bool'] == expected).all()
with pytest.raises(ValueError):
# int series must have only 0 and 1:
csv_str += "\n2"
d = read_csv(StringIO(csv_str), **args) # noqa

# float series is ok
csv_str = "bool\n1.0\n1.0\n1.0\n0.0\n0.0\n0.0"
d = read_csv(StringIO(csv_str), **args) # noqa
assert (d['bool'] == expected).all()
with pytest.raises(ValueError):
# float series must have only 0 and 1:
csv_str += "\n0.1"
d = read_csv(StringIO("bool\n1.0\n1.0\n1.0\n0.0\n0.0\n0.1"), **args) # noqa



def test_read_csv_categorical():
defaults = {
"str": "a",
Expand Down

0 comments on commit 6136f86

Please sign in to comment.