diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 36e6ceae6638..a83804dd3d3c 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -94,8 +94,11 @@ namespace; they are listed below. bincount bitwise_and bitwise_count + bitwise_invert + bitwise_left_shift bitwise_not bitwise_or + bitwise_right_shift bitwise_xor blackman block diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 275b2575d60d..46e018114fbc 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -149,6 +149,8 @@ def _arccosh(x: ArrayLike, /) -> Array: return out fabs = _one_to_one_unop(np.fabs, lax.abs, True) +bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not) +bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not) bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) invert = _one_to_one_unop(np.invert, lax.bitwise_not) negative = _one_to_one_unop(np.negative, lax.neg) @@ -176,6 +178,7 @@ def _arccosh(x: ArrayLike, /) -> Array: add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) +bitwise_left_shift = _one_to_one_binop(getattr(np, "bitwise_left_shift", np.left_shift), lax.shift_left, promote_to_numeric=True) bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True) @@ -225,6 +228,13 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) +@_wraps(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy') +@partial(jit, inline=True) +def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: + x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2) + lax_fn = lax.shift_right_logical if \ + np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic + return lax_fn(x1, x2) @_wraps(np.absolute, module='numpy') @partial(jit, inline=True) diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index 373d29098a16..a3084473bfd7 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -92,13 +92,13 @@ def bitwise_and(x1, x2, /): def bitwise_left_shift(x1, x2, /): """Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i.""" x1, x2 = _promote_dtypes("bitwise_left_shift", x1, x2) - return jax.numpy.left_shift(x1, x2) + return jax.numpy.bitwise_left_shift(x1, x2) def bitwise_invert(x, /): """Inverts (flips) each bit for each element x_i of the input array x.""" x, = _promote_dtypes("bitwise_invert", x) - return jax.numpy.bitwise_not(x) + return jax.numpy.bitwise_invert(x) def bitwise_or(x1, x2, /): @@ -110,7 +110,7 @@ def bitwise_or(x1, x2, /): def bitwise_right_shift(x1, x2, /): """Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2.""" x1, x2 = _promote_dtypes("bitwise_right_shift", x1, x2) - return jax.numpy.right_shift(x1, x2) + return jax.numpy.bitwise_right_shift(x1, x2) def bitwise_xor(x1, x2, /): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 6b51f45fa2d1..f71770f0a5d2 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -346,7 +346,10 @@ atan2 as atan2, bitwise_and as bitwise_and, bitwise_count as bitwise_count, + bitwise_invert as bitwise_invert, + bitwise_left_shift as bitwise_left_shift, bitwise_not as bitwise_not, + bitwise_right_shift as bitwise_right_shift, bitwise_or as bitwise_or, bitwise_xor as bitwise_xor, cbrt as cbrt, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index e60fa5a4e098..7b572f43c26d 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -148,8 +148,11 @@ def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ..., minlength: int = ..., *, length: Optional[int] = ...) -> Array: ... def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_count(x: ArrayLike, /) -> Array: ... +def bitwise_invert(x: ArrayLike, /) -> Array: ... +def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ... def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... +def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... def blackman(M: int) -> Array: ... def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ... diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 8b9f55e0dbfc..a169f6ffd695 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -308,6 +308,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, JAX_BITWISE_OP_RECORDS = [ op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_fullrange, []), + op_record("bitwise_invert", 1, int_dtypes + unsigned_dtypes, all_shapes, + jtu.rand_fullrange, [], alias='bitwise_not'), op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_fullrange, []), op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes, @@ -574,7 +576,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): @parameterized.parameters(itertools.chain.from_iterable( jtu.sample_product_testcases( - [dict(name=rec.name, rng_factory=rec.rng_factory)], + [dict(name=rec.name, rng_factory=rec.rng_factory, alias=rec.alias)], shapes=filter( _shapes_are_broadcast_compatible, itertools.combinations_with_replacement(rec.shapes, rec.nargs)), @@ -584,8 +586,8 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): ) for rec in JAX_BITWISE_OP_RECORDS)) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testBitwiseOp(self, name, rng_factory, shapes, dtypes): - np_op = getattr(np, name) + def testBitwiseOp(self, name, rng_factory, shapes, dtypes, alias): + np_op = getattr(np, name) if hasattr(np, name) else getattr(np, alias) jnp_op = getattr(jnp, name) rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) @@ -617,7 +619,7 @@ def testBitwiseCount(self, shape, dtype): for dtypes in itertools.product( *(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes)) ], - op=[jnp.left_shift, jnp.right_shift], + op=[jnp.left_shift, jnp.bitwise_left_shift, jnp.right_shift, jnp.bitwise_right_shift], ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testShiftOpAgainstNumpy(self, op, dtypes, shapes): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index e0d34eac6f9f..d6d4489f00c1 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5652,8 +5652,7 @@ def testWrappedSignaturesMatch(self): if dtype != dtypes.bfloat16] # TODO(jakevdp): implement missing ufuncs -UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_invert', 'bitwise_left_shift', - 'bitwise_right_shift', 'pow'} +UNIMPLEMENTED_UFUNCS = {'spacing', 'pow'} def _all_numpy_ufuncs() -> Iterator[str]: