Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove mesh_utils.create_device_mesh from docs #24409

Merged
merged 1 commit into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/jep/14273-shard-map.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
mesh = jax.make_mesh((4, 2), ('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@
},
"outputs": [],
"source": [
"from jax.experimental import mesh_utils\n",
"from jax.sharding import Mesh, PartitionSpec as P, NamedSharding"
"from jax.sharding import PartitionSpec as P, NamedSharding"
]
},
{
Expand All @@ -98,8 +97,7 @@
"outputs": [],
"source": [
"# Create a Sharding object to distribute a value across devices:\n",
"mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n",
" axis_names=('x', 'y'))"
"mesh = jax.make_mesh((4, 2), ('x', 'y'))"
]
},
{
Expand Down Expand Up @@ -372,7 +370,7 @@
"source": [
"Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n",
"\n",
"But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:"
"But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `jax.make_mesh`, which takes hardware topology into account for the `Device` order:"
]
},
{
Expand Down Expand Up @@ -419,12 +417,10 @@
],
"source": [
"from jax.sharding import Mesh, PartitionSpec, NamedSharding\n",
"from jax.experimental import mesh_utils\n",
"\n",
"P = PartitionSpec\n",
"\n",
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, axis_names=('a', 'b'))\n",
"mesh = jax.make_mesh((4, 2), ('a', 'b'))\n",
"y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n",
"jax.debug.visualize_array_sharding(y)"
]
Expand All @@ -446,8 +442,7 @@
},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"default_mesh = Mesh(devices, axis_names=('a', 'b'))\n",
"default_mesh = jax.make_mesh((4, 2), ('a', 'b'))\n",
"\n",
"def mesh_sharding(\n",
" pspec: PartitionSpec, mesh: Optional[Mesh] = None,\n",
Expand Down Expand Up @@ -836,8 +831,7 @@
},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, axis_names=('a', 'b'))"
"mesh = jax.make_mesh((4, 2), ('a', 'b'))"
]
},
{
Expand Down Expand Up @@ -1449,7 +1443,7 @@
},
"outputs": [],
"source": [
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))"
"mesh = jax.make_mesh((4, 2), ('x', 'y'))"
]
},
{
Expand Down Expand Up @@ -1785,7 +1779,7 @@
},
"outputs": [],
"source": [
"mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')"
"mesh = jax.make_mesh((8,), ('batch',))"
]
},
{
Expand Down Expand Up @@ -1943,7 +1937,7 @@
},
"outputs": [],
"source": [
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))"
"mesh = jax.make_mesh((4, 2), ('batch', 'model'))"
]
},
{
Expand Down
24 changes: 9 additions & 15 deletions docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ First, we'll create a `jax.Array` sharded across multiple devices:
```{code-cell}
:id: Gf2lO4ii3vGG

from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.sharding import PartitionSpec as P, NamedSharding
```

```{code-cell}
:id: q-XBTEoy3vGG

# Create a Sharding object to distribute a value across devices:
mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),
axis_names=('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))
```

```{code-cell}
Expand Down Expand Up @@ -173,7 +171,7 @@ jax.debug.visualize_array_sharding(x)

Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!

But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:
But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `jax.make_mesh`, which takes hardware topology into account for the `Device` order:

```{code-cell}
---
Expand All @@ -184,12 +182,10 @@ id: zpB1JxyK3vGN
outputId: 8e385462-1c2c-4256-c38a-84299d3bd02c
---
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

P = PartitionSpec

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
mesh = jax.make_mesh((4, 2), ('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
```
Expand All @@ -201,8 +197,7 @@ We can define a helper function to make things simpler:
```{code-cell}
:id: 8g0Md2Gd3vGO

devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
default_mesh = jax.make_mesh((4, 2), ('a', 'b'))

def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
Expand Down Expand Up @@ -318,8 +313,7 @@ For example, the simplest computation is an elementwise one:
```{code-cell}
:id: _EmQwggc3vGQ

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
mesh = jax.make_mesh((4, 2), ('a', 'b'))
```

```{code-cell}
Expand Down Expand Up @@ -522,7 +516,7 @@ While the compiler will attempt to decide how a function's intermediate values a
```{code-cell}
:id: jniSFm5V3vGT

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))
```

```{code-cell}
Expand Down Expand Up @@ -657,7 +651,7 @@ params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
```{code-cell}
:id: mJLqRPpSDX0i

mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')
mesh = jax.make_mesh((8,), ('batch',))
```

```{code-cell}
Expand Down Expand Up @@ -735,7 +729,7 @@ outputId: d66767b7-3f17-482f-b811-919bb1793277
```{code-cell}
:id: k1hxOfgRDwo0

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
```

```{code-cell}
Expand Down
17 changes: 5 additions & 12 deletions docs/notebooks/shard_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
"import jax.numpy as jnp\n",
"\n",
"from jax.sharding import Mesh, PartitionSpec as P\n",
"from jax.experimental import mesh_utils\n",
"from jax.experimental.shard_map import shard_map"
]
},
Expand All @@ -67,8 +66,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh((4, 2), ('x', 'y'))\n",
"\n",
"a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
"b = jnp.arange(16 * 4.).reshape(16, 4)\n",
Expand Down Expand Up @@ -296,8 +294,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, ('i', 'j'))\n",
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
"\n",
"@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
"def f1(x_block):\n",
Expand Down Expand Up @@ -1549,12 +1546,10 @@
"\n",
"from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
"from jax.experimental.shard_map import shard_map\n",
"from jax.experimental import mesh_utils\n",
"\n",
"devices = mesh_utils.create_device_mesh((8,))\n",
"mesh = jax.make_mesh((8,), ('batch',))\n",
"\n",
"# replicate initial params on all devices, shard data batch over devices\n",
"mesh = Mesh(devices, ('batch',))\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P()))\n",
"\n",
Expand Down Expand Up @@ -1723,8 +1718,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((8,))\n",
"mesh = Mesh(devices, ('feats',))\n",
"mesh = jax.make_mesh((8,), ('feats',))\n",
"\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
Expand Down Expand Up @@ -1766,8 +1760,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, ('batch', 'feats'))\n",
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n",
"\n",
"batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
"params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
Expand Down
17 changes: 5 additions & 12 deletions docs/notebooks/shard_map.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,11 @@ import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
```

```{code-cell}
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
Expand Down Expand Up @@ -196,8 +194,7 @@ input array axis size.) If an input's pspec does not mention a mesh axis name,
then there's no splitting over that mesh axis. For example:

```{code-cell}
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))
mesh = jax.make_mesh((4, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
Expand Down Expand Up @@ -1083,12 +1080,10 @@ from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))
mesh = jax.make_mesh((8,), ('batch',))

# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

Expand Down Expand Up @@ -1203,8 +1198,7 @@ multiplications followed by a `psum_scatter` to sum the local results and
efficiently scatter the result's shards.

```{code-cell}
devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))
mesh = jax.make_mesh((8,), ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
Expand Down Expand Up @@ -1234,8 +1228,7 @@ def loss_tp(params, batch):
We can compose these strategies together, using multiple axes of parallelism.

```{code-cell}
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))
mesh = jax.make_mesh((4, 2), ('batch', 'feats'))

batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
if isinstance(axis_names, str):
axis_names = (axis_names,)
axis_names = tuple(axis_names)
if not all(i is not None for i in axis_names):
if any(i is None for i in axis_names):
raise ValueError(f"Mesh axis names cannot be None. Got: {axis_names}")

if devices.ndim != len(axis_names):
Expand Down
Loading