Skip to content

Commit

Permalink
Mixed types (#7)
Browse files Browse the repository at this point in the history
* mixed types

* fix bug in np.where
  • Loading branch information
jpn-- authored Mar 11, 2022
1 parent f0c0b92 commit ab7ac06
Show file tree
Hide file tree
Showing 3 changed files with 5,070 additions and 10 deletions.
36 changes: 26 additions & 10 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,10 @@ def __initialize_2(
if f_args_j:
f_args_j += ", "
meta_code.append(
f"result[{js}, {n}] = {clean(k)}({f_args_j}result[{js}], {f_name_tokens})"
f"result[{js}, {n}] = ({clean(k)}({f_args_j}result[{js}], {f_name_tokens})).item()"
)
meta_code_dot.append(
f"intermediate[{n}] = {clean(k)}({f_args_j}intermediate, {f_name_tokens})"
f"intermediate[{n}] = ({clean(k)}({f_args_j}intermediate, {f_name_tokens})).item()"
)
meta_code_stack = textwrap.indent(
"\n".join(meta_code), " " * 12
Expand All @@ -1151,15 +1151,31 @@ def __initialize_2(
if not meta_code_stack_dot:
meta_code_stack_dot = "pass"
if n_root_dims == 1:
meta_template = IRUNNER_1D_TEMPLATE.format(**locals())
meta_template_dot = IDOTTER_1D_TEMPLATE.format(**locals())
line_template = ILINER_1D_TEMPLATE.format(**locals())
mnl_template = MNL_1D_TEMPLATE.format(**locals())
meta_template = IRUNNER_1D_TEMPLATE.format(**locals()).format(
**locals()
)
meta_template_dot = IDOTTER_1D_TEMPLATE.format(
**locals()
).format(**locals())
line_template = ILINER_1D_TEMPLATE.format(**locals()).format(
**locals()
)
mnl_template = MNL_1D_TEMPLATE.format(**locals()).format(
**locals()
)
elif n_root_dims == 2:
meta_template = IRUNNER_2D_TEMPLATE.format(**locals())
meta_template_dot = IDOTTER_2D_TEMPLATE.format(**locals())
line_template = ILINER_2D_TEMPLATE.format(**locals())
mnl_template = MNL_2D_TEMPLATE.format(**locals())
meta_template = IRUNNER_2D_TEMPLATE.format(**locals()).format(
**locals()
)
meta_template_dot = IDOTTER_2D_TEMPLATE.format(
**locals()
).format(**locals())
line_template = ILINER_2D_TEMPLATE.format(**locals()).format(
**locals()
)
mnl_template = MNL_2D_TEMPLATE.format(**locals()).format(
**locals()
)
else:
raise ValueError(f"invalid n_root_dims {n_root_dims}")

Expand Down
43 changes: 43 additions & 0 deletions sharrow/tests/test_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,49 @@ def test_with_2d_base(dataframe_regression):
np.testing.assert_array_almost_equal(check_vs, dot_result.to_numpy())


def test_mixed_dtypes(dataframe_regression):
data = example_data.get_data()
skims = data["skims"]
households = data["hhs"]

prng = default_rng(SeedSequence(42))
households["otaz_idx"] = households["TAZ"] - 1
households["dtaz_idx"] = prng.choice(np.arange(25), 5000)
households["timeperiod5"] = prng.choice(np.arange(5), 5000)
households["timeperiod3"] = np.clip(households["timeperiod5"], 1, 3) - 1
households["rownum"] = np.arange(len(households))

tree = DataTree(
base=households,
skims=skims,
relationships=(
"base.otaz_idx->skims.otaz",
"base.dtaz_idx->skims.dtaz",
"base.timeperiod5->skims.time_period",
),
)

ss = tree.setup_flow(
{
"income": "base.income",
"sov_time_by_income": "skims.SOV_TIME/base.income",
"sov_time_by_workers": "np.where(base.workers > 0, skims.SOV_TIME / base.workers, 0)",
}
)
result = ss._load(tree, as_dataframe=True, dtype=np.float32)
dataframe_regression.check(result)

ss_undot = tree.setup_flow(
{
"income": "income",
"sov_time_by_income": "SOV_TIME/income",
"sov_time_by_workers": "np.where(workers > 0, SOV_TIME / workers, 0)",
}
)
result = ss_undot._load(tree, as_dataframe=True, dtype=np.float32)
dataframe_regression.check(result)


def _get_target(q):
skims_ = Dataset.shm.from_shared_memory("skims")
q.put(skims_.SOV_TIME.sum())
Expand Down
Loading

0 comments on commit ab7ac06

Please sign in to comment.