Skip to content

Commit

Permalink
Fix axis=None and keepdims=True for jnp.quantile and jnp.median
Browse files Browse the repository at this point in the history
Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median`

Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median`

Remove `print`

Update tests

Remove `print`

Update tests
  • Loading branch information
james77777778 committed Mar 15, 2024
1 parent f17f59b commit a92fba7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 13 deletions.
1 change: 0 additions & 1 deletion jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,6 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5))
else:
raise ValueError(f"interpolation={interpolation!r} not recognized")
print("keepdims", keepdims, "keepdim", keepdim)
if keepdims and keepdim:
if q_ndim > 0:
keepdim = [np.shape(q)[0], *keepdim]
Expand Down
14 changes: 2 additions & 12 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights,
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((6, 7,), None),
((47, 7), 0),
((47, 7), ()),
((4, 101), 1),
Expand Down Expand Up @@ -705,12 +706,6 @@ def np_fun(*args):
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)

def testQuantileKeepdims(self):
x = np.ones([2, 3, 4]).astype(np.float32)
expected = np.quantile(x, 0.5, axis=None, keepdims=True)
actual = jnp.quantile(x, 0.5, axis=None, keepdims=True)
self.assertAllClose(expected, actual, atol=0)

@unittest.skipIf(not config.enable_x64.value, "test requires X64")
@jtu.run_on_devices("cpu") # test is for CPU float64 precision
def testPercentilePrecision(self):
Expand All @@ -722,6 +717,7 @@ def testPercentilePrecision(self):
[dict(a_shape=a_shape, axis=axis)
for a_shape, axis in (
((7,), None),
((6, 7,), None),
((47, 7), 0),
((4, 101), 1),
)
Expand Down Expand Up @@ -750,12 +746,6 @@ def np_fun(*args):
tol=tol)
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol)

def testMedianKeepdims(self):
x = np.ones([2, 3, 4], dtype=np.float32)
expected = np.median(x, axis=None, keepdims=True)
actual = jnp.median(x, axis=None, keepdims=True)
self.assertAllClose(expected, actual, atol=0)

def testMeanLargeArray(self):
# https://github.com/google/jax/issues/15068
raise unittest.SkipTest("test is slow, but it passes!")
Expand Down

0 comments on commit a92fba7

Please sign in to comment.