From 1cc5511df6c5bc6d9fb35bd94f8e5c41ef2e40c5 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 13 Nov 2025 16:09:55 -0600 Subject: [PATCH 1/3] Rename sharding_names to sharding_metadata --- .../nnx_toy_examples/10_fsdp_and_optimizer.py | 8 +++++--- flax/core/meta.py | 4 ++-- flax/core/spmd.py | 6 +++++- flax/linen/spmd.py | 4 ++-- flax/nnx/spmd.py | 18 +++++++++--------- flax/nnx/variablelib.py | 14 +++++++++++--- tests/nnx/bridge/wrappers_test.py | 10 +++++----- tests/nnx/nn/linear_test.py | 6 +++--- tests/nnx/optimizer_test.py | 2 +- tests/nnx/spmd_test.py | 14 +++++++------- tests/nnx/transforms_test.py | 18 +++++++++--------- 11 files changed, 59 insertions(+), 45 deletions(-) diff --git a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py index b9695e01a..f68ce8c2c 100644 --- a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py +++ b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py @@ -14,6 +14,8 @@ import dataclasses import os + +from jax._src import sharding os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' from matplotlib import pyplot as plt @@ -56,15 +58,15 @@ class MLP(nnx.Module): def __init__(self, din, dmid, dout, rngs: nnx.Rngs): self.w1 = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)), - sharding_names=mesh_rules('embed', 'mlp'), + sharding_metadata=mesh_rules('embed', 'mlp'), ) self.b1 = nnx.Param( jnp.zeros((dmid,)), - sharding_names=mesh_rules('mlp'), + sharding_metadata=mesh_rules('mlp'), ) self.w2 = nnx.Param( nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)), - sharding_names=mesh_rules('embed', 'mlp'), + sharding_metadata=mesh_rules('embed', 'mlp'), ) def __call__(self, x: jax.Array): diff --git a/flax/core/meta.py b/flax/core/meta.py index 98eb643f5..b8151be31 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -300,13 +300,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding: def to_nnx_metadata(self) -> dict[str, Any]: """Return a dict of metadata that can translate into an `nnx.Variable`.""" metadata = dict(vars(self)) - metadata['sharding_names'] = metadata.pop('names') + metadata['sharding_metadata'] = metadata.pop('names') return metadata @classmethod def from_nnx_metadata(cls, metadata: dict[str, Any]): """Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`.""" - metadata['names'] = metadata.pop('sharding_names') + metadata['names'] = metadata.pop('sharding_metadata') fields = {x.name for x in dataclasses.fields(cls)} return cls(**{k: v for k, v in metadata.items() if k in fields}) diff --git a/flax/core/spmd.py b/flax/core/spmd.py index 35c120c59..67d5bff52 100644 --- a/flax/core/spmd.py +++ b/flax/core/spmd.py @@ -45,6 +45,8 @@ def shard_value(value, sharding_names, sharding_rules, mesh): f' with annotation {sharding_names=}. ' 'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.') pspec = get_pspec(sharding_names, sharding_rules) + if isinstance(sharding_names, NamedSharding) and mesh is not None: + assert sharding_names.mesh == mesh if mesh is not None: return _apply_sharding(value, NamedSharding(mesh, pspec)) return _apply_sharding(value, pspec) @@ -107,8 +109,10 @@ def composite_rules(rule1, rule2): def from_sharding_rules( - sharding: Sharding, sharding_rules: LogicalRules + sharding, sharding_rules: LogicalRules ) -> Sharding: + if isinstance(sharding, NamedSharding): + sharding = sharding.spec rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} return tuple( rules[str(s)] if (s and str(s) in rules) else s for s in sharding diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index b68487b47..20597de2f 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -290,7 +290,7 @@ def to_nnx_metadata(self) -> dict[str, Any]: """Return a dict of metadata that can translate into an `nnx.Variable`.""" metadata = vars(self) if 'names' in metadata: - metadata['sharding_names'] = metadata.pop('names') + metadata['sharding_metadata'] = metadata.pop('names') if 'rules' in metadata: metadata['sharding_rules'] = metadata.pop('rules') return metadata @@ -298,7 +298,7 @@ def to_nnx_metadata(self) -> dict[str, Any]: @classmethod def from_nnx_metadata(cls, metadata: dict[str, Any]): """Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`.""" - metadata['names'] = metadata.pop('sharding_names') + metadata['names'] = metadata.pop('sharding_metadata') metadata['rules'] = metadata.pop('sharding_rules') fields = {x.name for x in dataclasses.fields(cls)} return cls(**{k: v for k, v in metadata.items() if k in fields}) diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index 756165af9..46becb25c 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -45,9 +45,9 @@ def insert_field(fields, index, value): def _add_axis(x: tp.Any): if isinstance(x, variablelib.Variable): metadata = x.get_metadata() - if 'sharding_names' in metadata and metadata['sharding_names']: - sharding = metadata['sharding_names'] - x.set_metadata(sharding_names=insert_field(sharding, index, axis_name)) + if 'sharding_metadata' in metadata and metadata['sharding_metadata']: + sharding = metadata['sharding_metadata'] + x.set_metadata(sharding_metadata=insert_field(sharding, index, axis_name)) for k, v in other_meta.items(): if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple): @@ -74,9 +74,9 @@ def remove_field(fields, index, value): def _remove_axis(x: tp.Any): if isinstance(x, variablelib.Variable): - if hasattr(x, 'sharding_names') and x.sharding_names is not None: + if hasattr(x, 'sharding_metadata') and x.sharding_metadata is not None: x.set_metadata( - sharding_names=remove_field(x.sharding_names, index, axis_name) + sharding_metadata=remove_field(x.sharding_metadata, index, axis_name) ) for k, v in other_meta.items(): @@ -119,7 +119,7 @@ def with_partitioning( """A wrapper over any initializer to add sharding annotation data to a `Variable`.""" return variablelib.with_metadata( initializer, - sharding_names=sharding, + sharding_metadata=sharding, mesh=mesh, **metadata, ) @@ -128,8 +128,8 @@ def with_partitioning( def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None: """Given an `nnx.Variable`, return its `PartitionSpec`.""" metadata = v.get_metadata() - if 'sharding_names' in metadata and metadata['sharding_names']: - sharding = metadata['sharding_names'] + if 'sharding_metadata' in metadata and metadata['sharding_metadata']: + sharding = metadata['sharding_metadata'] if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata: context_rules = core_spmd.get_logical_axis_rules() local_rules = metadata.get('sharding_rules', ()) @@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh): lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_state, get_named_sharding(abs_state, mesh) ) - return gdef, abs_state \ No newline at end of file + return gdef, abs_state diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index e9f93e163..9180b9bf5 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -21,6 +21,7 @@ import threading import typing as tp from typing import Any +import warnings from flax import config import jax @@ -375,16 +376,20 @@ def __init__( metadata['on_remove_axis'] = var_t.on_remove_axis if 'sharding' in metadata: - metadata['sharding_names'] = metadata.pop('sharding') + metadata['sharding_metadata'] = metadata.pop('sharding') + + if 'sharding_names' in metadata: # for bw compat + warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning) + metadata['sharding_metadata'] = metadata.pop('sharding_names') object.__setattr__(self, '_var_metadata', metadata) # run create_value hooks value = self.create_value(self.raw_value) # shard the value if applicable - if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_names' in metadata: + if metadata.get('eager_sharding', using_eager_sharding()) and 'sharding_metadata' in metadata: value = core_spmd.shard_value( - value, metadata['sharding_names'], metadata.get('sharding_rules', None), + value, metadata['sharding_metadata'], metadata.get('sharding_rules', None), metadata.get('mesh', None)) # Create the ref out of the array value @@ -394,6 +399,9 @@ def __init__( object.__setattr__(self, 'raw_value', value) def __getattr__(self, name: str) -> tp.Any: + if name == 'sharding_names': # for backward compatibility + warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning) + return self.sharding_metadata if name in object.__getattribute__(self, '_var_metadata'): return self._var_metadata[name] return getattr(self.raw_value, name) diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 8e827bd24..3feeb2a76 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -174,7 +174,7 @@ def create_sharded_nnx_module(x): self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) self.assertIsInstance(nnx_model.kernel, nnx.Variable) - assert nnx_model.kernel.sharding_names == ('in', 'out') + assert nnx_model.kernel.sharding_metadata == ('in', 'out') assert nnx_model.kernel[...].sharding.is_equivalent_to( jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec('in', 'out') @@ -182,7 +182,7 @@ def create_sharded_nnx_module(x): ndim=2, ), f'{nnx_model.kernel[...].sharding = }' - assert nnx_model.bias.sharding_names == ('out-alias',) + assert nnx_model.bias.sharding_metadata == ('out-alias',) assert nnx_model.bias.sharding_rules == (('out-alias', 'out'),) assert nnx_model.bias[...].sharding.is_equivalent_to( jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out')), @@ -410,7 +410,7 @@ def test_nnx_to_linen_metadata(self): pspec_tree = nn.get_partition_spec(variables) assert y.shape == (1, 64) self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta) - assert variables['params']['kernel'].metadata['sharding_names'] == ('in', 'out') + assert variables['params']['kernel'].metadata['sharding_metadata'] == ('in', 'out') self.assertEqual(pspec_tree['params']['kernel'], jax.sharding.PartitionSpec('in', 'out')) np.testing.assert_allclose(y, x @ variables['params']['kernel'].value) @@ -519,8 +519,8 @@ def __call__(self, x): w, b = model.inner.dot['w'], model.inner.b np.testing.assert_allclose(model(x), x @ w + b) self.assertIsInstance(w, nnx.Param) - assert hasattr(w, 'sharding_names') and w.sharding_names == ('in', 'out') - assert hasattr(b, 'sharding_names') and b.sharding_names == ('out-alias', ) + assert hasattr(w, 'sharding_metadata') and w.sharding_metadata == ('in', 'out') + assert hasattr(b, 'sharding_metadata') and b.sharding_metadata == ('out-alias', ) def test_linen_nnx_linen(self): # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without diff --git a/tests/nnx/nn/linear_test.py b/tests/nnx/nn/linear_test.py index 47a938ebe..c552e5e38 100644 --- a/tests/nnx/nn/linear_test.py +++ b/tests/nnx/nn/linear_test.py @@ -393,7 +393,7 @@ def test(self, module_args_kwargs_initargs): kwargs = {"rngs": nnx.Rngs(0)} sharding_names = ("din", "dout") metadata_kwargs = { - f"{key}_metadata": {"sharding_names": sharding_names[:le]} + f"{key}_metadata": {"sharding_metadata": sharding_names[:le]} for key, le, _ in metadata_argnames } @@ -410,8 +410,8 @@ def test(self, module_args_kwargs_initargs): for attr_name, param_name in attrs: attr = getattr(module, attr_name) if attr_name is not None else module param = getattr(attr, param_name) - self.assertIsNotNone(param.sharding_names) - self.assertEqual(param.sharding_names, sharding_names[:le]) + self.assertIsNotNone(param.sharding_metadata) + self.assertEqual(param.sharding_metadata, sharding_names[:le]) if __name__ == '__main__': diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index 8a22b06fe..cbc230eca 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -91,7 +91,7 @@ def test_sharding_propagation(self): state = nnx.state(optimizer) partition_spec = nnx.get_partition_spec(state) - self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_names, ('a', 'b')) + self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_metadata, ('a', 'b')) self.assertEqual( partition_spec['opt_state'][0]['mu']['kernel'].value, jax.sharding.PartitionSpec('a', 'b'), diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 35194c2b8..fb0d606d7 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -139,7 +139,7 @@ def __init__(self, rngs: nnx.Rngs): 4, kernel_init=nnx.with_metadata( nnx.initializers.lecun_normal(), - sharding_names=('din', 'dout'), + sharding_metadata=('din', 'dout'), nickname=('in', 'out'), on_add_axis=lambda _, idx, name: kadds.append((idx, name)), on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)), @@ -160,7 +160,7 @@ def __call__(self, x: jax.Array): x = self.linear(x) # test sharding layer axes is not present inside scan test.assertEqual(self.linear.kernel.shape, (4, 4)) - test.assertEqual(self.linear.kernel.sharding_names, ('din', 'dout')) + test.assertEqual(self.linear.kernel.sharding_metadata, ('din', 'dout')) # at least a remove_axis was already called to remove the layer axis test.assertEqual(kremoves[-1], (0, 'layers')) test.assertEqual(bremoves[-1], (0, 'layers')) @@ -175,7 +175,7 @@ def __call__(self, x: jax.Array): with jax.set_mesh(mesh): m = MLP(rngs=nnx.Rngs(0)) self.assertEqual(m.linear.kernel.shape, (5, 4, 4)) - self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.kernel.sharding_metadata, ('layers', 'din', 'dout')) self.assertEqual(m.linear.kernel.nickname, ('nick', 'in', 'out')) self.assertEqual(m.linear.bias.shape, (5, 4)) # One add_axis called to add the `nnx.vmap` dimension @@ -201,7 +201,7 @@ def test_eager_sharding_context(self, use_eager_sharding): with jax.set_mesh(mesh): w = nnx.Param( rngs.lecun_normal()((4, 8)), - sharding_names=(None, 'model')) + sharding_metadata=(None, 'model')) if use_eager_sharding: assert has_sharding_spec(w) else: @@ -273,7 +273,7 @@ def test_explicit_sharding(self): ) v = nnx.Variable( jnp.ones((4, 4)), - sharding_names=('row', 'col'), + sharding_metadata=('row', 'col'), mesh=mesh, ) self.assertEqual(v.sharding.mesh, mesh) @@ -291,7 +291,7 @@ def test_explicit_sharding_disable_jit(self): with jax.disable_jit(True): v = nnx.Variable( jnp.ones((4, 4)), - sharding_names=('row', 'col'), + sharding_metadata=('row', 'col'), mesh=mesh, ) self.assertEqual(v.sharding.mesh, mesh) @@ -309,7 +309,7 @@ def test_explicit_sharding_mesh_context(self): with jax.set_mesh(mesh): v = nnx.Variable( jnp.ones((4, 4)), - sharding_names=('row', 'col'), + sharding_metadata=('row', 'col'), ) self.assertEqual(v.sharding.mesh, mesh) self.assertEqual( diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index d83e25d9d..71f63b4aa 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1852,10 +1852,10 @@ def __init__(self, rngs: nnx.Rngs): 3, 3, kernel_init=nnx.with_metadata( - nnx.initializers.lecun_normal(), sharding_names=('din', 'dout') + nnx.initializers.lecun_normal(), sharding_metadata=('din', 'dout') ), bias_init=nnx.with_metadata( - nnx.initializers.zeros_init(), sharding_names=('dout',) + nnx.initializers.zeros_init(), sharding_metadata=('dout',) ), rngs=rngs, ) @@ -1867,9 +1867,9 @@ def __call__(self, x: jax.Array): x = self.linear(x) # test sharding layer axes is not present inside scan test.assertEqual(self.linear.kernel.shape, (3, 3)) - test.assertEqual(self.linear.kernel.sharding_names, ('din', 'dout')) + test.assertEqual(self.linear.kernel.sharding_metadata, ('din', 'dout')) test.assertEqual(self.linear.bias.shape, (3,)) - test.assertEqual(self.linear.bias.sharding_names, ('dout',)) + test.assertEqual(self.linear.bias.sharding_metadata, ('dout',)) return x, None mesh = jax.make_mesh((1, 1, 1), ('layers', 'din', 'dout'), axis_types=(jax.sharding.AxisType.Auto,) * len(('layers', 'din', 'dout'))) @@ -1878,9 +1878,9 @@ def __call__(self, x: jax.Array): # test sharding layers axes is set self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) - self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.kernel.sharding_metadata, ('layers', 'din', 'dout')) self.assertEqual(m.linear.bias.shape, (5, 3)) - self.assertEqual(m.linear.bias.sharding_names, ('layers', 'dout')) + self.assertEqual(m.linear.bias.sharding_metadata, ('layers', 'dout')) x = jnp.ones((1, 3)) with jax.set_mesh(mesh): @@ -1888,9 +1888,9 @@ def __call__(self, x: jax.Array): # test sharding axes is preserved self.assertEqual(m.linear.kernel.shape, (5, 3, 3)) - self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout')) + self.assertEqual(m.linear.kernel.sharding_metadata, ('layers', 'din', 'dout')) self.assertEqual(m.linear.bias.shape, (5, 3)) - self.assertEqual(m.linear.bias.sharding_names, ('layers', 'dout')) + self.assertEqual(m.linear.bias.sharding_metadata, ('layers', 'dout')) def test_cache_tracing_simple(self): n = 0 @@ -2650,7 +2650,7 @@ def create_block(rngs: nnx.Rngs): with jax.set_mesh(mesh): m = create_block(nnx.Rngs(0)) self.assertEqual(m.kernel.shape, (5, 16, 32)) - self.assertEqual(m.kernel.sharding_names, ('c', 'a', 'b')) + self.assertEqual(m.kernel.sharding_metadata, ('c', 'a', 'b')) def test_state_axes_from_state(self): class Model(nnx.Module): From 0f22613d6c9161ae120e263482a19f1164319ace Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 20 Nov 2025 11:56:12 -0600 Subject: [PATCH 2/3] Fix error from merge conflict --- flax/nnx/variablelib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index b73f55c71..c02f524fe 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -1108,10 +1108,10 @@ def __init__( # run create_value hook value = self.create_value(value) # type: ignore # shard the _value if applicable - if eager_sharding and 'sharding_names' in metadata: + if eager_sharding and 'sharding_metadata' in metadata: value = core_spmd.shard_value( value, - metadata['sharding_names'], + metadata['sharding_metadata'], metadata.get('sharding_rules', None), metadata.get('mesh', None), ) From 6e5010d50859af3fddde806c9af7a04b06a12f49 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Thu, 20 Nov 2025 12:29:15 -0600 Subject: [PATCH 3/3] Handle Format sharding arguments --- flax/core/spmd.py | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/flax/core/spmd.py b/flax/core/spmd.py index 67d5bff52..69f13ef9b 100644 --- a/flax/core/spmd.py +++ b/flax/core/spmd.py @@ -18,40 +18,39 @@ import jax from jax.sharding import PartitionSpec, NamedSharding -from flax.core import meta +from jax.experimental.layout import Format from flax.typing import ( LogicalRules, Sharding, ) -def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec: - """Given an `nnx.Variable`, return its `PartitionSpec`.""" +def map_sharding(f, sharding): + if isinstance(sharding, PartitionSpec) or isinstance(sharding, tuple): + return PartitionSpec(*map(f, sharding)) + elif isinstance(sharding, NamedSharding): + return NamedSharding(sharding.mesh, map_sharding(f, sharding.sharding)) # type: ignore + elif isinstance(sharding, Format): + return Format(sharding.layout, map_sharding(f, sharding.format)) + +def apply_rules(sharding, sharding_rules): if get_logical_axis_rules() or sharding_rules: context_rules = get_logical_axis_rules() - rules = composite_rules(context_rules, sharding_rules) - return PartitionSpec(*from_sharding_rules(sharding_names, rules)) - return PartitionSpec(*sharding_names) + rules = {alias: on_mesh for (alias, on_mesh) in composite_rules(context_rules, sharding_rules)} + else: + rules = {} + return map_sharding(lambda a: rules.get(a, a), sharding) def _apply_sharding(value, sharding): with jax.disable_jit(False): return jax.jit(lambda x: x, out_shardings=sharding)(value) -def shard_value(value, sharding_names, sharding_rules, mesh): - if not sharding_names: - return value - if not mesh and not meta.global_mesh_defined(): - raise ValueError( - 'An auto mesh context or metadata is required if creating a variable' - f' with annotation {sharding_names=}. ' - 'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.') - pspec = get_pspec(sharding_names, sharding_rules) - if isinstance(sharding_names, NamedSharding) and mesh is not None: - assert sharding_names.mesh == mesh - if mesh is not None: - return _apply_sharding(value, NamedSharding(mesh, pspec)) - return _apply_sharding(value, pspec) - - +def shard_value(value, sharding, sharding_rules, mesh): + sharding = apply_rules(sharding, sharding_rules) + if isinstance(sharding, PartitionSpec) and mesh is not None: + sharding = NamedSharding(mesh, sharding) + if hasattr(sharding, 'mesh'): + assert mesh == sharding.mesh + return _apply_sharding(value, sharding) # Dynamic Axis Mapping Context