Skip to content

Commit 1e48380

Browse files
author
Flax Authors
committed
Merge pull request #4637 from IvyZX:custom-meta
PiperOrigin-RevId: 738191979
2 parents dd3c18b + a920dd2 commit 1e48380

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

flax/nnx/spmd.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,24 @@ class HasSharding(tp.Protocol):
3838
def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]:
3939
return hasattr(x, 'sharding') and x.sharding is not None
4040

41-
def add_axis(tree: A, index: int, params: tp.Mapping) -> A:
42-
axis_name = _get_partition_name(params)
41+
def add_axis(tree: A, index: int, transform_metadata: tp.Mapping) -> A:
42+
axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata)
43+
44+
def insert_field(fields, index, value):
45+
iterable = list(fields)
46+
while len(iterable) < index:
47+
iterable.append(None)
48+
iterable.insert(index, value)
49+
return tuple(iterable)
4350

4451
def _add_axis(x: tp.Any):
4552
if isinstance(x, variablelib.VariableState):
4653
if _has_sharding(x) and x.sharding is not None:
47-
sharding: list[str | None] = list(x.sharding)
48-
while len(sharding) < index:
49-
sharding.append(None)
50-
sharding.insert(index, axis_name)
51-
x.sharding = tuple(sharding) # type: ignore
54+
x.sharding = insert_field(x.sharding, index, axis_name)
55+
56+
for k, v in other_meta.items():
57+
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
58+
setattr(x, k, insert_field(t, index, v))
5259

5360
assert isinstance(x, variablelib.VariableState)
5461
x.add_axis(index, axis_name)
@@ -59,15 +66,23 @@ def _add_axis(x: tp.Any):
5966
)
6067

6168

62-
def remove_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:
63-
axis_name = _get_partition_name(params)
69+
def remove_axis(tree: A, index: int, transform_metadata: tp.Mapping[tp.Any, tp.Any]) -> A:
70+
axis_name, other_meta = _get_partition_name_and_metadata(transform_metadata)
71+
72+
def remove_field(fields, index, value):
73+
iterable = list(fields)
74+
assert iterable.pop(index) == value
75+
return tuple(iterable)
6476

6577
def _remove_axis(x: tp.Any):
6678
if isinstance(x, variablelib.VariableState):
6779
if hasattr(x, 'sharding') and x.sharding is not None:
68-
sharding = list(x.sharding)
69-
assert sharding.pop(index) == axis_name
70-
x.sharding = tuple(sharding)
80+
x.sharding = remove_field(x.sharding, index, axis_name)
81+
82+
for k, v in other_meta.items():
83+
if hasattr(x, k) and (t := getattr(x, k)) and isinstance(t, tuple):
84+
setattr(x, k, remove_field(t, index, v))
85+
7186
x.remove_axis(index, axis_name)
7287
return x
7388

@@ -78,13 +93,17 @@ def _remove_axis(x: tp.Any):
7893
)
7994

8095

81-
def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str:
82-
if PARTITION_NAME not in params:
96+
def _get_partition_name_and_metadata(
97+
transform_metadata: tp.Mapping[tp.Any, tp.Any]
98+
) -> tuple[str, tp.Mapping[tp.Any, tp.Any]]:
99+
if PARTITION_NAME not in transform_metadata:
83100
raise ValueError(
84101
'Trying to transform a Partitioned variable but "partition_name" '
85-
f'is not specified in scan_metadata: {params}'
102+
f'is not specified in transform_metadata: {transform_metadata}'
86103
)
87-
return params[PARTITION_NAME]
104+
other_meta = dict(transform_metadata) # shallow copy
105+
other_meta.pop(PARTITION_NAME)
106+
return transform_metadata[PARTITION_NAME], other_meta
88107

89108

90109
def get_partition_spec(tree: A) -> A:

tests/nnx/spmd_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class MLP(nnx.Module):
108108
@nnx.split_rngs(splits=5)
109109
@nnx.vmap(
110110
in_axes=(0, 0),
111-
transform_metadata={nnx.PARTITION_NAME: 'layers'},
111+
transform_metadata={nnx.PARTITION_NAME: 'layers', 'nickname': 'nick'},
112112
)
113113
def __init__(self, rngs: nnx.Rngs):
114114
self.linear = nnx.Linear(
@@ -117,6 +117,7 @@ def __init__(self, rngs: nnx.Rngs):
117117
kernel_init=nnx.with_metadata(
118118
nnx.initializers.lecun_normal(),
119119
sharding=('din', 'dout'),
120+
nickname=('in', 'out'),
120121
on_add_axis=lambda _, idx, name: kadds.append((idx, name)),
121122
on_remove_axis=lambda _, idx, name: kremoves.append((idx, name)),
122123
),
@@ -145,6 +146,7 @@ def __call__(self, x: jax.Array):
145146
m = MLP(rngs=nnx.Rngs(0))
146147
self.assertEqual(m.linear.kernel.shape, (5, 3, 3))
147148
self.assertEqual(m.linear.kernel.sharding, ('layers', 'din', 'dout'))
149+
self.assertEqual(m.linear.kernel.nickname, ('nick', 'in', 'out'))
148150
self.assertEqual(m.linear.bias.shape, (5, 3))
149151
# One add_axis called to add the `nnx.vmap` dimension
150152
self.assertEqual(kadds, [(0, 'layers')])

0 commit comments

Comments
 (0)