Skip to content

Commit

Permalink
Nested np.where (#29)
Browse files Browse the repository at this point in the history
* doc cleaning
* nested np.where
* boost req numba to 0.54
  • Loading branch information
jpn-- authored Sep 16, 2022
1 parent 771f1b1 commit 20f27c6
Show file tree
Hide file tree
Showing 9 changed files with 8,365 additions and 8 deletions.
5 changes: 5 additions & 0 deletions docs/_static/sharrow-docs.css
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ img[src*="#floatleft"] {
float: left;
margin-right: 20px;
}

div.cell_output > div.output.stream > div.highlight > pre {
font-size: 80%;
background-color: lightgoldenrodyellow;
}
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy >= 1.19
pandas >= 1.2
pyarrow >= 3.0.0
xarray >= 0.20.0
numba >= 0.53
numba >= 0.54
numexpr
filelock
sphinx-autosummary-accessors
9 changes: 5 additions & 4 deletions docs/walkthrough/one-dim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
},
"outputs": [],
"source": [
"# test households content\n",
"# TEST households content\n",
"assert len(households) == 5000\n",
"assert \"income\" in households \n",
"assert households.index.name == \"HHID\""
Expand Down Expand Up @@ -200,6 +200,7 @@
},
"outputs": [],
"source": [
"# TEST\n",
"assert tours.index.name == \"TOURIDX\"\n",
"assert 0 in tours.head().dest_taz_idx"
]
Expand Down Expand Up @@ -435,7 +436,7 @@
},
"outputs": [],
"source": [
"# test utility data\n",
"# TEST utility data\n",
"assert flow.compiled_recently == True\n",
"actual = flow.load()\n",
"expected = np.array([[ 9.4 , 16.9572 , 4.5 , 0. , 1. ],\n",
Expand Down Expand Up @@ -687,7 +688,7 @@
},
"outputs": [],
"source": [
"# test utility\n",
"# TEST utility\n",
"np.testing.assert_array_almost_equal(u, np.dot(x, b))"
]
},
Expand Down Expand Up @@ -866,7 +867,7 @@
},
"outputs": [],
"source": [
"# test mnl choices\n",
"# TEST mnl choices\n",
"uz = np.exp(flow.dot(b))\n",
"uz = uz / uz.sum(1)[:,None]\n",
"np.testing.assert_array_almost_equal(\n",
Expand Down
2 changes: 1 addition & 1 deletion envs/development.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:
- nbmake
- networkx
- notebook
- numba>=0.53
- numba>=0.54
- numexpr
- numpy>=1.19
- openmatrix
Expand Down
2 changes: 1 addition & 1 deletion envs/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- xarray
- dask
- networkx
- numba>=0.53
- numba>=0.54
- numexpr
- sparse
- filelock
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ install_requires =
pandas >= 1.2
pyarrow >= 3.0.0
xarray >= 0.20.0
numba >= 0.51.2
numba >= 0.54
sparse
numexpr
filelock
Expand Down
47 changes: 47 additions & 0 deletions sharrow/maths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numba as nb
import numpy as np
from numba.extending import overload


@nb.njit(cache=True)
Expand Down Expand Up @@ -127,3 +128,49 @@ def digital_decode(encoded_value, scale, offset, missing_value):
return missing_value
else:
return (encoded_value * scale) + offset


@overload(np.where)
def _zero_len_array_where(condition, x=None, y=None):
if isinstance(condition, nb.types.Boolean):
if (
isinstance(x, nb.types.Array)
and x.ndim == 0
and isinstance(x.dtype, (nb.types.Integer, nb.types.Float))
):
if isinstance(y, (nb.types.Integer, nb.types.Float)):

def impl(condition, x=None, y=None):
if condition:
return x.item()
else:
return y

return impl
elif (
isinstance(y, nb.types.Array)
and y.ndim == 0
and isinstance(y.dtype, (nb.types.Integer, nb.types.Float))
):

def impl(condition, x=None, y=None):
if condition:
return x.item()
else:
return y.item()

return impl
elif isinstance(x, (nb.types.Integer, nb.types.Float)):
if (
isinstance(y, nb.types.Array)
and y.ndim == 0
and isinstance(y.dtype, (nb.types.Integer, nb.types.Float))
):

def impl(condition, x=None, y=None):
if condition:
return x
else:
return y.item()

return impl
91 changes: 91 additions & 0 deletions sharrow/tests/test_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,3 +682,94 @@ def test_isin_and_between(dataframe_regression):
check_names=False,
)
dataframe_regression.check(result)


def test_nested_where(dataframe_regression):

data = example_data.get_data()
base = persons = data["persons"]

tree = DataTree(
base=persons,
extra_vars={
"pt1": 1,
"pt5": 5,
"pt34": [3, 4],
},
)

ss = tree.setup_flow(
{
"pt": "base.ptype",
"pt_shifted_1": "np.where(base.ptype<3, np.where(base.ptype<2, base.ptype*100, 0), base.ptype)",
"pt_shifted_2": "np.where(base.ptype<3, np.where(base.ptype<2, base.ptype*100, 0), 0)",
"pt_shifted_3": "np.where(base.ptype<3, 0, np.where(base.ptype>4, base.ptype*100, 0))",
"pt_shifted_4": "np.where(base.ptype<3, base.ptype, np.where(base.ptype>4, base.ptype*100, 0))",
"pt_shifted_5": "np.where(base.ptype<3, base.ptype, np.where(base.ptype>4, base.ptype*100, base.ptype))",
"pt_shifted_6": "np.where(base.ptype<3, 0, np.where(base.ptype>4, base.ptype*100, base.ptype))",
}
)
result = ss.load_dataframe(tree)
pd.testing.assert_series_equal(
pd.Series(
np.where(
base.ptype < 3,
np.where(base.ptype < 2, base.ptype * 100, 0),
base.ptype,
).astype(np.float32),
),
result["pt_shifted_1"],
check_names=False,
)
pd.testing.assert_series_equal(
pd.Series(
np.where(
base.ptype < 3, np.where(base.ptype < 2, base.ptype * 100, 0), 0
).astype(np.float32),
),
result["pt_shifted_2"],
check_names=False,
)
pd.testing.assert_series_equal(
pd.Series(
np.where(
base.ptype < 3, 0, np.where(base.ptype > 4, base.ptype * 100, 0)
).astype(np.float32),
),
result["pt_shifted_3"],
check_names=False,
)
pd.testing.assert_series_equal(
pd.Series(
np.where(
base.ptype < 3,
base.ptype,
np.where(base.ptype > 4, base.ptype * 100, 0),
).astype(np.float32),
),
result["pt_shifted_4"],
check_names=False,
)
pd.testing.assert_series_equal(
pd.Series(
np.where(
base.ptype < 3,
base.ptype,
np.where(base.ptype > 4, base.ptype * 100, base.ptype),
).astype(np.float32),
),
result["pt_shifted_5"],
check_names=False,
)
pd.testing.assert_series_equal(
pd.Series(
np.where(
base.ptype < 3,
0,
np.where(base.ptype > 4, base.ptype * 100, base.ptype),
).astype(np.float32),
),
result["pt_shifted_6"],
check_names=False,
)
dataframe_regression.check(result)
Loading

0 comments on commit 20f27c6

Please sign in to comment.