diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 918a734a130d..1644b4f475c2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5373,7 +5373,7 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: if bins_arr.ndim != 1: raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}") if bins_arr.shape[0] == 0: - return zeros(x, dtype=dtypes.canonicalize_dtype(int_)) + return zeros_like(x, dtype=dtypes.canonicalize_dtype(int_)) side = 'right' if not right else 'left' return where( bins_arr[-1] >= bins_arr[0], diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d2a5a52ef32a..ed3a16eda43a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -2597,7 +2597,7 @@ def testSearchsortedNans(self, dtype, side, method): @jtu.sample_product( xshape=[(20,), (5, 4)], - binshape=[(1,), (5,)], + binshape=[(0,), (1,), (5,)], right=[True, False], reverse=[True, False], dtype=default_dtypes,