Skip to content

Commit b8387f7

Browse files
patnotzGoogle-ML-Automation
authored andcommitted
Distributed Flax Shakespeare model training example
PiperOrigin-RevId: 797963267
1 parent 8bf2049 commit b8387f7

File tree

5 files changed

+70
-32
lines changed

5 files changed

+70
-32
lines changed

jax_tpu_embedding/sparsecore/examples/models/shakespeare/flax_model.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from jax_tpu_embedding.sparsecore.lib.nn import embedding
2121
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
2222

23-
2423
shard_map = jax.experimental.shard_map.shard_map
2524
Nested = embedding.Nested
2625

@@ -30,19 +29,39 @@
3029
################################################################################
3130
class Model(nn.Module):
3231
"""Shakespeare model using embedding layer."""
33-
feature_specs: Nested[embedding_spec.FeatureSpec]
3432

33+
feature_specs: Nested[embedding_spec.FeatureSpec]
3534
global_batch_size: int
3635
vocab_size: int
3736
seq_len: int
3837
embedding_size: int
39-
table_name: str = 'shakespeare_table'
4038
feature_name: str = 'shakespeare_feature'
4139
mesh: jax.sharding.Mesh | None = None
4240
sharding_axis: str = 'sparsecore_sharding'
4341

42+
def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]):
43+
# Add a sharding constraint to the array.
44+
#
45+
# Add a sharding constraint to the array to ensure that the sharding
46+
# information is not lost during compilation. This may not be necessary but
47+
# it helps SPMD and ensures that the sharding information is as expected.
48+
#
49+
# Args:
50+
# x: The array to add the sharding constraint to.
51+
# names: The mesh axes for the partition spec.
52+
#
53+
# Returns:
54+
# The array with the sharding constraint added.
55+
return jax.lax.with_sharding_constraint(
56+
x,
57+
jax.sharding.NamedSharding(
58+
self.mesh, jax.sharding.PartitionSpec(*names)
59+
),
60+
)
61+
4462
@nn.compact
45-
def __call__(self, embedding_lookup_inputs: embed.EmbeddingLookupInput):
63+
def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput):
64+
# Run the embedding layer.
4665
x = embed.SparseCoreEmbed(
4766
feature_specs=self.feature_specs,
4867
mesh=self.mesh,
@@ -52,9 +71,28 @@ def __call__(self, embedding_lookup_inputs: embed.EmbeddingLookupInput):
5271
# Unpack the activations.
5372
x = x[self.feature_name]
5473
x = jnp.reshape(x, (self.global_batch_size, -1))
74+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
5575

56-
# Apply the model.
57-
x = nn.Dense(self.embedding_size)(x)
58-
x = nn.Dense(self.vocab_size)(x)
76+
# Apply the dense portion of the model.
77+
x = nn.Dense(
78+
self.embedding_size,
79+
kernel_init=nn.with_partitioning(
80+
nn.initializers.xavier_uniform(), (self.sharding_axis,)
81+
),
82+
bias_init=nn.with_partitioning(
83+
nn.initializers.zeros, (self.sharding_axis,)
84+
),
85+
)(x)
86+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
87+
x = nn.Dense(
88+
self.vocab_size,
89+
kernel_init=nn.with_partitioning(
90+
nn.initializers.xavier_uniform(), (self.sharding_axis,)
91+
),
92+
bias_init=nn.with_partitioning(
93+
nn.initializers.zeros, (self.sharding_axis,)
94+
),
95+
)(x)
96+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
5997

6098
return x

jax_tpu_embedding/sparsecore/lib/flax/embed.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
DLL = layout.DeviceLocalLayout # type: ignore
3333
Layout = layout.Format
3434
LogicalNames = typing.LogicalNames
35-
P = jax.sharding.PartitionSpec
3635
shard_map = jax.experimental.shard_map.shard_map
3736
Nested = embedding.Nested
3837
EmbeddingLookupInput = embedding.PreprocessedInput
@@ -60,7 +59,8 @@ def with_sparsecore_layout(
6059
fn: Callable[..., Any],
6160
names: LogicalNames,
6261
mesh: jax.sharding.Mesh,
63-
):
62+
) -> Callable[..., Any]:
63+
"""Wraps a function to add a SparseCore layout."""
6464
@functools.wraps(fn)
6565
def wrapper(*args, **kwargs):
6666
return WithSparseCoreLayout(fn(*args, **kwargs), names, mesh=mesh)
@@ -73,7 +73,7 @@ class SparseCoreEmbed(nn.Module):
7373

7474
# A sequence of FeatureSpecs to specify the configurations for the
7575
# input feature.
76-
feature_specs: Nested[embedding_spec.FeatureSpec]
76+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
7777
# Axis in the mesh to use for sharding.
7878
sharding_axis: str = 'sparsecore_sharding'
7979
# Mesh to use for the embedding layer.
@@ -94,8 +94,10 @@ def __post_init__(self):
9494
super().__post_init__()
9595

9696
def setup(self):
97-
self.embedding_table_partition = P(self.sharding_axis, None)
98-
self.data_partition = P(self.sharding_axis)
97+
self.embedding_table_partition = jax.sharding.PartitionSpec(
98+
self.sharding_axis, None
99+
)
100+
self.data_partition = jax.sharding.PartitionSpec(self.sharding_axis)
99101
self.num_shards = self.mesh.shape[self.sharding_axis]
100102

101103
initializer = functools.partial(
@@ -118,17 +120,17 @@ def _wrap_initializer(
118120
self, initializer: Callable[[jax.Array], tuple[jax.Array, ...]]
119121
):
120122
return with_sparsecore_layout(
121-
initializer,
122-
(self.sharding_axis,),
123+
fn=initializer,
124+
names=(self.sharding_axis, None),
123125
mesh=self.mesh,
124126
)
125127

126128
def preprocess_inputs(
127129
self,
128130
step: int,
129-
features: Nested[np.ndarray],
130-
features_weights: Nested[np.ndarray],
131-
) -> EmbeddingLookupInput:
131+
features: embedding.Nested[np.ndarray],
132+
features_weights: embedding.Nested[np.ndarray],
133+
) -> embedding.PreprocessedInput:
132134
"""Preprocesses the input for sparse dense matmul.
133135
134136
This method do not need to be invoked with module.apply().
@@ -157,8 +159,8 @@ def preprocess_inputs(
157159
)[0]
158160

159161
def __call__(
160-
self, embedding_lookup_inputs: EmbeddingLookupInput
161-
) -> Nested[jax.Array]:
162+
self, embedding_lookup_inputs: embedding.PreprocessedInput
163+
) -> embedding.Nested[jax.Array]:
162164
"""Computes the embedding activations.
163165
164166
Args:
@@ -175,8 +177,8 @@ def __call__(
175177

176178
def apply_gradient(
177179
self,
178-
gradients: Nested[jax.Array],
179-
embedding_lookup_inputs: EmbeddingLookupInput,
180+
gradients: embedding.Nested[jax.Array],
181+
embedding_lookup_inputs: embedding.PreprocessedInput,
180182
) -> Mapping[str, Mapping[str, jax.Array]]:
181183
"""Apply the gradients to the embedding variables.
182184
@@ -202,7 +204,7 @@ def apply_gradient(
202204
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
203205
def _emb_lookup(
204206
embedding_layer: SparseCoreEmbed,
205-
embedding_lookup_inputs: EmbeddingLookupInput,
207+
embedding_lookup_inputs: embedding.PreprocessedInput,
206208
emb_table: Mapping[str, tuple[jax.Array, ...]],
207209
):
208210
pt = embedding_layer.embedding_table_partition
@@ -226,7 +228,7 @@ def _emb_lookup(
226228

227229
def _emb_lookup_fwd(
228230
embedding_layer: SparseCoreEmbed,
229-
embedding_lookup_inputs: EmbeddingLookupInput,
231+
embedding_lookup_inputs: embedding.PreprocessedInput,
230232
emb_table: Mapping[str, tuple[jax.Array, ...]],
231233
):
232234
return _emb_lookup(

jax_tpu_embedding/sparsecore/lib/flax/tests/autograd_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import jax.numpy as jnp
2323
from jax_tpu_embedding.sparsecore.examples.models.shakespeare import dataset as shakespeare_data
2424
from jax_tpu_embedding.sparsecore.examples.models.shakespeare import flax_model as shakespeare_model
25-
from jax_tpu_embedding.sparsecore.lib.flax import embed
2625
from jax_tpu_embedding.sparsecore.lib.flax import embed_optimizer
2726
from jax_tpu_embedding.sparsecore.lib.nn import embedding
2827
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
@@ -173,7 +172,7 @@ def process_inputs(batch_number, feature_batch):
173172
)
174173
def train_step(
175174
params: Any,
176-
embedding_lookup_inputs: embed.EmbeddingLookupInput,
175+
embedding_lookup_inputs: embedding.PreprocessedInput,
177176
labels: jax.Array,
178177
opt_state,
179178
):

jax_tpu_embedding/sparsecore/lib/flax/tests/embed_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15-
from typing import Tuple
1615

1716
from absl.testing import absltest
1817
import einops
@@ -197,8 +196,8 @@ class EmbeddingLayerTest(absltest.TestCase):
197196

198197
def _row_initialize_with_padding(
199198
self,
200-
shape: Tuple[int, ...],
201-
padded_shape: Tuple[int, ...],
199+
shape: tuple[int, ...],
200+
padded_shape: tuple[int, ...],
202201
offset: int = 0,
203202
pad_value: float = _PAD_VALUE,
204203
):

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import collections
1717
import dataclasses
1818
import functools
19-
from typing import List, Mapping, NamedTuple, Sequence, Tuple, TypeAlias, TypeVar, Union
19+
from typing import List, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar, Union
2020
import warnings
2121

2222
from absl import logging
@@ -119,8 +119,8 @@ class PreprocessedInput(struct.PyTreeNode):
119119
"""
120120

121121
sparse_dense_matmul_input: SparseDenseMatmulInput
122-
num_minibatches: jnp.ndarray = struct.field(
123-
default_factory=lambda: jnp.array(1)
122+
num_minibatches: np.ndarray = struct.field(
123+
default_factory=lambda: np.array(1)
124124
)
125125

126126
# Backward compatibility properties and functions. This class acts as a
@@ -1068,7 +1068,7 @@ def _init_stacked_embedding_table(
10681068
stack_name: str,
10691069
table_specs: List[embedding_spec.TableSpec],
10701070
global_sharding: jax.sharding.NamedSharding,
1071-
sharding_axis: str | Tuple[str, ...],
1071+
sharding_axis: str | tuple[str, ...],
10721072
num_sparsecore_per_device: int | None = None,
10731073
) -> EmbeddingVariables:
10741074
"""Initializes a stacked embedding table."""

0 commit comments

Comments
 (0)