From c15d9c02688877ebb9f92b191ca8c94c06fe9d52 Mon Sep 17 00:00:00 2001 From: Jeffrey Newman Date: Fri, 23 Sep 2022 18:50:53 -0500 Subject: [PATCH] Fix bitwise operations (#31) * xxx.isna() * isna for unicode/obj * fix test for py3.7 * make boolean wrapping of bitwise operations optional --- sharrow/aster.py | 10 ++++++++-- sharrow/digital_encoding.py | 20 ++++++++++++++++++++ sharrow/flows.py | 9 +++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/sharrow/aster.py b/sharrow/aster.py index f0e655b..ce2e2ba 100755 --- a/sharrow/aster.py +++ b/sharrow/aster.py @@ -337,6 +337,7 @@ def __init__( preferred_spacename=None, extra_vars=None, blenders=None, + bool_wrapping=False, ): self.spacename = spacename self.dim_slots = dim_slots @@ -347,6 +348,7 @@ def __init__( self.preferred_spacename = preferred_spacename self.extra_vars = extra_vars or {} self.blenders = blenders or {} + self.bool_wrapping = bool_wrapping def log_event(self, tag, node1=None, node2=None): if logger.getEffectiveLevel() <= 0: @@ -691,7 +693,7 @@ def visit_Name(self, node): def visit_UnaryOp(self, node): # convert bitflip `~x` operator into `~np.bool_(x)` - if isinstance(node.op, ast.Invert): + if self.bool_wrapping and isinstance(node.op, ast.Invert): return ast.UnaryOp( op=node.op, operand=bool_wrap(self.visit(node.operand)), @@ -707,7 +709,9 @@ def visit_BinOp(self, node): left = self.visit(node.left) right = self.visit(node.right) - if isinstance(node.op, (ast.BitAnd, ast.BitOr, ast.BitXor)): + if self.bool_wrapping and isinstance( + node.op, (ast.BitAnd, ast.BitOr, ast.BitXor) + ): result = ast.BinOp( left=bool_wrap(left), @@ -924,6 +928,7 @@ def expression_for_numba( prefer_name=None, extra_vars=None, blenders=None, + bool_wrapping=False, ): """ Rewrite an expression so numba can compile it. @@ -958,6 +963,7 @@ def expression_for_numba( prefer_name, extra_vars, blenders, + bool_wrapping, ).visit(ast.parse(expr)) ) diff --git a/sharrow/digital_encoding.py b/sharrow/digital_encoding.py index 133b840..2b974c4 100644 --- a/sharrow/digital_encoding.py +++ b/sharrow/digital_encoding.py @@ -234,6 +234,10 @@ def set(self, name, *args, **kwargs): by_dict : {8, 16, 32}, optional Encode by dictionary, using this bitwidth. If given, all arguments other than this and `x` are ignored. + joint_dict : bool or str, optional + If given as a string, the variables in `name` will be encoded + with joint dictionary encoding under this name. Or simply give + a `True` value to apply the same with a random unique name. Returns ------- @@ -287,6 +291,22 @@ def baggage(self, names): def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): + """ + Apply a joint dictionary encoding to a collection of Dataset variables. + + Parameters + ---------- + ds : Dataset + encode_vars : Collection[str], optional + The collection of dataset variable names that will be encoded with + a joint dictionary. + encoding_name : str, optional + Use this name for the newly created joint encoding. + + Returns + ------- + + """ logger = logging.getLogger("sharrow") if not isinstance(encoding_name, str): i = 0 diff --git a/sharrow/flows.py b/sharrow/flows.py index 35c2881..b96f823 100644 --- a/sharrow/flows.py +++ b/sharrow/flows.py @@ -582,6 +582,7 @@ def __new__( hashing_level=1, dim_order=None, dim_exclude=None, + bool_wrapping=False, ): assert isinstance(tree, DataTree) tree.digitize_relationships(inplace=True) @@ -603,6 +604,7 @@ def __new__( boundscheck=boundscheck, nopython=nopython, fastmath=fastmath, + bool_wrapping=bool_wrapping, ) # return from library if available if flow_library is not None and self.flow_hash in flow_library: @@ -641,6 +643,7 @@ def __initialize_1( hashing_level=1, dim_order=None, dim_exclude=None, + bool_wrapping=False, ): """ Initialize up to the flow_hash @@ -660,6 +663,7 @@ def __initialize_1( self._secondary_flows = {} self.dim_order = dim_order self.dim_exclude = dim_exclude + self.bool_wrapping = bool_wrapping all_raw_names = set() all_name_tokens = set() @@ -776,6 +780,7 @@ def _flow_hash_push(x): _flow_hash_push(f"boundscheck={boundscheck}") _flow_hash_push(f"error_model={error_model}") _flow_hash_push(f"fastmath={fastmath}") + _flow_hash_push(f"bool_wrapping={bool_wrapping}") self.flow_hash = base64.b32encode(flow_hash.digest()).decode() self.flow_hash_audit = "]\n# [".join(flow_hash_audit) @@ -873,6 +878,7 @@ def init_sub_funcs( digital_encodings=digital_encodings, extra_vars=self.tree.extra_vars, blenders=blenders, + bool_wrapping=self.bool_wrapping, ) except KeyError as key_err: if ".." in key_err.args[0]: @@ -896,6 +902,7 @@ def init_sub_funcs( prefer_name=other_spacename, extra_vars=self.tree.extra_vars, blenders=blenders, + bool_wrapping=self.bool_wrapping, ) except KeyError as err: # noqa: F841 pass @@ -921,6 +928,7 @@ def init_sub_funcs( self.output_name_positions, "_outputs", extra_vars=self.tree.extra_vars, + bool_wrapping=self.bool_wrapping, ) aux_tokens = { @@ -936,6 +944,7 @@ def init_sub_funcs( spacevars=aux_tokens, prefer_name="aux_var", extra_vars=self.tree.extra_vars, + bool_wrapping=self.bool_wrapping, ) if (k == init_expr) and (init_expr == expr) and k.isidentifier():