Skip to content

Commit

Permalink
Bugfix for df.get (#34)
Browse files Browse the repository at this point in the history
* repair and test subspace fallbacks

* allow error on graphviz

* find "get" when inside another function

* fix root_dims
  • Loading branch information
jpn-- authored Jan 13, 2023
1 parent e3a3cc4 commit 5c6aecb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def filter_name_tokens(expr, matchable_names=None):
return name_tokens, arg_tokens


class ExtractOptionalGetTokens(ast.NodeTransformer):
class ExtractOptionalGetTokens(ast.NodeVisitor):
def __init__(self, from_names):
self.optional_get_tokens = set()
self.required_get_tokens = set()
Expand Down Expand Up @@ -178,7 +178,7 @@ def visit_Call(self, node):
raise ValueError(
f"{node.func.value.id}.get with more than 2 positional arguments"
)
return node
self.generic_visit(node)

def check(self, node):
if isinstance(node, str):
Expand Down
4 changes: 3 additions & 1 deletion sharrow/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ def shape(self):
def root_dims(self):
from .flows import presorted

return tuple(presorted(self.root_dataset, self.dim_order, self.dim_exclude))
return tuple(
presorted(self.root_dataset.dims, self.dim_order, self.dim_exclude)
)

def __shallow_copy_extras(self):
return dict(
Expand Down
17 changes: 17 additions & 0 deletions sharrow/tests/test_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,23 @@ def test_get(dataframe_regression, households, skims):
assert flow4.flow_hash != flow1.flow_hash
dataframe_regression.check(result)

# test get when inside another function
flow5 = tree.setup_flow(
{
"income": "np.power(base.get('income', default=0) + df.get('missing_one', 0), 1)",
"sov_time_by_income": "skims.SOV_TIME/np.power(base.get('income', default=0), 1)",
"missing_data": "np.where(np.isnan(df.get('missing_data', default=1)), 0, df.get('missing_data', default=-1))",
"missing_skim": "(np.where(np.isnan(df.get('num_escortees', np.nan)), -2 , df.get('num_escortees', np.nan)))",
"sov_time_by_income_2": "skims.get('SOV_TIME', default=0)/base.income",
"sov_cost_by_income_2": "skims.get('HOV3_TIME', default=999)",
},
)
result = flow5._load(tree, as_dataframe=True)
assert "__skims__HOV3_TIME:True" in flow5._optional_get_tokens
assert "__df__missing_data:False" in flow5._optional_get_tokens
assert "__df__num_escortees:False" in flow5._optional_get_tokens
dataframe_regression.check(result)


def test_get_native():
data = example_data.get_data()
Expand Down

0 comments on commit 5c6aecb

Please sign in to comment.