Skip to content

Commit

Permalink
Remove mesh_utils.create_device_mesh from docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687679312
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 19, 2024
1 parent 77fb1ee commit d01b1e9
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 59 deletions.
4 changes: 1 addition & 3 deletions docs/jep/14273-shard-map.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))
mesh = jax.make_mesh((4, 2), ('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@
},
"outputs": [],
"source": [
"from jax.experimental import mesh_utils\n",
"from jax.sharding import Mesh, PartitionSpec as P, NamedSharding"
"from jax.sharding import PartitionSpec as P, NamedSharding"
]
},
{
Expand All @@ -98,8 +97,7 @@
"outputs": [],
"source": [
"# Create a Sharding object to distribute a value across devices:\n",
"mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n",
" axis_names=('x', 'y'))"
"mesh = jax.make_mesh((4, 2), ('x', 'y'))"
]
},
{
Expand Down Expand Up @@ -372,7 +370,7 @@
"source": [
"Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n",
"\n",
"But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:"
"But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `jax.make_mesh`, which takes hardware topology into account for the `Device` order:"
]
},
{
Expand Down Expand Up @@ -419,12 +417,10 @@
],
"source": [
"from jax.sharding import Mesh, PartitionSpec, NamedSharding\n",
"from jax.experimental import mesh_utils\n",
"\n",
"P = PartitionSpec\n",
"\n",
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, axis_names=('a', 'b'))\n",
"mesh = jax.make_mesh((4, 2), ('a', 'b'))\n",
"y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n",
"jax.debug.visualize_array_sharding(y)"
]
Expand All @@ -446,8 +442,7 @@
},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"default_mesh = Mesh(devices, axis_names=('a', 'b'))\n",
"default_mesh = jax.make_mesh((4, 2), ('a', 'b'))\n",
"\n",
"def mesh_sharding(\n",
" pspec: PartitionSpec, mesh: Optional[Mesh] = None,\n",
Expand Down Expand Up @@ -836,8 +831,7 @@
},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, axis_names=('a', 'b'))"
"mesh = jax.make_mesh((4, 2), ('a', 'b'))"
]
},
{
Expand Down Expand Up @@ -1449,7 +1443,7 @@
},
"outputs": [],
"source": [
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))"
"mesh = jax.make_mesh((4, 2), ('x', 'y'))"
]
},
{
Expand Down Expand Up @@ -1785,7 +1779,7 @@
},
"outputs": [],
"source": [
"mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')"
"mesh = jax.make_mesh((8,), ('batch',))"
]
},
{
Expand Down Expand Up @@ -1943,7 +1937,7 @@
},
"outputs": [],
"source": [
"mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))"
"mesh = jax.make_mesh((4, 2), ('batch', 'model'))"
]
},
{
Expand Down
24 changes: 9 additions & 15 deletions docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,14 @@ First, we'll create a `jax.Array` sharded across multiple devices:
```{code-cell}
:id: Gf2lO4ii3vGG
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.sharding import PartitionSpec as P, NamedSharding
```

```{code-cell}
:id: q-XBTEoy3vGG
# Create a Sharding object to distribute a value across devices:
mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),
axis_names=('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))
```

```{code-cell}
Expand Down Expand Up @@ -173,7 +171,7 @@ jax.debug.visualize_array_sharding(x)

Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!

But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:
But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `jax.make_mesh`, which takes hardware topology into account for the `Device` order:

```{code-cell}
---
Expand All @@ -184,12 +182,10 @@ id: zpB1JxyK3vGN
outputId: 8e385462-1c2c-4256-c38a-84299d3bd02c
---
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
mesh = jax.make_mesh((4, 2), ('a', 'b'))
y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
jax.debug.visualize_array_sharding(y)
```
Expand All @@ -201,8 +197,7 @@ We can define a helper function to make things simpler:
```{code-cell}
:id: 8g0Md2Gd3vGO
devices = mesh_utils.create_device_mesh((4, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
default_mesh = jax.make_mesh((4, 2), ('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
Expand Down Expand Up @@ -318,8 +313,7 @@ For example, the simplest computation is an elementwise one:
```{code-cell}
:id: _EmQwggc3vGQ
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
mesh = jax.make_mesh((4, 2), ('a', 'b'))
```

```{code-cell}
Expand Down Expand Up @@ -522,7 +516,7 @@ While the compiler will attempt to decide how a function's intermediate values a
```{code-cell}
:id: jniSFm5V3vGT
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))
```

```{code-cell}
Expand Down Expand Up @@ -657,7 +651,7 @@ params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
```{code-cell}
:id: mJLqRPpSDX0i
mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')
mesh = jax.make_mesh((8,), ('batch',))
```

```{code-cell}
Expand Down Expand Up @@ -735,7 +729,7 @@ outputId: d66767b7-3f17-482f-b811-919bb1793277
```{code-cell}
:id: k1hxOfgRDwo0
mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
```

```{code-cell}
Expand Down
19 changes: 6 additions & 13 deletions docs/notebooks/shard_map.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@
"import jax.numpy as jnp\n",
"\n",
"from jax.sharding import Mesh, PartitionSpec as P\n",
"from jax.experimental import mesh_utils\n",
"from jax.experimental.shard_map import shard_map"
] "from jax.experimental.shard_map import shard_map"
]
},
{
Expand All @@ -67,8 +66,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"mesh = jax.make_mesh((4, 2), ('x', 'y'))\n",
"\n",
"a = jnp.arange( 8 * 16.).reshape(8, 16)\n",
"b = jnp.arange(16 * 4.).reshape(16, 4)\n",
Expand Down Expand Up @@ -296,8 +294,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, ('i', 'j'))\n",
"mesh = jax.make_mesh((4, 2), ('i', 'j'))\n",
"\n",
"@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n",
"def f1(x_block):\n",
Expand Down Expand Up @@ -1549,12 +1546,10 @@
"\n",
"from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n",
"from jax.experimental.shard_map import shard_map\n",
"from jax.experimental import mesh_utils\n",
"\n",
"devices = mesh_utils.create_device_mesh((8,))\n",
"mesh = jax.make_mesh((8,), ('batch',))\n",
"\n",
"# replicate initial params on all devices, shard data batch over devices\n",
"mesh = Mesh(devices, ('batch',))\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P()))\n",
"\n",
Expand Down Expand Up @@ -1723,8 +1718,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((8,))\n",
"mesh = Mesh(devices, ('feats',))\n",
"mesh = jax.make_mesh((8,), ('feats',))\n",
"\n",
"batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n",
"params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n",
Expand Down Expand Up @@ -1766,8 +1760,7 @@
"metadata": {},
"outputs": [],
"source": [
"devices = mesh_utils.create_device_mesh((4, 2))\n",
"mesh = Mesh(devices, ('batch', 'feats'))\n",
"mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n",
"\n",
"batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n",
"params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n",
Expand Down
17 changes: 5 additions & 12 deletions docs/notebooks/shard_map.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,11 @@ import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
```

```{code-cell}
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('x', 'y'))
mesh = jax.make_mesh((4, 2), ('x', 'y'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 4.).reshape(16, 4)
Expand Down Expand Up @@ -196,8 +194,7 @@ input array axis size.) If an input's pspec does not mention a mesh axis name,
then there's no splitting over that mesh axis. For example:

```{code-cell}
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('i', 'j'))
mesh = jax.make_mesh((4, 2), ('i', 'j'))

@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
Expand Down Expand Up @@ -1083,12 +1080,10 @@ from functools import partial

from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((8,))
mesh = jax.make_mesh((8,), ('batch',))

# replicate initial params on all devices, shard data batch over devices
mesh = Mesh(devices, ('batch',))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))
params = jax.device_put(params, NamedSharding(mesh, P()))

Expand Down Expand Up @@ -1203,8 +1198,7 @@ multiplications followed by a `psum_scatter` to sum the local results and
efficiently scatter the result's shards.

```{code-cell}
devices = mesh_utils.create_device_mesh((8,))
mesh = Mesh(devices, ('feats',))
mesh = jax.make_mesh((8,), ('feats',))

batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))
params = jax.device_put(params, NamedSharding(mesh, P('feats')))
Expand Down Expand Up @@ -1234,8 +1228,7 @@ def loss_tp(params, batch):
We can compose these strategies together, using multiple axes of parallelism.

```{code-cell}
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('batch', 'feats'))
mesh = jax.make_mesh((4, 2), ('batch', 'feats'))

batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))
params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
if isinstance(axis_names, str):
axis_names = (axis_names,)
axis_names = tuple(axis_names)
if not all(i is not None for i in axis_names):
if any(i is None for i in axis_names):
raise ValueError(f"Mesh axis names cannot be None. Got: {axis_names}")

if devices.ndim != len(axis_names):
Expand Down

0 comments on commit d01b1e9

Please sign in to comment.