Skip to content

Commit

Permalink
Fix bitwise operations (#31)
Browse files Browse the repository at this point in the history
* xxx.isna()

* isna for unicode/obj

* fix test for py3.7

* make boolean wrapping of bitwise operations optional
  • Loading branch information
jpn-- authored Sep 23, 2022
1 parent 22e5a08 commit c15d9c0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sharrow/aster.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def __init__(
preferred_spacename=None,
extra_vars=None,
blenders=None,
bool_wrapping=False,
):
self.spacename = spacename
self.dim_slots = dim_slots
Expand All @@ -347,6 +348,7 @@ def __init__(
self.preferred_spacename = preferred_spacename
self.extra_vars = extra_vars or {}
self.blenders = blenders or {}
self.bool_wrapping = bool_wrapping

def log_event(self, tag, node1=None, node2=None):
if logger.getEffectiveLevel() <= 0:
Expand Down Expand Up @@ -691,7 +693,7 @@ def visit_Name(self, node):

def visit_UnaryOp(self, node):
# convert bitflip `~x` operator into `~np.bool_(x)`
if isinstance(node.op, ast.Invert):
if self.bool_wrapping and isinstance(node.op, ast.Invert):
return ast.UnaryOp(
op=node.op,
operand=bool_wrap(self.visit(node.operand)),
Expand All @@ -707,7 +709,9 @@ def visit_BinOp(self, node):
left = self.visit(node.left)
right = self.visit(node.right)

if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor)):
if self.bool_wrapping and isinstance(
node.op, (ast.BitAnd, ast.BitOr, ast.BitXor)
):

result = ast.BinOp(
left=bool_wrap(left),
Expand Down Expand Up @@ -924,6 +928,7 @@ def expression_for_numba(
prefer_name=None,
extra_vars=None,
blenders=None,
bool_wrapping=False,
):
"""
Rewrite an expression so numba can compile it.
Expand Down Expand Up @@ -958,6 +963,7 @@ def expression_for_numba(
prefer_name,
extra_vars,
blenders,
bool_wrapping,
).visit(ast.parse(expr))
)

Expand Down
20 changes: 20 additions & 0 deletions sharrow/digital_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def set(self, name, *args, **kwargs):
by_dict : {8, 16, 32}, optional
Encode by dictionary, using this bitwidth. If given, all
arguments other than this and `x` are ignored.
joint_dict : bool or str, optional
If given as a string, the variables in `name` will be encoded
with joint dictionary encoding under this name. Or simply give
a `True` value to apply the same with a random unique name.
Returns
-------
Expand Down Expand Up @@ -287,6 +291,22 @@ def baggage(self, names):


def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None):
"""
Apply a joint dictionary encoding to a collection of Dataset variables.
Parameters
----------
ds : Dataset
encode_vars : Collection[str], optional
The collection of dataset variable names that will be encoded with
a joint dictionary.
encoding_name : str, optional
Use this name for the newly created joint encoding.
Returns
-------
"""
logger = logging.getLogger("sharrow")
if not isinstance(encoding_name, str):
i = 0
Expand Down
9 changes: 9 additions & 0 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def __new__(
hashing_level=1,
dim_order=None,
dim_exclude=None,
bool_wrapping=False,
):
assert isinstance(tree, DataTree)
tree.digitize_relationships(inplace=True)
Expand All @@ -603,6 +604,7 @@ def __new__(
boundscheck=boundscheck,
nopython=nopython,
fastmath=fastmath,
bool_wrapping=bool_wrapping,
)
# return from library if available
if flow_library is not None and self.flow_hash in flow_library:
Expand Down Expand Up @@ -641,6 +643,7 @@ def __initialize_1(
hashing_level=1,
dim_order=None,
dim_exclude=None,
bool_wrapping=False,
):
"""
Initialize up to the flow_hash
Expand All @@ -660,6 +663,7 @@ def __initialize_1(
self._secondary_flows = {}
self.dim_order = dim_order
self.dim_exclude = dim_exclude
self.bool_wrapping = bool_wrapping

all_raw_names = set()
all_name_tokens = set()
Expand Down Expand Up @@ -776,6 +780,7 @@ def _flow_hash_push(x):
_flow_hash_push(f"boundscheck={boundscheck}")
_flow_hash_push(f"error_model={error_model}")
_flow_hash_push(f"fastmath={fastmath}")
_flow_hash_push(f"bool_wrapping={bool_wrapping}")

self.flow_hash = base64.b32encode(flow_hash.digest()).decode()
self.flow_hash_audit = "]\n# [".join(flow_hash_audit)
Expand Down Expand Up @@ -873,6 +878,7 @@ def init_sub_funcs(
digital_encodings=digital_encodings,
extra_vars=self.tree.extra_vars,
blenders=blenders,
bool_wrapping=self.bool_wrapping,
)
except KeyError as key_err:
if ".." in key_err.args[0]:
Expand All @@ -896,6 +902,7 @@ def init_sub_funcs(
prefer_name=other_spacename,
extra_vars=self.tree.extra_vars,
blenders=blenders,
bool_wrapping=self.bool_wrapping,
)
except KeyError as err: # noqa: F841
pass
Expand All @@ -921,6 +928,7 @@ def init_sub_funcs(
self.output_name_positions,
"_outputs",
extra_vars=self.tree.extra_vars,
bool_wrapping=self.bool_wrapping,
)

aux_tokens = {
Expand All @@ -936,6 +944,7 @@ def init_sub_funcs(
spacevars=aux_tokens,
prefer_name="aux_var",
extra_vars=self.tree.extra_vars,
bool_wrapping=self.bool_wrapping,
)

if (k == init_expr) and (init_expr == expr) and k.isidentifier():
Expand Down

0 comments on commit c15d9c0

Please sign in to comment.