From a92fba71915f4e0f4731a41c330feb1d634342d2 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 14 Mar 2024 22:45:02 +0800 Subject: [PATCH] Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median` Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median` Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median` Remove `print` Update tests Remove `print` Update tests --- jax/_src/numpy/reductions.py | 1 - tests/lax_numpy_reducers_test.py | 14 ++------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index b6bf214c31ed..0e628fc5e5b8 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -858,7 +858,6 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) else: raise ValueError(f"interpolation={interpolation!r} not recognized") - print("keepdims", keepdims, "keepdim", keepdim) if keepdims and keepdim: if q_ndim > 0: keepdim = [np.shape(q)[0], *keepdim] diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index ac81870d193c..b72e47d9d179 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -662,6 +662,7 @@ def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, [dict(a_shape=a_shape, axis=axis) for a_shape, axis in ( ((7,), None), + ((6, 7,), None), ((47, 7), 0), ((47, 7), ()), ((4, 101), 1), @@ -705,12 +706,6 @@ def np_fun(*args): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - def testQuantileKeepdims(self): - x = np.ones([2, 3, 4]).astype(np.float32) - expected = np.quantile(x, 0.5, axis=None, keepdims=True) - actual = jnp.quantile(x, 0.5, axis=None, keepdims=True) - self.assertAllClose(expected, actual, atol=0) - @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): @@ -722,6 +717,7 @@ def testPercentilePrecision(self): [dict(a_shape=a_shape, axis=axis) for a_shape, axis in ( ((7,), None), + ((6, 7,), None), ((47, 7), 0), ((4, 101), 1), ) @@ -750,12 +746,6 @@ def np_fun(*args): tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) - def testMedianKeepdims(self): - x = np.ones([2, 3, 4], dtype=np.float32) - expected = np.median(x, axis=None, keepdims=True) - actual = jnp.median(x, axis=None, keepdims=True) - self.assertAllClose(expected, actual, atol=0) - def testMeanLargeArray(self): # https://github.com/google/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!")