Skip to content

Commit

Permalink
Aux Vars (#28)
Browse files Browse the repository at this point in the history
* fallback to dynamic version

* version infer

* fix dataset from_omx for arbitrary zone ids

* aux_vars

* add tests for aux_vars

* nb.typed.Dict in aux

* pypi only from root repo

* hash on used extra funcs

* import extra_funcs not pickle when possible
  • Loading branch information
jpn-- authored Sep 13, 2022
1 parent 7c500f5 commit 771f1b1
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ jobs:
# now send to PyPI
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
if: github.repository_owner == 'ActivitySim'
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
93 changes: 76 additions & 17 deletions docs/walkthrough/one-dim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"metadata": {},
"outputs": [],
"source": [
"import numba as nb\n",
"import numpy as np\n",
"import pandas as pd\n",
"import xarray as xr\n",
Expand All @@ -26,17 +27,6 @@
"sh.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa0564dd",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.version_info < (3,8)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -258,8 +248,8 @@
"Drive Time,odt_skims['SOV_TIME'] + dot_skims['SOV_TIME'],-0.0134,,\n",
"Transit IVT,(odt_skims['WLK_LOC_WLK_TOTIVT']/100 + dot_skims['WLK_LOC_WLK_TOTIVT']/100),,,-0.0134\n",
"Transit Wait Time,short_i_wait_mult * ((odt_skims['WLK_LOC_WLK_IWAIT']/100).clip(upper=shortwait) + (dot_skims['WLK_LOC_WLK_IWAIT']/100).clip(upper=shortwait)),,,-0.0134\n",
"Income,hh.income > 60000,,-0.2,\n",
"Constant,1,,-0.4,-0.55\n",
"Income,hh.income > income_breakpoints[2],,-0.2,\n",
"Constant,one,,-0.4,-0.55\n",
"\"\"\""
]
},
Expand Down Expand Up @@ -293,6 +283,7 @@
},
"outputs": [],
"source": [
"# TEST check spec\n",
"assert spec.index.name == \"Label\"\n",
"assert all(spec.columns == ['Expression', 'DRIVE', 'WALK', 'TRANSIT'])"
]
Expand All @@ -317,6 +308,11 @@
"metadata": {},
"outputs": [],
"source": [
"income_breakpoints = nb.typed.Dict.empty(nb.types.int32,nb.types.int32)\n",
"income_breakpoints[0] = 15000\n",
"income_breakpoints[1] = 30000\n",
"income_breakpoints[2] = 60000\n",
"\n",
"tree = sh.DataTree(\n",
" tour=tours,\n",
" person=persons,\n",
Expand All @@ -334,8 +330,12 @@
" \"tour.in_time_period @ dot_skims.time_period\",\n",
" ),\n",
" extra_vars={\n",
" 'short_i_wait_mult': 0.75,\n",
" 'shortwait': 3,\n",
" 'one': 1,\n",
" },\n",
" aux_vars={\n",
" 'short_i_wait_mult': 0.75,\n",
" 'income_breakpoints': income_breakpoints,\n",
" },\n",
")"
]
Expand Down Expand Up @@ -365,9 +365,15 @@
"this manner, as the `dest_taz_idx` variable in the `tours` dataset contains positional references\n",
"instead of labels.\n",
"\n",
"Lastly, out tree definition includes a few named constants, that are just fixed values defined\n",
"in a separate dictionary. These values get hard-coded into the compiled results, effectively the \n",
"same as if their values were expanded and written into exprssions in the `spec` directly.\n",
"Lastly, our tree definition includes a few named constants, that are just fixed values defined\n",
"in a separate dictionary. These are shown in two groups, `extra_vars` and `aux_vars`. The values \n",
"in `extra_vars` get hard-coded into the compiled results, effectively the \n",
"same as if their values were expanded and written into exprssions in the `spec` directly. This is\n",
"generally most efficient if the values will never change. On the other hand, `aux_vars` will be \n",
"passed by reference into the compiled results. These values need to be numba-safe objects, so\n",
"for instance a regular Python dictionary can't be used, but a numba typed Dict is acceptable.\n",
"So long as the data type and dimensionality of the values in `aux_vars` remains constant, the \n",
"actual values can be changed later (i.e. after compilation).\n",
"\n",
"Once we have defined our data tree, we can use it along with the `spec`, to compute the utility\n",
"for various alternatives in the choice model. Sharrow allows us to compile this utility function\n",
Expand Down Expand Up @@ -395,6 +401,19 @@
"and that compiled code is cached to disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5607980e",
"metadata": {},
"outputs": [],
"source": [
"# TEST\n",
"from pytest import approx\n",
"assert flow.tree.aux_vars['short_i_wait_mult'] == 0.75\n",
"assert flow.tree.aux_vars['income_breakpoints'][2] == 60000"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -514,6 +533,19 @@
"tree_2 = tree.replace_datasets(tour=tours_2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "509aa5b1",
"metadata": {},
"outputs": [],
"source": [
"# TEST\n",
"from pytest import approx\n",
"assert tree_2.aux_vars['short_i_wait_mult'] == 0.75\n",
"assert tree_2.aux_vars['income_breakpoints'][2] == approx(60000)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -524,6 +556,33 @@
"%time flow.load(tree_2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa37a5a1",
"metadata": {},
"outputs": [],
"source": [
"# TEST that aux_vars also work with arrays\n",
"tree_a = tree_2.replace_datasets(tour=tours)\n",
"tree_a.aux_vars['income_breakpoints'] = np.asarray([1,2,60000])\n",
"actual = flow.load(tree_a)\n",
"expected = np.array([[ 9.4 , 16.9572 , 4.5 , 0. , 1. ],\n",
" [ 9.32 , 14.3628 , 4.5 , 1. , 1. ],\n",
" [ 7.62 , 11.0129 , 4.5 , 1. , 1. ],\n",
" [ 4.25 , 7.6692 , 2.50065 , 0. , 1. ],\n",
" [ 6.16 , 8.2186 , 3.387825, 0. , 1. ],\n",
" [ 4.86 , 4.9288 , 4.5 , 0. , 1. ],\n",
" [ 1.07 , 0. , 0. , 0. , 1. ],\n",
" [ 8.52 , 11.615499, 3.260325, 0. , 1. ],\n",
" [ 11.74 , 16.2798 , 3.440325, 0. , 1. ],\n",
" [ 10.48 , 13.3974 , 3.942825, 0. , 1. ]], dtype=np.float32)\n",
"\n",
"np.testing.assert_array_almost_equal(actual[:5], expected[:5])\n",
"np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])\n",
"assert actual.shape == (len(tours), len(spec))"
]
},
{
"cell_type": "markdown",
"id": "a78a1be9",
Expand Down
61 changes: 56 additions & 5 deletions sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,23 @@ def visit_Subscript(self, node):
node,
missing_dim_value=_b,
)
# for XXX[...], there is no space name and XXX is the name of an aux_var
if (
node.value.id in self.spacevars
and isinstance(self.spacevars[node.value.id], ast.Name)
and self.spacename == ""
):
result = ast.Subscript(
value=self.spacevars[node.value.id],
slice=self.visit(node.slice),
ctx=node.ctx,
)
self.log_event(
f"visit_Subscript(AuxVar {node.value.id})",
node,
result,
)
return result
self.log_event("visit_Subscript(no change)", node)
return node

Expand All @@ -632,6 +649,14 @@ def visit_Attribute(self, node):
)
self.log_event(f"visit_Attribute(Raw {node.attr})", node, result)
return result
if self.spacename == "" and node.value.id in self.spacevars:
result = ast.Attribute(
value=self.visit(node.value),
attr=node.attr,
ctx=node.ctx,
)
self.log_event("visit_Attribute(lead change)", node, result)
return result
return node
else:
result = ast.Attribute(
Expand All @@ -648,11 +673,15 @@ def visit_Name(self, node):
self.log_event("visit_Name(no change)", node)
return node
if self.spacename == "":
result = ast.Subscript(
value=ast.Name(id=self.rawname, ctx=ast.Load()),
slice=ast.Constant(self.spacevars[attr]),
ctx=node.ctx,
)
if isinstance(self.spacevars[attr], ast.Name):
# when spacevars values are ast.Name we are using it, it's probably an aux_var
result = self.spacevars[attr]
else:
result = ast.Subscript(
value=ast.Name(id=self.rawname, ctx=ast.Load()),
slice=ast.Constant(self.spacevars[attr]),
ctx=node.ctx,
)
self.log_event(f"visit_Name(Constant {attr})", node, result)
return result
else:
Expand Down Expand Up @@ -882,6 +911,28 @@ def expression_for_numba(
extra_vars=None,
blenders=None,
):
"""
Rewrite an expression so numba can compile it.
Parameters
----------
expr : str
The expression being rewritten
spacename : str
A namespace of variables that might be in the expression.
dim_slots : tuple or Any
spacevars : Mapping, optional
rawname : str
rawalias : str
digital_encodings : Mapping, optional
prefer_name : str, optional
extra_vars : Mapping, optional
blenders : Mapping, optional
Returns
-------
str
"""
return unparse_(
RewriteForNumba(
spacename,
Expand Down
64 changes: 57 additions & 7 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import base64
import hashlib
import importlib
Expand Down Expand Up @@ -598,6 +599,10 @@ def __new__(
hashing_level=hashing_level,
dim_order=dim_order,
dim_exclude=dim_exclude,
error_model=error_model,
boundscheck=boundscheck,
nopython=nopython,
fastmath=fastmath,
)
# return from library if available
if flow_library is not None and self.flow_hash in flow_library:
Expand Down Expand Up @@ -677,6 +682,17 @@ def __initialize_1(
if k in all_raw_names:
self._used_extra_vars[k] = v

self._used_extra_funcs = set()
if self.tree.extra_funcs:
for f in self.tree.extra_funcs:
if f.__name__ in all_raw_names:
self._used_extra_funcs.add(f.__name__)

self._used_aux_vars = []
for aux_var in self.tree.aux_vars:
if aux_var in all_raw_names:
self._used_aux_vars.append(aux_var)

self._hashing_level = hashing_level
if self._hashing_level > 1:
func_code, all_name_tokens = self.init_sub_funcs(
Expand Down Expand Up @@ -713,6 +729,10 @@ def _flow_hash_push(x):
v = self._used_extra_vars[k]
_flow_hash_push(k)
_flow_hash_push(v)
for k in sorted(self._used_aux_vars):
_flow_hash_push(f"aux_var:{k}")
for k in sorted(self._used_extra_funcs):
_flow_hash_push(f"func:{k}")
_flow_hash_push("---DataTree---")
for k in self.arg_names:
_flow_hash_push(f"arg:{k}")
Expand Down Expand Up @@ -786,6 +806,7 @@ def init_sub_funcs(
}
self.arg_name_positions = index_slots
candidate_names = self.tree.namespace_names()
candidate_names |= set(f"__aux_var__{i}" for i in self.tree.aux_vars.keys())

meta_data = {}

Expand Down Expand Up @@ -901,6 +922,22 @@ def init_sub_funcs(
"_outputs",
extra_vars=self.tree.extra_vars,
)

aux_tokens = {
k: ast.parse(f"__aux_var__{k}", mode="eval").body
for k in self.tree.aux_vars.keys()
}

# now handle aux vars
expr = expression_for_numba(
expr,
"",
(),
spacevars=aux_tokens,
prefer_name="aux_var",
extra_vars=self.tree.extra_vars,
)

if (k == init_expr) and (init_expr == expr) and k.isidentifier():
logger.error(f"unable to rewrite '{k}' to itself")
raise ValueError(f"unable to rewrite '{k}' to itself")
Expand Down Expand Up @@ -1047,10 +1084,14 @@ def __initialize_2(
import cloudpickle as pickle
except ModuleNotFoundError:
import pickle
dependencies.add("import pickle")
func_code += "\n\n# extra_funcs\n"
for x_func in self.tree.extra_funcs:
func_code += f"\n\n{x_func.__name__} = pickle.loads({repr(pickle.dumps(x_func))})\n"
if x_func.__name__ in self._used_extra_funcs:
if x_func.__module__ == "__main__":
dependencies.add("import pickle")
func_code += f"\n\n{x_func.__name__} = pickle.loads({repr(pickle.dumps(x_func))})\n"
else:
func_code += f"\n\nfrom {x_func.__module__} import {x_func.__name__}\n"

# write extra_vars file, if there are any used extra_vars
if self._used_extra_vars:
Expand Down Expand Up @@ -1296,7 +1337,12 @@ def load_raw(self, rg, args, runner=None, dtype=None, dot=None):
continue
if arg.startswith("_arg"):
continue
arguments.append(np.asarray(rg.get_named_array(arg)))
arg_value = rg.get_named_array(arg)
# aux_vars get passed through as is, not forced to be arrays
if arg.startswith("__aux_var"):
arguments.append(arg_value)
else:
arguments.append(np.asarray(arg_value))
kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
Expand Down Expand Up @@ -1369,10 +1415,14 @@ def iload_raw(
"logsums",
}:
continue
argument = np.asarray(rg.get_named_array(arg))
if argument.dtype.kind == "O":
argument = argument.astype("unicode")
arguments.append(argument)
argument = rg.get_named_array(arg)
# aux_vars get passed through as is, not forced to be arrays
if arg.startswith("__aux_var"):
arguments.append(argument)
else:
if argument.dtype.kind == "O":
argument = argument.astype("unicode")
arguments.append(np.asarray(argument))
kwargs = {}
if dtype is not None:
kwargs["dtype"] = dtype
Expand Down
Loading

0 comments on commit 771f1b1

Please sign in to comment.