Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

array api: add jnp.bitwise_* aliases #19278

Merged
merged 1 commit into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,11 @@ namespace; they are listed below.
bincount
bitwise_and
bitwise_count
bitwise_invert
bitwise_left_shift
bitwise_not
bitwise_or
bitwise_right_shift
bitwise_xor
blackman
block
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def _arccosh(x: ArrayLike, /) -> Array:
return out

fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not)
bitwise_invert = _one_to_one_unop(getattr(np, 'bitwise_invert', np.invert), lax.bitwise_not)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
Expand Down Expand Up @@ -176,6 +178,7 @@ def _arccosh(x: ArrayLike, /) -> Array:

add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_left_shift = _one_to_one_binop(getattr(np, "bitwise_left_shift", np.left_shift), lax.shift_left, promote_to_numeric=True)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left, promote_to_numeric=True)
Expand Down Expand Up @@ -225,6 +228,13 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)

@_wraps(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy')
@partial(jit, inline=True)
def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)

@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def bitwise_and(x1, x2, /):
def bitwise_left_shift(x1, x2, /):
"""Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i."""
x1, x2 = _promote_dtypes("bitwise_left_shift", x1, x2)
return jax.numpy.left_shift(x1, x2)
return jax.numpy.bitwise_left_shift(x1, x2)


def bitwise_invert(x, /):
"""Inverts (flips) each bit for each element x_i of the input array x."""
x, = _promote_dtypes("bitwise_invert", x)
return jax.numpy.bitwise_not(x)
return jax.numpy.bitwise_invert(x)


def bitwise_or(x1, x2, /):
Expand All @@ -110,7 +110,7 @@ def bitwise_or(x1, x2, /):
def bitwise_right_shift(x1, x2, /):
"""Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("bitwise_right_shift", x1, x2)
return jax.numpy.right_shift(x1, x2)
return jax.numpy.bitwise_right_shift(x1, x2)


def bitwise_xor(x1, x2, /):
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@
atan2 as atan2,
bitwise_and as bitwise_and,
bitwise_count as bitwise_count,
bitwise_invert as bitwise_invert,
bitwise_left_shift as bitwise_left_shift,
bitwise_not as bitwise_not,
bitwise_right_shift as bitwise_right_shift,
bitwise_or as bitwise_or,
bitwise_xor as bitwise_xor,
cbrt as cbrt,
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ def bincount(x: ArrayLike, weights: Optional[ArrayLike] = ...,
minlength: int = ..., *, length: Optional[int] = ...) -> Array: ...
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_count(x: ArrayLike, /) -> Array: ...
def bitwise_invert(x: ArrayLike, /) -> Array: ...
def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_not(x: ArrayLike, /) -> Array: ...
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def blackman(M: int) -> Array: ...
def block(arrays: Union[ArrayLike, Sequence[ArrayLike], Sequence[Sequence[ArrayLike]]]) -> Array: ...
Expand Down
10 changes: 6 additions & 4 deletions tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
JAX_BITWISE_OP_RECORDS = [
op_record("bitwise_and", 2, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("bitwise_invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, [], alias='bitwise_not'),
op_record("bitwise_not", 1, int_dtypes + unsigned_dtypes, all_shapes,
jtu.rand_fullrange, []),
op_record("invert", 1, int_dtypes + unsigned_dtypes, all_shapes,
Expand Down Expand Up @@ -574,7 +576,7 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):

@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(name=rec.name, rng_factory=rec.rng_factory)],
[dict(name=rec.name, rng_factory=rec.rng_factory, alias=rec.alias)],
shapes=filter(
_shapes_are_broadcast_compatible,
itertools.combinations_with_replacement(rec.shapes, rec.nargs)),
Expand All @@ -584,8 +586,8 @@ def testBinaryOperatorDefers(self, op_name, rng_factory, dtype):
)
for rec in JAX_BITWISE_OP_RECORDS))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testBitwiseOp(self, name, rng_factory, shapes, dtypes):
np_op = getattr(np, name)
def testBitwiseOp(self, name, rng_factory, shapes, dtypes, alias):
np_op = getattr(np, name) if hasattr(np, name) else getattr(np, alias)
jnp_op = getattr(jnp, name)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
Expand Down Expand Up @@ -617,7 +619,7 @@ def testBitwiseCount(self, shape, dtype):
for dtypes in itertools.product(
*(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes))
],
op=[jnp.left_shift, jnp.right_shift],
op=[jnp.left_shift, jnp.bitwise_left_shift, jnp.right_shift, jnp.bitwise_right_shift],
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testShiftOpAgainstNumpy(self, op, dtypes, shapes):
Expand Down
3 changes: 1 addition & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5652,8 +5652,7 @@ def testWrappedSignaturesMatch(self):
if dtype != dtypes.bfloat16]

# TODO(jakevdp): implement missing ufuncs
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_invert', 'bitwise_left_shift',
'bitwise_right_shift', 'pow'}
UNIMPLEMENTED_UFUNCS = {'spacing', 'pow'}


def _all_numpy_ufuncs() -> Iterator[str]:
Expand Down