Skip to content

Commit

Permalink
Support for .isna() (#30)
Browse files Browse the repository at this point in the history
* xxx.isna()

* isna for unicode/obj

* fix test for py3.7
  • Loading branch information
jpn-- authored Sep 20, 2022
1 parent 20f27c6 commit 22e5a08
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 9 deletions.
14 changes: 14 additions & 0 deletions sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,20 @@ def visit_Call(self, node):
)
result = ast.BinOp(left=left, op=ast.BitAnd(), right=right)

# change XXX.isna() [with no arguments] to np.isnan(x)
if (
isinstance(node.func, ast.Attribute)
and node.func.attr == "isna"
and len(node.args) == 0
and len(node.keywords) == 0
):
apply_args = [self.visit(node.func.value)] # convert XXX into argument
result = ast.Call(
func=ast.Name("isnan_fast_safe"),
args=apply_args,
keywords=[],
)

# if no other changes
if result is None:
args = [self.visit(i) for i in node.args]
Expand Down
24 changes: 20 additions & 4 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def __initialize_2(
"from contextlib import suppress",
"from numpy import log, exp, log1p, expm1",
"from sharrow.maths import piece, hard_sigmoid, transpose_leading, clip, digital_decode",
"from sharrow.sparse import get_blended_2",
"from sharrow.sparse import get_blended_2, isnan_fast_safe",
}

func_code = self._func_code
Expand Down Expand Up @@ -1342,7 +1342,16 @@ def load_raw(self, rg, args, runner=None, dtype=None, dot=None):
if arg.startswith("__aux_var"):
arguments.append(arg_value)
else:
arguments.append(np.asarray(arg_value))
arg_value_array = np.asarray(arg_value)
if arg_value_array.dtype.kind == "O":
# convert object arrays to unicode str
# and replace missing values with NAK='\u0015'
# that can be found by `isnan_fast_safe`
# This is done for compatability and likely ruins performance
arg_value_array_ = arg_value_array.astype("unicode")
arg_value_array_[pd.isnull(arg_value_array)] = "\u0015"
arg_value_array = arg_value_array_
arguments.append(arg_value_array)
kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
Expand Down Expand Up @@ -1421,8 +1430,15 @@ def iload_raw(
arguments.append(argument)
else:
if argument.dtype.kind == "O":
argument = argument.astype("unicode")
arguments.append(np.asarray(argument))
# convert object arrays to unicode str
# and replace missing values with NAK='\u0015'
# that can be found by `isnan_fast_safe`
# This is done for compatability and likely ruins performance
argument_ = argument.astype("unicode")
argument_[pd.isnull(argument)] = "\u0015"
arguments.append(np.asarray(argument_))
else:
arguments.append(np.asarray(argument))
kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
Expand Down
17 changes: 13 additions & 4 deletions sharrow/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,21 @@ def blenders(self):
return b


@nb.njit
@nb.generated_jit(nopython=True)
def isnan_fast_safe(x):
if int(x) == -9223372036854775808:
return True
if isinstance(x, nb.types.Float):

def func(x):
if int(x) == -9223372036854775808:
return True
else:
return False

return func
elif isinstance(x, (nb.types.UnicodeType, nb.types.UnicodeCharSeq)):
return lambda x: x == "\u0015"
else:
return False
return lambda x: False


@nb.njit
Expand Down
33 changes: 32 additions & 1 deletion sharrow/tests/test_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
from numpy.random import SeedSequence, default_rng
from pytest import mark, raises
from pytest import approx, mark, raises

import sharrow
from sharrow import Dataset, DataTree, example_data
Expand Down Expand Up @@ -773,3 +773,34 @@ def test_nested_where(dataframe_regression):
check_names=False,
)
dataframe_regression.check(result)


def test_isna():
data = example_data.get_data()
data["hhs"].loc[data["hhs"].income > 200000, "income"] = np.nan
tree = DataTree(
base=data["hhs"],
)
ss = tree.setup_flow(
{
"missing_income": "((income < 0) | income.isna())",
"income_is_na": "income.isna()",
}
)
result = ss.load()
assert result[0, 0] == 1
assert result[0, 1] == 1
assert result[:, 0].sum() == 188
assert result[:, 1].sum() == 188

qf = pd.DataFrame({"MixedVals": ["a", "", None, np.nan]})
tree2 = DataTree(
base=qf,
)
qf = tree2.setup_flow(
{
"MixedVals_is_na": "MixedVals.isna()",
}
)
result = qf.load()
assert result == approx(np.asarray([[0, 0, 1, 1]]).T)

0 comments on commit 22e5a08

Please sign in to comment.