diff --git a/sharrow/aster.py b/sharrow/aster.py index 806553a..f0e655b 100755 --- a/sharrow/aster.py +++ b/sharrow/aster.py @@ -886,6 +886,20 @@ def visit_Call(self, node): ) result = ast.BinOp(left=left, op=ast.BitAnd(), right=right) + # change XXX.isna() [with no arguments] to np.isnan(x) + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "isna" + and len(node.args) == 0 + and len(node.keywords) == 0 + ): + apply_args = [self.visit(node.func.value)] # convert XXX into argument + result = ast.Call( + func=ast.Name("isnan_fast_safe"), + args=apply_args, + keywords=[], + ) + # if no other changes if result is None: args = [self.visit(i) for i in node.args] diff --git a/sharrow/flows.py b/sharrow/flows.py index 3a23437..35c2881 100644 --- a/sharrow/flows.py +++ b/sharrow/flows.py @@ -1073,7 +1073,7 @@ def __initialize_2( "from contextlib import suppress", "from numpy import log, exp, log1p, expm1", "from sharrow.maths import piece, hard_sigmoid, transpose_leading, clip, digital_decode", - "from sharrow.sparse import get_blended_2", + "from sharrow.sparse import get_blended_2, isnan_fast_safe", } func_code = self._func_code @@ -1342,7 +1342,16 @@ def load_raw(self, rg, args, runner=None, dtype=None, dot=None): if arg.startswith("__aux_var"): arguments.append(arg_value) else: - arguments.append(np.asarray(arg_value)) + arg_value_array = np.asarray(arg_value) + if arg_value_array.dtype.kind == "O": + # convert object arrays to unicode str + # and replace missing values with NAK='\u0015' + # that can be found by `isnan_fast_safe` + # This is done for compatability and likely ruins performance + arg_value_array_ = arg_value_array.astype("unicode") + arg_value_array_[pd.isnull(arg_value_array)] = "\u0015" + arg_value_array = arg_value_array_ + arguments.append(arg_value_array) kwargs = {} if dtype is not None: kwargs["dtype"] = dtype @@ -1421,8 +1430,15 @@ def iload_raw( arguments.append(argument) else: if argument.dtype.kind == "O": - argument = argument.astype("unicode") - arguments.append(np.asarray(argument)) + # convert object arrays to unicode str + # and replace missing values with NAK='\u0015' + # that can be found by `isnan_fast_safe` + # This is done for compatability and likely ruins performance + argument_ = argument.astype("unicode") + argument_[pd.isnull(argument)] = "\u0015" + arguments.append(np.asarray(argument_)) + else: + arguments.append(np.asarray(argument)) kwargs = {} if dtype is not None: kwargs["dtype"] = dtype diff --git a/sharrow/sparse.py b/sharrow/sparse.py index 444e526..2adc8f0 100644 --- a/sharrow/sparse.py +++ b/sharrow/sparse.py @@ -197,12 +197,21 @@ def blenders(self): return b -@nb.njit +@nb.generated_jit(nopython=True) def isnan_fast_safe(x): - if int(x) == -9223372036854775808: - return True + if isinstance(x, nb.types.Float): + + def func(x): + if int(x) == -9223372036854775808: + return True + else: + return False + + return func + elif isinstance(x, (nb.types.UnicodeType, nb.types.UnicodeCharSeq)): + return lambda x: x == "\u0015" else: - return False + return lambda x: False @nb.njit diff --git a/sharrow/tests/test_relationships.py b/sharrow/tests/test_relationships.py index 2a771fa..216fc7b 100644 --- a/sharrow/tests/test_relationships.py +++ b/sharrow/tests/test_relationships.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from numpy.random import SeedSequence, default_rng -from pytest import mark, raises +from pytest import approx, mark, raises import sharrow from sharrow import Dataset, DataTree, example_data @@ -773,3 +773,34 @@ def test_nested_where(dataframe_regression): check_names=False, ) dataframe_regression.check(result) + + +def test_isna(): + data = example_data.get_data() + data["hhs"].loc[data["hhs"].income > 200000, "income"] = np.nan + tree = DataTree( + base=data["hhs"], + ) + ss = tree.setup_flow( + { + "missing_income": "((income < 0) | income.isna())", + "income_is_na": "income.isna()", + } + ) + result = ss.load() + assert result[0, 0] == 1 + assert result[0, 1] == 1 + assert result[:, 0].sum() == 188 + assert result[:, 1].sum() == 188 + + qf = pd.DataFrame({"MixedVals": ["a", "", None, np.nan]}) + tree2 = DataTree( + base=qf, + ) + qf = tree2.setup_flow( + { + "MixedVals_is_na": "MixedVals.isna()", + } + ) + result = qf.load() + assert result == approx(np.asarray([[0, 0, 1, 1]]).T)