Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array API] add device property & to_device method #22597

Merged
merged 1 commit into from
Jul 23, 2024

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jul 23, 2024

Part of #21088 and #18353

As discussed earlier, arr.device will return the device for single-device arrays, and the sharding for non-trivially sharded arrays. This lets arr.device be passed to device_put, as well as to other Array API arguments that accept Device | Sharding.

jax/_src/array.py Show resolved Hide resolved
jax/_src/earray.py Show resolved Hide resolved
jax/_src/numpy/array_methods.py Outdated Show resolved Hide resolved
copybara-service bot pushed a commit that referenced this pull request Jul 23, 2024
+ enable more jax/lax tests for XLA CPU thunks

FUTURE_COPYBARA_INTEGRATE_REVIEW=#22597 from jakevdp:arr-device 613a000
PiperOrigin-RevId: 654865806
@copybara-service copybara-service bot merged commit dc42ba0 into jax-ml:main Jul 23, 2024
12 of 15 checks passed
@jakevdp jakevdp deleted the arr-device branch July 23, 2024 18:28
@jakevdp jakevdp mentioned this pull request Jul 29, 2024
38 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants