Skip to content

Commit 68297e6

Browse files
Support 1D optimizer slot variables
PiperOrigin-RevId: 797129907
1 parent b8387f7 commit 68297e6

File tree

5 files changed

+381
-103
lines changed

5 files changed

+381
-103
lines changed

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_optimizer_grad.py

Lines changed: 69 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,23 @@
2020
generator is a function that generates the MLIR code for the optimizer update
2121
computation. Take a look in optimizer.py for examples.
2222
23-
Depending on the optimizer, different number of embedding variables may be
23+
Depending on the optimizer, a different number of embedding variables may be
2424
required. For example, for Adagrad, the optimizer update computation requires
2525
both the embedding table and the accumulator.
2626
27-
These variables are passed in as an 3D array of shape [num_tables, vocab_size,
28-
emb_size].
29-
The order in which the variables are stacked _must_ be identical to the order
27+
These variables are passed in individually as positional operands:
28+
- the embedding table (2D: [vocab_size, emb_dim]) first,
29+
- followed by any number of slot variables, each of which may be
30+
2D ([vocab_size, emb_dim]) **or** 1D ([vocab_size]).
31+
32+
The order in which the variables are provided _must_ be identical to the order
3033
that the XLA compiler expects. For example, for Adagrad, the embedding table
3134
must be at index 0 and the accumulator must be at index 1.
3235
33-
The hyperparameters are passed in as a 1D array of shape [num_hyperparameters].
34-
The order of the hyperparameters _must_ be identical to the order that the XLA
35-
compiler expects. For example, for SGD and Adagrad, the learning rate must be at
36-
index 0.
36+
The hyperparameters are passed as trailing scalar operands (0D tensors) after
37+
the activation gradients argument. The order of the hyperparameters _must_ be
38+
identical to the order that the XLA compiler expects. For example, for SGD and
39+
Adagrad, the learning rate must be at index 0.
3740
"""
3841

3942
import functools
@@ -72,10 +75,7 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
7275
lhs_local_embedding_ids: np.ndarray,
7376
lhs_local_sample_ids: np.ndarray,
7477
lhs_gains: np.ndarray,
75-
embedding_variables: np.ndarray,
76-
activations_grad: np.ndarray,
77-
hyperparameters: np.ndarray,
78-
*_,
78+
*args,
7979
optimizer_generator: Callable[[mlir.LoweringRuleContext, str, int], None],
8080
max_ids_per_partition: int,
8181
max_unique_ids_per_partition: int,
@@ -99,21 +99,44 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
9999
)
100100
if lhs_gains.dtype != np.float32:
101101
raise ValueError(f"lhs_gains must have type float32, got {lhs_gains.dtype}")
102+
# Args Layout:
103+
# num_minibatches, tables, activations_grad, *hparams
104+
arg_list = list(args)
105+
# Strip optional num_minibatches if present (scalar of any dtype).
106+
if arg_list and not arg_list[0].shape:
107+
arg_list = arg_list[1:]
102108

103-
if embedding_variables.dtype != np.float32:
104-
raise ValueError(
105-
"embedding_table must have type float32, got"
106-
f" {embedding_variables.dtype}"
107-
)
108-
if hyperparameters.dtype != np.float32 or len(hyperparameters.shape) != 1:
109+
# Split trailing scalar hyperparameters.
110+
split = len(arg_list)
111+
while split > 0 and not arg_list[split - 1].shape:
112+
split -= 1
113+
non_hparams = arg_list[:split]
114+
115+
if not non_hparams:
116+
raise ValueError("Missing activations_grad and table operands.")
117+
activations_grad = non_hparams[-1]
118+
tables = non_hparams[:-1]
119+
120+
if not tables:
121+
raise ValueError("At least one table (the embedding variable) is required.")
122+
if activations_grad.dtype != np.float32 or len(activations_grad.shape) != 2:
109123
raise ValueError(
110-
"hyperparameters must be 1 dimensional with dtype float32, got"
111-
f" {hyperparameters.dtype} and shape {hyperparameters.shape}"
124+
"activations_grad must be rank-2 with dtype float32, got"
125+
f" dtype {activations_grad.dtype} and shape {activations_grad.shape}"
112126
)
113-
if activations_grad.dtype != np.float32:
127+
# Validate tables: embedding table first (2D), slots may be 1D or 2D.
128+
if tables[0].dtype != np.float32 or len(tables[0].shape) != 2:
114129
raise ValueError(
115-
f"activations_grad must have type float32, got {activations_grad.dtype}"
130+
"The first table must be the embedding table (rank-2, dtype float32),"
131+
f" got dtype {tables[0].dtype} and shape {tables[0].shape}"
116132
)
133+
for t in tables:
134+
if t.dtype != np.float32:
135+
raise ValueError("All tables must have dtype float32.")
136+
if len(t.shape) not in (1, 2):
137+
raise ValueError(
138+
"Slot variables must be rank-1 or rank-2; got shape {t.shape}."
139+
)
117140
if len(lhs_row_pointers.shape) != 1:
118141
raise ValueError(
119142
f"lhs_row_pointers must have rank 1, got {lhs_row_pointers.shape}"
@@ -128,20 +151,6 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
128151
f"equal rank 1 shapes, got shapes {lhs_local_sample_ids.shape}, "
129152
f"{lhs_local_embedding_ids.shape} and {lhs_gains.shape}"
130153
)
131-
if len(embedding_variables.shape) != 3:
132-
raise ValueError(
133-
f"embedding_table must have rank 3, got {embedding_variables.shape}"
134-
)
135-
if len(activations_grad.shape) != 2:
136-
raise ValueError(
137-
f"activations_grad must have rank 2, got {activations_grad.shape}"
138-
)
139-
if embedding_variables.shape[-1] != activations_grad.shape[-1]:
140-
raise ValueError(
141-
"embedding_table and activations_grad must have equal feature (minor)"
142-
f" dimensions, got {embedding_variables.shape},"
143-
f" {activations_grad.shape}"
144-
)
145154
if not callable(optimizer_generator):
146155
raise ValueError("optimizer_generator must be callable")
147156

@@ -163,15 +172,6 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
163172
if not computation_name:
164173
raise ValueError("computation_name must be non-empty")
165174

166-
num_tables = embedding_variables.shape[0]
167-
return tuple(
168-
core.ShapedArray(
169-
(embedding_variables.shape[1], embedding_variables.shape[2]),
170-
dtype=jnp.float32,
171-
)
172-
for _ in range(num_tables)
173-
)
174-
175175

176176
tpu_sparse_dense_matmul_optimizer_grad_primitive.def_abstract_eval(
177177
_tpu_sparse_dense_matmul_optimizer_grad_abstract_eval
@@ -184,21 +184,36 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
184184
lhs_local_embedding_ids: mlir.ir.BlockArgument,
185185
lhs_local_sample_ids: mlir.ir.BlockArgument,
186186
lhs_gains: mlir.ir.BlockArgument,
187-
embedding_variables: mlir.ir.BlockArgument,
188-
activations_grad: mlir.ir.BlockArgument,
189-
hyperparameters: mlir.ir.BlockArgument,
190-
*,
187+
*args: mlir.ir.Value,
191188
optimizer_generator: Callable[[mlir.LoweringRuleContext, str, int], None],
192189
max_ids_per_partition: int,
193190
max_unique_ids_per_partition: int,
194191
computation_name: str = "sparse_dense_matmul_optimizer_grad",
195192
sharding_strategy: int = 1,
196193
) -> Tuple[np.ndarray, ...]:
197194
"""Lowering for sparse_dense_matmul_optimizer_grad."""
198-
num_slot_variables = (
199-
embedding_variables.type.maybe_downcast().get_dim_size(0) - 1
200-
)
201-
num_hyperparameters = hyperparameters.type.maybe_downcast().get_dim_size(0)
195+
args = list(args)
196+
if args and args[0].type.maybe_downcast().get_rank() == 0:
197+
args = args[1:]
198+
199+
# Split trailing scalar hyperparameters.
200+
split = len(args)
201+
while split > 0 and args[split - 1].type.maybe_downcast().get_rank() == 0:
202+
split -= 1
203+
non_hparams = args[:split]
204+
hyperparams = args[split:]
205+
if not non_hparams:
206+
raise ValueError("Missing activations_grad and table operands.")
207+
activations_grad = non_hparams[-1]
208+
tables = non_hparams[:-1]
209+
210+
if not tables:
211+
raise ValueError("At least one embedding table is required.")
212+
emb_rank = tables[0].type.maybe_downcast().get_rank()
213+
if emb_rank != 2:
214+
raise ValueError("First table must be rank-2 embedding variable.")
215+
num_slot_variables = len(tables) - 1
216+
num_hyperparameters = len(hyperparams)
202217
sdmm_sgd_config = {
203218
"max_ids_per_partition": max_ids_per_partition,
204219
"max_unique_ids_per_partition": max_unique_ids_per_partition,
@@ -214,48 +229,11 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
214229

215230
optimizer_update_computation_name = computation_name
216231

217-
# Because we cannot take in a tuple or list of Nd arrays, we need to slice
218-
# the embedding tables into individual tables. The order of the user input
219-
# must be kept intact.
220-
tables = []
221-
table_shape = (
222-
embedding_variables.type.maybe_downcast().get_dim_size(1),
223-
embedding_variables.type.maybe_downcast().get_dim_size(2),
224-
)
225-
for i in range(num_slot_variables + 1):
226-
sliced = hlo.slice(
227-
embedding_variables,
228-
mlir.dense_int_array([i, 0, 0]),
229-
mlir.dense_int_array([i + 1, table_shape[0], table_shape[1]]),
230-
mlir.dense_int_array([1, 1, 1]),
231-
)
232-
sliced = hlo.reshape(
233-
ir.RankedTensorType.get(
234-
[table_shape[0], table_shape[1]],
235-
ir.F32Type.get(),
236-
),
237-
sliced,
238-
)
239-
tables.append(sliced)
240232
optimizer_generator(
241233
ctx,
242234
optimizer_update_computation_name,
243235
tables[0].type.maybe_downcast().get_dim_size(1),
244236
)
245-
hyperparams = []
246-
f32type = mlir.aval_to_ir_type(core.ShapedArray((), np.float32))
247-
for i in range(num_hyperparameters):
248-
sliced_param = hlo.slice(
249-
hyperparameters,
250-
mlir.dense_int_array([i]),
251-
mlir.dense_int_array([i + 1]),
252-
mlir.dense_int_array([1]),
253-
)
254-
sliced_param = hlo.reshape(
255-
f32type,
256-
sliced_param,
257-
)
258-
hyperparams.append(sliced_param)
259237

260238
operands = (
261239
[
@@ -271,7 +249,7 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
271249
op = jax.ffi.ffi_lowering(
272250
"SparseDenseMatmulGradOpWithOptimizerUpdate",
273251
result_types=[
274-
ir.TupleType.get_tuple([tables[0].type for _ in range(len(tables))]) # pylint: disable=attribute-error
252+
ir.TupleType.get_tuple([t.type for t in tables]) # pylint: disable=attribute-error
275253
],
276254
backend_config=backend_config,
277255
called_computations=[optimizer_update_computation_name],

0 commit comments

Comments
 (0)