From ca2d1584f8707e876d331350b77e7c7941153826 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 19 Oct 2024 15:48:05 -0700 Subject: [PATCH] Remove `mesh_utils.create_device_mesh` from docs PiperOrigin-RevId: 687695419 --- docs/jep/14273-shard-map.md | 4 +--- ...arrays_and_automatic_parallelization.ipynb | 24 +++++++------------ ...ed_arrays_and_automatic_parallelization.md | 24 +++++++------------ docs/notebooks/shard_map.ipynb | 17 ++++--------- docs/notebooks/shard_map.md | 17 ++++--------- jax/_src/mesh.py | 2 +- 6 files changed, 30 insertions(+), 58 deletions(-) diff --git a/docs/jep/14273-shard-map.md b/docs/jep/14273-shard-map.md index 4e52d21034f7..8e66a675a522 100644 --- a/docs/jep/14273-shard-map.md +++ b/docs/jep/14273-shard-map.md @@ -66,11 +66,9 @@ import numpy as np import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P -from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map -devices = mesh_utils.create_device_mesh((4, 2)) -mesh = Mesh(devices, axis_names=('i', 'j')) +mesh = jax.make_mesh((4, 2), ('i', 'j')) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 32.).reshape(16, 32) diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 32d332d9ac7e..b1d12abdd251 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -85,8 +85,7 @@ }, "outputs": [], "source": [ - "from jax.experimental import mesh_utils\n", - "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding" + "from jax.sharding import PartitionSpec as P, NamedSharding" ] }, { @@ -98,8 +97,7 @@ "outputs": [], "source": [ "# Create a Sharding object to distribute a value across devices:\n", - "mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n", - " axis_names=('x', 'y'))" + "mesh = jax.make_mesh((4, 2), ('x', 'y'))" ] }, { @@ -372,7 +370,7 @@ "source": [ "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", "\n", - "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" + "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `jax.make_mesh`, which takes hardware topology into account for the `Device` order:" ] }, { @@ -419,12 +417,10 @@ ], "source": [ "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", - "from jax.experimental import mesh_utils\n", "\n", "P = PartitionSpec\n", "\n", - "devices = mesh_utils.create_device_mesh((4, 2))\n", - "mesh = Mesh(devices, axis_names=('a', 'b'))\n", + "mesh = jax.make_mesh((4, 2), ('a', 'b'))\n", "y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n", "jax.debug.visualize_array_sharding(y)" ] @@ -446,8 +442,7 @@ }, "outputs": [], "source": [ - "devices = mesh_utils.create_device_mesh((4, 2))\n", - "default_mesh = Mesh(devices, axis_names=('a', 'b'))\n", + "default_mesh = jax.make_mesh((4, 2), ('a', 'b'))\n", "\n", "def mesh_sharding(\n", " pspec: PartitionSpec, mesh: Optional[Mesh] = None,\n", @@ -836,8 +831,7 @@ }, "outputs": [], "source": [ - "devices = mesh_utils.create_device_mesh((4, 2))\n", - "mesh = Mesh(devices, axis_names=('a', 'b'))" + "mesh = jax.make_mesh((4, 2), ('a', 'b'))" ] }, { @@ -1449,7 +1443,7 @@ }, "outputs": [], "source": [ - "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))" + "mesh = jax.make_mesh((4, 2), ('x', 'y'))" ] }, { @@ -1785,7 +1779,7 @@ }, "outputs": [], "source": [ - "mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')" + "mesh = jax.make_mesh((8,), ('batch',))" ] }, { @@ -1943,7 +1937,7 @@ }, "outputs": [], "source": [ - "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))" + "mesh = jax.make_mesh((4, 2), ('batch', 'model'))" ] }, { diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index 2142db9866ae..18da0cc78715 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -61,16 +61,14 @@ First, we'll create a `jax.Array` sharded across multiple devices: ```{code-cell} :id: Gf2lO4ii3vGG -from jax.experimental import mesh_utils -from jax.sharding import Mesh, PartitionSpec as P, NamedSharding +from jax.sharding import PartitionSpec as P, NamedSharding ``` ```{code-cell} :id: q-XBTEoy3vGG # Create a Sharding object to distribute a value across devices: -mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)), - axis_names=('x', 'y')) +mesh = jax.make_mesh((4, 2), ('x', 'y')) ``` ```{code-cell} @@ -173,7 +171,7 @@ jax.debug.visualize_array_sharding(x) Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring! -But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order: +But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `jax.make_mesh`, which takes hardware topology into account for the `Device` order: ```{code-cell} --- @@ -184,12 +182,10 @@ id: zpB1JxyK3vGN outputId: 8e385462-1c2c-4256-c38a-84299d3bd02c --- from jax.sharding import Mesh, PartitionSpec, NamedSharding -from jax.experimental import mesh_utils P = PartitionSpec -devices = mesh_utils.create_device_mesh((4, 2)) -mesh = Mesh(devices, axis_names=('a', 'b')) +mesh = jax.make_mesh((4, 2), ('a', 'b')) y = jax.device_put(x, NamedSharding(mesh, P('a', 'b'))) jax.debug.visualize_array_sharding(y) ``` @@ -201,8 +197,7 @@ We can define a helper function to make things simpler: ```{code-cell} :id: 8g0Md2Gd3vGO -devices = mesh_utils.create_device_mesh((4, 2)) -default_mesh = Mesh(devices, axis_names=('a', 'b')) +default_mesh = jax.make_mesh((4, 2), ('a', 'b')) def mesh_sharding( pspec: PartitionSpec, mesh: Optional[Mesh] = None, @@ -318,8 +313,7 @@ For example, the simplest computation is an elementwise one: ```{code-cell} :id: _EmQwggc3vGQ -devices = mesh_utils.create_device_mesh((4, 2)) -mesh = Mesh(devices, axis_names=('a', 'b')) +mesh = jax.make_mesh((4, 2), ('a', 'b')) ``` ```{code-cell} @@ -522,7 +516,7 @@ While the compiler will attempt to decide how a function's intermediate values a ```{code-cell} :id: jniSFm5V3vGT -mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y')) +mesh = jax.make_mesh((4, 2), ('x', 'y')) ``` ```{code-cell} @@ -657,7 +651,7 @@ params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) ```{code-cell} :id: mJLqRPpSDX0i -mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch') +mesh = jax.make_mesh((8,), ('batch',)) ``` ```{code-cell} @@ -735,7 +729,7 @@ outputId: d66767b7-3f17-482f-b811-919bb1793277 ```{code-cell} :id: k1hxOfgRDwo0 -mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) +mesh = jax.make_mesh((4, 2), ('batch', 'model')) ``` ```{code-cell} diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index aa355f471a20..37c27ce2728a 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -56,7 +56,6 @@ "import jax.numpy as jnp\n", "\n", "from jax.sharding import Mesh, PartitionSpec as P\n", - "from jax.experimental import mesh_utils\n", "from jax.experimental.shard_map import shard_map" ] }, @@ -67,8 +66,7 @@ "metadata": {}, "outputs": [], "source": [ - "devices = mesh_utils.create_device_mesh((4, 2))\n", - "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "mesh = jax.make_mesh((4, 2), ('x', 'y'))\n", "\n", "a = jnp.arange( 8 * 16.).reshape(8, 16)\n", "b = jnp.arange(16 * 4.).reshape(16, 4)\n", @@ -296,8 +294,7 @@ "metadata": {}, "outputs": [], "source": [ - "devices = mesh_utils.create_device_mesh((4, 2))\n", - "mesh = Mesh(devices, ('i', 'j'))\n", + "mesh = jax.make_mesh((4, 2), ('i', 'j'))\n", "\n", "@partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j'))\n", "def f1(x_block):\n", @@ -1549,12 +1546,10 @@ "\n", "from jax.sharding import NamedSharding, Mesh, PartitionSpec as P\n", "from jax.experimental.shard_map import shard_map\n", - "from jax.experimental import mesh_utils\n", "\n", - "devices = mesh_utils.create_device_mesh((8,))\n", + "mesh = jax.make_mesh((8,), ('batch',))\n", "\n", "# replicate initial params on all devices, shard data batch over devices\n", - "mesh = Mesh(devices, ('batch',))\n", "batch = jax.device_put(batch, NamedSharding(mesh, P('batch')))\n", "params = jax.device_put(params, NamedSharding(mesh, P()))\n", "\n", @@ -1723,8 +1718,7 @@ "metadata": {}, "outputs": [], "source": [ - "devices = mesh_utils.create_device_mesh((8,))\n", - "mesh = Mesh(devices, ('feats',))\n", + "mesh = jax.make_mesh((8,), ('feats',))\n", "\n", "batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats')))\n", "params = jax.device_put(params, NamedSharding(mesh, P('feats')))\n", @@ -1766,8 +1760,7 @@ "metadata": {}, "outputs": [], "source": [ - "devices = mesh_utils.create_device_mesh((4, 2))\n", - "mesh = Mesh(devices, ('batch', 'feats'))\n", + "mesh = jax.make_mesh((4, 2), ('batch', 'feats'))\n", "\n", "batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats')))\n", "params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats'))))\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index d77dec652068..47b11079e27d 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -46,13 +46,11 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec as P -from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map ``` ```{code-cell} -devices = mesh_utils.create_device_mesh((4, 2)) -mesh = Mesh(devices, axis_names=('x', 'y')) +mesh = jax.make_mesh((4, 2), ('x', 'y')) a = jnp.arange( 8 * 16.).reshape(8, 16) b = jnp.arange(16 * 4.).reshape(16, 4) @@ -196,8 +194,7 @@ input array axis size.) If an input's pspec does not mention a mesh axis name, then there's no splitting over that mesh axis. For example: ```{code-cell} -devices = mesh_utils.create_device_mesh((4, 2)) -mesh = Mesh(devices, ('i', 'j')) +mesh = jax.make_mesh((4, 2), ('i', 'j')) @partial(shard_map, mesh=mesh, in_specs=P('i', None), out_specs=P('i', 'j')) def f1(x_block): @@ -1083,12 +1080,10 @@ from functools import partial from jax.sharding import NamedSharding, Mesh, PartitionSpec as P from jax.experimental.shard_map import shard_map -from jax.experimental import mesh_utils -devices = mesh_utils.create_device_mesh((8,)) +mesh = jax.make_mesh((8,), ('batch',)) # replicate initial params on all devices, shard data batch over devices -mesh = Mesh(devices, ('batch',)) batch = jax.device_put(batch, NamedSharding(mesh, P('batch'))) params = jax.device_put(params, NamedSharding(mesh, P())) @@ -1203,8 +1198,7 @@ multiplications followed by a `psum_scatter` to sum the local results and efficiently scatter the result's shards. ```{code-cell} -devices = mesh_utils.create_device_mesh((8,)) -mesh = Mesh(devices, ('feats',)) +mesh = jax.make_mesh((8,), ('feats',)) batch = jax.device_put(batch, NamedSharding(mesh, P(None, 'feats'))) params = jax.device_put(params, NamedSharding(mesh, P('feats'))) @@ -1234,8 +1228,7 @@ def loss_tp(params, batch): We can compose these strategies together, using multiple axes of parallelism. ```{code-cell} -devices = mesh_utils.create_device_mesh((4, 2)) -mesh = Mesh(devices, ('batch', 'feats')) +mesh = jax.make_mesh((4, 2), ('batch', 'feats')) batch_ = jax.device_put(batch, NamedSharding(mesh, P('batch', 'feats'))) params_ = jax.device_put(params, NamedSharding(mesh, P(('batch', 'feats')))) diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index a3d529fcffbe..8cb508378129 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -165,7 +165,7 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], if isinstance(axis_names, str): axis_names = (axis_names,) axis_names = tuple(axis_names) - if not all(i is not None for i in axis_names): + if any(i is None for i in axis_names): raise ValueError(f"Mesh axis names cannot be None. Got: {axis_names}") if devices.ndim != len(axis_names):