Skip to content

Commit

Permalink
np.where conversion to IfExp (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpn-- authored Jun 4, 2024
1 parent 8438639 commit 8d63a66
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,12 +939,12 @@ def visit_Call(self, node):
):
if len(node.args) == 2 and len(node.keywords) == 0:
try:
return self._replacement(
result = self._replacement(
ast_String_value(node.args[0]), node.func.value.ctx, node
)
except KeyError:
if self.get_default:
return self.visit(node.args[1])
result = self.visit(node.args[1])
else:
raise
if (
Expand All @@ -953,19 +953,34 @@ def visit_Call(self, node):
and "default" == node.keywords[0].arg
):
try:
return self._replacement(
result = self._replacement(
ast_String_value(node.args[0]), node.func.value.ctx, node
)
except KeyError:
if self.get_default:
return self.visit(node.keywords[0].value)
result = self.visit(node.keywords[0].value)
else:
raise
if len(node.args) == 1 and len(node.keywords) == 0:
return self._replacement(
result = self._replacement(
ast_String_value(node.args[0]), node.func.value.ctx, node
)

# change np.where(x, y, z) to (y if x else z), for performance reasons
if (
isinstance(node.func, ast.Attribute)
and node.func.attr == "where"
and isinstance(node.func.value, ast.Name)
and (node.func.value.id == "np" or node.func.value.id == "numpy")
and len(node.args) == 3
):
if len(node.args) == 3 and len(node.keywords) == 0:
result = ast.IfExp(
test=self.visit(node.args[0]),
body=self.visit(node.args[1]),
orelse=self.visit(node.args[2]),
)

# if no other changes
if result is None:
args = [self.visit(i) for i in node.args]
Expand Down

0 comments on commit 8d63a66

Please sign in to comment.