Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX support #83

Closed
imh opened this issue Feb 3, 2024 · 11 comments
Closed

JAX support #83

imh opened this issue Feb 3, 2024 · 11 comments

Comments

@imh
Copy link

imh commented Feb 3, 2024

If you want support for other array libraries, or if you encounter any issues, please open an issue.

Okay, well here it is. I don't know if I'd have time to add it myself in the near future, but it would be great to have.

@asmeurer
Copy link
Member

asmeurer commented Feb 5, 2024

@jakevdp has been working on adding array API support to JAX. I think his plan was to have full support in a JAX submodule so that array-api-compat library support would be unnecessary, but maybe he can clarify if that's still the plan and what the latest status of that is.

@jakevdp
Copy link
Contributor

jakevdp commented Feb 5, 2024

Hi - thanks for the tag! We are tracking JAX Array API support here: jax-ml/jax#18353

You can experimentally enable it by importing jax.experimental.array_api:

import jax.numpy as jnp

# This is a side-effecting import that among other
# things adds __array_namespace__ to JAX arrays.
import jax.experimental.array_api as xp

print(xp is jnp.array(0).__array_namespace__())  # Standard access to array_api
# True

print(xp.arange(10))  # most Array API functions are already implemented at HEAD.
# [0 1 2 3 4 5 6 7 8 9]

There are still some fundamental questions, though: for example, JAX arrays are immutable so they cannot support pieces of the API standard that require mutability, and JAX arrays had a pre-existing device() method (deprecation in progress) that conflicts with the device property in the API standard. But we're gradually ironing out those details and are hoping to have the main jax.numpy namespace be array API-compliant in a future release.

@asmeurer
Copy link
Member

asmeurer commented Feb 5, 2024

So with __array_namespace__ defined, array-api-compat will automatically work

>>> import jax.experimental.array_api as xp
>>> import array_api_compat
>>> array_api_compat.array_namespace(xp.arange(10)) is xp
True

I think the main thing to do here is to add jax support to the device() helper function.

I don't think it's necessary to add an array_api_compat.jax wrapper namespace to wrap JAX functions that aren't compliant because that is effectively already the jax.experimental.array_api namespace.

# This is a side-effecting import that among other
# things adds __array_namespace__ to JAX arrays.

What are the other things it does? I would suggest not using a side-effecting import, but rather just add __array_namespace__ to JAX arrays, and if there is any patching or something that needs to happen, do it when __array_namespace__() is called.

Otherwise, whether jax.numpy arrays will work with a library like scipy will depend on whether the user (or someone else) has already imported jax.experimental.array_api or not.

@jakevdp
Copy link
Contributor

jakevdp commented Feb 5, 2024

The only other patching currently is adding the to_device method, which is a bit of a hack for the time being until we can finish the device method deprecation and replace it with a device property.

Otherwise, whether jax.numpy arrays will work with a library like scipy will depend on whether the user (or someone else) has already imported jax.experimental.array_api or not.

This is the intent right now: JAX is not yet ready to non-experimentally support the array API. But in the future when we've gotten to the point that we can advertise non-experimental support, the import will no longer be necessary.

@asmeurer
Copy link
Member

asmeurer commented Feb 5, 2024

array-api-compat also has a to_device helper (JAX isn't the only library without these methods). In practice people using the array API today are doing it through array-api-compat, so if we just add JAX support for those helper methods here, it will be good enough for most array API users.

This is the intent right now: JAX is not yet ready to non-experimentally support the array API. But in the future when we've gotten to the point that we can advertise non-experimental support, the import will no longer be necessary.

So should we go ahead and fix the array-api-compat device and to_device helpers now to support JAX, or wait for the support to be fleshed out more?

@jakevdp
Copy link
Contributor

jakevdp commented Feb 5, 2024

Interesting - thanks. Realistically it will be 1-2 months more before we can do a JAX release with compatible device and to_device support, so we could support it via array-api-compat until then. What do you think? I can put together a PR if you'd like.

@lucascolley
Copy link
Contributor

IMO it would make sense to just wait the 1-2 months to avoid extra work here. Support from consumer libraries is still quite sparse and probably won't move forward a huge amount in the next 2 months (and being able to test with JAX a month or so earlier probably won't accelerate things).

@asmeurer
Copy link
Member

asmeurer commented Feb 6, 2024

Either way. It really isn't hard to add it here.

@ntessore
Copy link

ntessore commented Feb 6, 2024

As a consumer library maintainer, I am accessing the Array API through array_api_compat.array_namespace() for the foreseeable future anyway. As far as I understand, once JAX supports __array_namespace__ natively, all it takes is to remove the special casing here, right? Having JAX access early, and being able to potentially feed issues back, would be nice.

@asmeurer
Copy link
Member

asmeurer commented Feb 6, 2024

OK, I've added basic JAX support to the helper functions at #84.

@imh
Copy link
Author

imh commented Feb 10, 2024

Thank you!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants