Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for max_version, dl_device, copy kwargs in __dlpack__ to match Array API #20198

Merged
merged 1 commit into from
Apr 11, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Mar 12, 2024

Towards #20200

cf. data-apis/array-api#741, data-apis/array-api#602, https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html

Note

"In principle, arbitrary cross-device copies could be allowed too, but the consensus in data-apis/array-api#626 was that limiting to device-to-host copies is enough for now". This PR includes the optional device-to-device transfer.

Default behavior is preserved when max_version=None, dl_device=None, copy=None,

This PR also adds new versioning information to jax/_src/dlpack.py:

DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)

jax/_src/array.py Outdated Show resolved Hide resolved
jax/_src/dlpack.py Outdated Show resolved Hide resolved
jax/_src/dlpack.py Outdated Show resolved Hide resolved
@jakevdp jakevdp self-assigned this Mar 12, 2024
@Micky774 Micky774 force-pushed the array_api_dlpack branch 3 times, most recently from 35543cf to b4b1164 Compare March 12, 2024 18:05
jax/_src/array.py Outdated Show resolved Hide resolved
jax/_src/dlpack.py Outdated Show resolved Hide resolved
jax/_src/array.py Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2024

Minor comment, but in general it's useful to have meaningful commit messages ("Update" doesn't communicate much about what the change includes).

@Micky774
Copy link
Collaborator Author

Micky774 commented Mar 25, 2024

@yashk2810 @jakevdp This PR should be ready for review again. I've added tests and updated the documentation of to_dlpack.

@Micky774 Micky774 force-pushed the array_api_dlpack branch 2 times, most recently from 4a85f4a to 65be6b3 Compare March 26, 2024 16:38
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 11, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 11, 2024

Tests are failing because jax2tf uses the take_ownership parameter removed in this PR: https://github.com/google/jax/blob/abfbb0ae2b85a6c887caa5d11609b347ee6b8f2d/jax/experimental/jax2tf/call_tf.py#L337

(jax2tf tests are not run in github CI because they are too expensive).
I think a sufficient fix will be removing that parameter from the call site, since it's deprecated and currently does nothing.

@Micky774
Copy link
Collaborator Author

@jakevdp Done -- I've also simplified the tests a bit and refactored DLDeviceType to jax._src.typing

@copybara-service copybara-service bot merged commit d9d11a3 into jax-ml:main Apr 11, 2024
13 checks passed
@Micky774 Micky774 deleted the array_api_dlpack branch April 11, 2024 19:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants