Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/nnx_toy_examples/10_fsdp_and_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import dataclasses
import os

from jax._src import sharding
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

from matplotlib import pyplot as plt
Expand Down Expand Up @@ -56,15 +58,15 @@ class MLP(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.w1 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (din, dmid)),
sharding_names=mesh_rules('embed', 'mlp'),
sharding_metadata=mesh_rules('embed', 'mlp'),
)
self.b1 = nnx.Param(
jnp.zeros((dmid,)),
sharding_names=mesh_rules('mlp'),
sharding_metadata=mesh_rules('mlp'),
)
self.w2 = nnx.Param(
nnx.initializers.lecun_normal()(rngs.params(), (dmid, dout)),
sharding_names=mesh_rules('embed', 'mlp'),
sharding_metadata=mesh_rules('embed', 'mlp'),
)

def __call__(self, x: jax.Array):
Expand Down
4 changes: 2 additions & 2 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,13 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = dict(vars(self))
metadata['sharding_names'] = metadata.pop('names')
metadata['sharding_metadata'] = metadata.pop('names')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
metadata['names'] = metadata.pop('sharding_names')
metadata['names'] = metadata.pop('sharding_metadata')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})

Expand Down
45 changes: 24 additions & 21 deletions flax/core/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,39 @@

import jax
from jax.sharding import PartitionSpec, NamedSharding
from flax.core import meta
from jax.experimental.layout import Format
from flax.typing import (
LogicalRules,
Sharding,
)

def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec:
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
def map_sharding(f, sharding):
if isinstance(sharding, PartitionSpec) or isinstance(sharding, tuple):
return PartitionSpec(*map(f, sharding))
elif isinstance(sharding, NamedSharding):
return NamedSharding(sharding.mesh, map_sharding(f, sharding.sharding)) # type: ignore
elif isinstance(sharding, Format):
return Format(sharding.layout, map_sharding(f, sharding.format))

def apply_rules(sharding, sharding_rules):
if get_logical_axis_rules() or sharding_rules:
context_rules = get_logical_axis_rules()
rules = composite_rules(context_rules, sharding_rules)
return PartitionSpec(*from_sharding_rules(sharding_names, rules))
return PartitionSpec(*sharding_names)
rules = {alias: on_mesh for (alias, on_mesh) in composite_rules(context_rules, sharding_rules)}
else:
rules = {}
return map_sharding(lambda a: rules.get(a, a), sharding)

def _apply_sharding(value, sharding):
with jax.disable_jit(False):
return jax.jit(lambda x: x, out_shardings=sharding)(value)

def shard_value(value, sharding_names, sharding_rules, mesh):
if not sharding_names:
return value
if not mesh and not meta.global_mesh_defined():
raise ValueError(
'An auto mesh context or metadata is required if creating a variable'
f' with annotation {sharding_names=}. '
'For more guidance, see https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html.')
pspec = get_pspec(sharding_names, sharding_rules)
if mesh is not None:
return _apply_sharding(value, NamedSharding(mesh, pspec))
return _apply_sharding(value, pspec)


def shard_value(value, sharding, sharding_rules, mesh):
sharding = apply_rules(sharding, sharding_rules)
if isinstance(sharding, PartitionSpec) and mesh is not None:
sharding = NamedSharding(mesh, sharding)
if hasattr(sharding, 'mesh'):
assert mesh == sharding.mesh
return _apply_sharding(value, sharding)


# Dynamic Axis Mapping Context
Expand Down Expand Up @@ -107,8 +108,10 @@ def composite_rules(rule1, rule2):


def from_sharding_rules(
sharding: Sharding, sharding_rules: LogicalRules
sharding, sharding_rules: LogicalRules
) -> Sharding:
if isinstance(sharding, NamedSharding):
sharding = sharding.spec
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return tuple(
rules[str(s)] if (s and str(s) in rules) else s for s in sharding
Expand Down
4 changes: 2 additions & 2 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,15 @@ def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
if 'names' in metadata:
metadata['sharding_names'] = metadata.pop('names')
metadata['sharding_metadata'] = metadata.pop('names')
if 'rules' in metadata:
metadata['sharding_rules'] = metadata.pop('rules')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
metadata['names'] = metadata.pop('sharding_names')
metadata['names'] = metadata.pop('sharding_metadata')
metadata['rules'] = metadata.pop('sharding_rules')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})
Expand Down
18 changes: 9 additions & 9 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def insert_field(fields, index, value):
def _add_axis(x: tp.Any):
if isinstance(x, variablelib.Variable):
metadata = x.get_metadata()
if 'sharding_names' in metadata and metadata['sharding_names']:
sharding = metadata['sharding_names']
x.set_metadata(sharding_names=insert_field(sharding, index, axis_name))
if 'sharding_metadata' in metadata and metadata['sharding_metadata']:
sharding = metadata['sharding_metadata']
x.set_metadata(sharding_metadata=insert_field(sharding, index, axis_name))

for k, v in other_meta.items():
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
Expand All @@ -74,9 +74,9 @@ def remove_field(fields, index, value):

def _remove_axis(x: tp.Any):
if isinstance(x, variablelib.Variable):
if hasattr(x, 'sharding_names') and x.sharding_names is not None:
if hasattr(x, 'sharding_metadata') and x.sharding_metadata is not None:
x.set_metadata(
sharding_names=remove_field(x.sharding_names, index, axis_name)
sharding_metadata=remove_field(x.sharding_metadata, index, axis_name)
)

for k, v in other_meta.items():
Expand Down Expand Up @@ -119,7 +119,7 @@ def with_partitioning(
"""A wrapper over any initializer to add sharding annotation data to a `Variable`."""
return variablelib.with_metadata(
initializer,
sharding_names=sharding,
sharding_metadata=sharding,
mesh=mesh,
**metadata,
)
Expand All @@ -128,8 +128,8 @@ def with_partitioning(
def get_var_pspec(v: variablelib.Variable) -> PartitionSpec | None:
"""Given an `nnx.Variable`, return its `PartitionSpec`."""
metadata = v.get_metadata()
if 'sharding_names' in metadata and metadata['sharding_names']:
sharding = metadata['sharding_names']
if 'sharding_metadata' in metadata and metadata['sharding_metadata']:
sharding = metadata['sharding_metadata']
if core_spmd.get_logical_axis_rules() or 'sharding_rules' in metadata:
context_rules = core_spmd.get_logical_axis_rules()
local_rules = metadata.get('sharding_rules', ())
Expand Down Expand Up @@ -174,4 +174,4 @@ def get_abstract_model(init_fn, mesh):
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_state, get_named_sharding(abs_state, mesh)
)
return gdef, abs_state
return gdef, abs_state
14 changes: 11 additions & 3 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import threading
import typing as tp
from typing import Any
import warnings

from flax import config
from flax import errors
Expand Down Expand Up @@ -1093,7 +1094,11 @@ def __init__(
metadata['on_remove_axis'] = var_t.on_remove_axis

if 'sharding' in metadata:
metadata['sharding_names'] = metadata.pop('sharding')
metadata['sharding_metadata'] = metadata.pop('sharding')

if 'sharding_names' in metadata: # for bw compat
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
metadata['sharding_metadata'] = metadata.pop('sharding_names')

# run create_value hooks
if 'on_create_value' in metadata:
Expand All @@ -1103,10 +1108,10 @@ def __init__(
# run create_value hook
value = self.create_value(value) # type: ignore
# shard the _value if applicable
if eager_sharding and 'sharding_names' in metadata:
if eager_sharding and 'sharding_metadata' in metadata:
value = core_spmd.shard_value(
value,
metadata['sharding_names'],
metadata['sharding_metadata'],
metadata.get('sharding_rules', None),
metadata.get('mesh', None),
)
Expand All @@ -1133,6 +1138,9 @@ def _check_can_update(self):
)

def __getattr__(self, name: str) -> tp.Any:
if name == 'sharding_names': # for backward compatibility
warnings.warn("'sharding_names' is deprecated. Use 'sharding_metadata' instead.", DeprecationWarning)
return self.sharding_metadata
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use has_metadata and get_metadata in the lines bellow and teach these methods how to handle sharding_names.

if name in object.__getattribute__(self, '_var_metadata'):
return self._var_metadata[name]
return getattr(object.__getattribute__(self, '_raw_value'), name)
Expand Down
10 changes: 5 additions & 5 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def create_sharded_nnx_module(x):
self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned)
self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned)
self.assertIsInstance(nnx_model.kernel, nnx.Variable)
assert nnx_model.kernel.sharding_names == ('in', 'out')
assert nnx_model.kernel.sharding_metadata == ('in', 'out')
assert nnx_model.kernel[...].sharding.is_equivalent_to(
jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec('in', 'out')
),
ndim=2,
), f'{nnx_model.kernel[...].sharding = }'

assert nnx_model.bias.sharding_names == ('out-alias',)
assert nnx_model.bias.sharding_metadata == ('out-alias',)
assert nnx_model.bias.sharding_rules == (('out-alias', 'out'),)
assert nnx_model.bias[...].sharding.is_equivalent_to(
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out')),
Expand Down Expand Up @@ -410,7 +410,7 @@ def test_nnx_to_linen_metadata(self):
pspec_tree = nn.get_partition_spec(variables)
assert y.shape == (1, 64)
self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta)
assert variables['params']['kernel'].metadata['sharding_names'] == ('in', 'out')
assert variables['params']['kernel'].metadata['sharding_metadata'] == ('in', 'out')
self.assertEqual(pspec_tree['params']['kernel'],
jax.sharding.PartitionSpec('in', 'out'))
np.testing.assert_allclose(y, x @ variables['params']['kernel'].value)
Expand Down Expand Up @@ -519,8 +519,8 @@ def __call__(self, x):
w, b = model.inner.dot['w'], model.inner.b
np.testing.assert_allclose(model(x), x @ w + b)
self.assertIsInstance(w, nnx.Param)
assert hasattr(w, 'sharding_names') and w.sharding_names == ('in', 'out')
assert hasattr(b, 'sharding_names') and b.sharding_names == ('out-alias', )
assert hasattr(w, 'sharding_metadata') and w.sharding_metadata == ('in', 'out')
assert hasattr(b, 'sharding_metadata') and b.sharding_metadata == ('out-alias', )

def test_linen_nnx_linen(self):
# TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without
Expand Down
6 changes: 3 additions & 3 deletions tests/nnx/nn/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def test(self, module_args_kwargs_initargs):
kwargs = {"rngs": nnx.Rngs(0)}
sharding_names = ("din", "dout")
metadata_kwargs = {
f"{key}_metadata": {"sharding_names": sharding_names[:le]}
f"{key}_metadata": {"sharding_metadata": sharding_names[:le]}
for key, le, _ in metadata_argnames
}

Expand All @@ -410,8 +410,8 @@ def test(self, module_args_kwargs_initargs):
for attr_name, param_name in attrs:
attr = getattr(module, attr_name) if attr_name is not None else module
param = getattr(attr, param_name)
self.assertIsNotNone(param.sharding_names)
self.assertEqual(param.sharding_names, sharding_names[:le])
self.assertIsNotNone(param.sharding_metadata)
self.assertEqual(param.sharding_metadata, sharding_names[:le])


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_sharding_propagation(self):
state = nnx.state(optimizer)
partition_spec = nnx.get_partition_spec(state)

self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_names, ('a', 'b'))
self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_metadata, ('a', 'b'))
self.assertEqual(
partition_spec['opt_state'][0]['mu']['kernel'].value,
jax.sharding.PartitionSpec('a', 'b'),
Expand Down
14 changes: 7 additions & 7 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self, rngs: nnx.Rngs):
4,
kernel_init=nnx.with_metadata(
nnx.initializers.lecun_normal(),
sharding_names=('din', 'dout'),
sharding_metadata=('din', 'dout'),
nickname=('in', 'out'),
on_add_axis=lambda _, idx, name: kadds.append((idx, name)),
on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)),
Expand All @@ -160,7 +160,7 @@ def __call__(self, x: jax.Array):
x = self.linear(x)
# test sharding layer axes is not present inside scan
test.assertEqual(self.linear.kernel.shape, (4, 4))
test.assertEqual(self.linear.kernel.sharding_names, ('din', 'dout'))
test.assertEqual(self.linear.kernel.sharding_metadata, ('din', 'dout'))
# at least a remove_axis was already called to remove the layer axis
test.assertEqual(kremoves[-1], (0, 'layers'))
test.assertEqual(bremoves[-1], (0, 'layers'))
Expand All @@ -175,7 +175,7 @@ def __call__(self, x: jax.Array):
with jax.set_mesh(mesh):
m = MLP(rngs=nnx.Rngs(0))
self.assertEqual(m.linear.kernel.shape, (5, 4, 4))
self.assertEqual(m.linear.kernel.sharding_names, ('layers', 'din', 'dout'))
self.assertEqual(m.linear.kernel.sharding_metadata, ('layers', 'din', 'dout'))
self.assertEqual(m.linear.kernel.nickname, ('nick', 'in', 'out'))
self.assertEqual(m.linear.bias.shape, (5, 4))
# One add_axis called to add the `nnx.vmap` dimension
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_eager_sharding_context(self, use_eager_sharding):
with jax.set_mesh(mesh):
w = nnx.Param(
rngs.lecun_normal()((4, 8)),
sharding_names=(None, 'model'))
sharding_metadata=(None, 'model'))
if use_eager_sharding:
assert has_sharding_spec(w)
else:
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_explicit_sharding(self):
)
v = nnx.Variable(
jnp.ones((4, 4)),
sharding_names=('row', 'col'),
sharding_metadata=('row', 'col'),
mesh=mesh,
)
self.assertEqual(v.sharding.mesh, mesh)
Expand All @@ -295,7 +295,7 @@ def test_explicit_sharding_disable_jit(self):
with jax.disable_jit(True):
v = nnx.Variable(
jnp.ones((4, 4)),
sharding_names=('row', 'col'),
sharding_metadata=('row', 'col'),
mesh=mesh,
)
self.assertEqual(v.sharding.mesh, mesh)
Expand All @@ -313,7 +313,7 @@ def test_explicit_sharding_mesh_context(self):
with jax.set_mesh(mesh):
v = nnx.Variable(
jnp.ones((4, 4)),
sharding_names=('row', 'col'),
sharding_metadata=('row', 'col'),
)
self.assertEqual(v.sharding.mesh, mesh)
self.assertEqual(
Expand Down
Loading
Loading