Skip to content

Commit

Permalink
array api: add jnp.bitwise_* aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 10, 2024
1 parent 57f0559 commit 4e55086
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 9 deletions.
3 changes: 3 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ namespace; they are listed below.
bincount
bitwise_and
bitwise_count
bitwise_invert
bitwise_left_shift
bitwise_not
bitwise_or
bitwise_right_shift
bitwise_xor
blackman
block
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def _arccosh(x: ArrayLike, /) -> Array:
return out

fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not)
bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
Expand Down Expand Up @@ -176,6 +178,7 @@ def _arccosh(x: ArrayLike, /) -> Array:

add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_left_shift = _one_to_one_binop(getattr(np, "bitwise_left_shift", np.left_shift), lax.shift_left, promote_to_numeric=True)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
Expand Down Expand Up @@ -225,6 +228,13 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)

@_wraps(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy')
@partial(jit, inline=True)
def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)

@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def bitwise_and(x1, x2, /):
def bitwise_left_shift(x1, x2, /):
"""Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i."""
x1, x2 = _promote_dtypes("bitwise_left_shift", x1, x2)
return jax.numpy.left_shift(x1, x2)
return jax.numpy.bitwise_left_shift(x1, x2)


def bitwise_invert(x, /):
"""Inverts (flips) each bit for each element x_i of the input array x."""
x, = _promote_dtypes("bitwise_invert", x)
return jax.numpy.bitwise_not(x)
return jax.numpy.bitwise_invert(x)


def bitwise_or(x1, x2, /):
Expand All @@ -110,7 +110,7 @@ def bitwise_or(x1, x2, /):
def bitwise_right_shift(x1, x2, /):
"""Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("bitwise_right_shift", x1, x2)
return jax.numpy.right_shift(x1, x2)
return jax.numpy.bitwise_right_shift(x1, x2)


def bitwise_xor(x1, x2, /):
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@
atan2 as atan2,
bitwise_and as bitwise_and,
bitwise_count as bitwise_count,
bitwise_invert as bitwise_invert,
bitwise_left_shift as bitwise_left_shift,
bitwise_not as bitwise_not,
bitwise_right_shift as bitwise_right_shift,
bitwise_or as bitwise_or,
bitwise_xor as bitwise_xor,
cbrt as cbrt,
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ...,
minlength: int = ..., *, length: Optional[int] = ...) -> Array: ...
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_count(x: ArrayLike, /) -> Array: ...
def bitwise_invert(x: ArrayLike, /) -> Array: ...
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_not(x: ArrayLike, /) -> Array: ...
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def blackman(M: int) -> Array: ...
def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ...
Expand Down
10 changes: 6 additions & 4 deletions tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("bitwise_invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, [], alias='bitwise_not'),
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
Expand Down Expand Up @@ -574,7 +576,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):

@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory)],
[dict(name=rec.name, rng_factory=rec.rng_factory, alias=rec.alias)],
shapes=filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs)),
Expand All @@ -584,8 +586,8 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
)
for rec in JAX_BITWISE_OP_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testBitwiseOp(self, name, rng_factory, shapes, dtypes):
np_op = getattr(np, name)
def testBitwiseOp(self, name, rng_factory, shapes, dtypes, alias):
np_op = getattr(np, name) if hasattr(np, name) else getattr(np, alias)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
Expand Down Expand Up @@ -617,7 +619,7 @@ def testBitwiseCount(self, shape, dtype):
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes))
],
op=[jnp.left_shift, jnp.right_shift],
op=[jnp.left_shift, jnp.bitwise_left_shift, jnp.right_shift, jnp.bitwise_right_shift],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
Expand Down
3 changes: 1 addition & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5652,8 +5652,7 @@ def testWrappedSignaturesMatch(self):
if dtype != dtypes.bfloat16]

# TODO(jakevdp): implement missing ufuncs
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_invert', 'bitwise_left_shift',
'bitwise_right_shift', 'pow'}
UNIMPLEMENTED_UFUNCS = {'spacing', 'pow'}


def _all_numpy_ufuncs() -> Iterator[str]:
Expand Down

0 comments on commit 4e55086

Please sign in to comment.