Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for device kwarg in astype, and add matching utility func #21086

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented May 6, 2024

Towards #20200

This PR adds a private device-placement utility jax._src.numpy.util._place_array to manage array API compliant array placement behavior for use in the jax.numpy namespace. Copies are mediated by the lax._array_copy utility, while device transfer is performed via api.device_put.

cc: @jakevdp

jax/_src/numpy/util.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
jax/_src/numpy/util.py Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the astype-device branch 3 times, most recently from 593f01c to 5445230 Compare May 9, 2024 22:27
@Micky774
Copy link
Collaborator Author

@yashk2810 Sorry for the delay in getting this message out. I've updated the PR with simplifications to the _place_array utility, as you suggested. Originally I intended to maximize strict compliance with array API semantics wrt setting copy=False while still specifying device, hence the attempts at detecting when a transfer is not required in the first place (in hindsight, the checks weren't quite right, since I'm still getting familiar with the sharding system).

Is there already a mechanism to predict when two shards are equivalent in the sense that calling device_put(x, shard) would be a no-op if x has sharding equivalent to shard? If so, or if you think something like that is feasible, then it may be worth including here.

cc: @jakevdp if you have any insights regarding how we want to handle copy=False, device=... behavior here

@jakevdp
Copy link
Collaborator

jakevdp commented May 10, 2024

The easiest thing would just be to error if copy=False and device is specified.

@Micky774 Micky774 force-pushed the astype-device branch 2 times, most recently from f0f52a5 to 60eaabc Compare May 10, 2024 17:58
@yashk2810
Copy link
Collaborator

Is there already a mechanism to predict when two shards are equivalent in the sense that calling device_put(x, shard) would be a no-op if x has sharding equivalent to shard?

sorry, I don't understand what this means. Also, why do you need such a mechanism?

@Micky774
Copy link
Collaborator Author

sorry, I don't understand what this means. Also, why do you need such a mechanism?

I was looking for a way to tell if x already follows whatever sharding/layout device would specify. That way we could short-cut to a no-op and avoid raising an error in situations like _place_array(x, device=x.sharding, copy=False).

@yashk2810
Copy link
Collaborator

I was looking for a way to tell if x already follows whatever sharding/layout device would specify

device_put should do that for you :)

@jakevdp
Copy link
Collaborator

jakevdp commented May 10, 2024

Yash - the issue is that the semantics of copy=False are "error if a copy is required", so device_put handling that transparently doesn't help because we need to know what's happening in order to know whether to error.

@yashk2810
Copy link
Collaborator

Ok, then why not error instead of doing the no-op logic here? If you want to transfer or have it be a no-op should be device_put's job. If device is not None and copy=True, then we should error since it makes no sense right?

@jakevdp
Copy link
Collaborator

jakevdp commented May 10, 2024

That's what I suggested above: the most conservative thing would be to error if copy and device are used together. But the more "correct" thing would be to somehow introspect whether or not device_put forces a copy in any particular case.

@yashk2810
Copy link
Collaborator

We can abstract away a function which can determine that which we can share here but device_put is complex so let's error for now and file a bug against me to give you such a function which you can call here.

@@ -37,6 +37,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.28 (May 9, 2024)

* New Functionality
* {func}`jax.numpy.astype` supports a new `device` keyword argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move up to 0.4.29

Note that JAX may not always strictly adhere to array API device semantics when
using ``jax.jit``. In particular, specifying the ``device`` argument is
equivalent to calling ``jax.device_put(x, device)``. For up-to-date details on
device placement, see the documentation of ``jax.device_put`` for more details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually references go at the bottom of the docstring. You can put this paragraph above the .. _Python... line


This utility uses `jax.device_put` for device placement.
"""
out = x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why out=x?

)
out = api.device_put(out, device)

# TODO(micky774): Avoid copy if data has already been copied via device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This todo doesn't make sense? In this branch the device is None. Remove the todo?

device = jax.devices("cpu")[0]
expected_sharding = SingleDeviceSharding(device)
else:
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use (2, 2) mesh to coverage across more hardware?

expected_sharding = SingleDeviceSharding(device)
else:
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
device = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to try other PartitionSpecs too (can you parameterize the test on that too)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants