Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix axis=None, keepdims=True for jnp.quantile and jnp.median #20251

Merged
merged 1 commit into from
Mar 20, 2024

Conversation

james77777778
Copy link
Contributor

This PR fixes the incorrect behavior:

>>> x = np.ones([2, 3, 4])
>>> np.median(x, axis=None, keepdims=True).shape
(1, 1, 1)
>>> jnp.median(x, axis=None, keepdims=True).shape
(1,)  # mismatched!
>>> np.quantile(x, 0.5, axis=None, keepdims=True).shape
(1, 1, 1)
>>> jnp.quantile(x, 0.5, axis=None, keepdims=True).shape
(1,)  # mismatched!

The corresponding tests have been included

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2024

Good catch, thanks for the fix! One minor suggestion: rather than new specific test cases, we could add relevant cases to the old tests. For example here: https://github.com/google/jax/blob/64bd95ded50fc647d6e80a55ab1cb3baef368cf6/tests/lax_numpy_reducers_test.py#L664
and here: https://github.com/google/jax/blob/64bd95ded50fc647d6e80a55ab1cb3baef368cf6/tests/lax_numpy_reducers_test.py#L718
We could add another case that looks something like ((6, 7), None) and that would lead to this case being covered for our runs with larger JAX_NUM_GENERATED_CASES. What do you think?

@jakevdp jakevdp self-requested a review March 14, 2024 15:19
@jakevdp jakevdp self-assigned this Mar 14, 2024
@james77777778
Copy link
Contributor Author

We could add another case that looks something like ((6, 7), None) and that would lead to this case being covered for our runs with larger JAX_NUM_GENERATED_CASES. What do you think?

I'm not familar with the test configuration of JAX. Thanks for the suggestion.
The tests have been updated and the new tests should be included using JAX_NUM_GENERATED_CASES=40

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

Last thing: could you please squash the changes into a single commit? See https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests. Thanks!

Fix `axis=None` and `keepdims=True` for `jnp.quantile` and `jnp.median`

Remove `print`

Update tests
@james77777778
Copy link
Contributor Author

Last thing: could you please squash the changes into a single commit? See https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests. Thanks!

It should be a one-commit PR right now. Thanks for the tip.

@james77777778
Copy link
Contributor Author

Kindly ping @jakevdp
Is there any unresolved issue? Happy to update.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@copybara-service copybara-service bot merged commit 8da8d03 into jax-ml:main Mar 20, 2024
13 checks passed
@james77777778 james77777778 deleted the fix-quantile-keepdims branch March 21, 2024 07:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants