2020generator is a function that generates the MLIR code for the optimizer update
2121computation. Take a look in optimizer.py for examples.
2222
23- Depending on the optimizer, different number of embedding variables may be
23+ Depending on the optimizer, a different number of embedding variables may be
2424required. For example, for Adagrad, the optimizer update computation requires
2525both the embedding table and the accumulator.
2626
27- These variables are passed in as an 3D array of shape [num_tables, vocab_size,
28- emb_size].
29- The order in which the variables are stacked _must_ be identical to the order
27+ These variables are passed in individually as positional operands:
28+ - the embedding table (2D: [vocab_size, emb_dim]) first,
29+ - followed by any number of slot variables, each of which may be
30+ 2D ([vocab_size, emb_dim]) **or** 1D ([vocab_size]).
31+
32+ The order in which the variables are provided _must_ be identical to the order
3033that the XLA compiler expects. For example, for Adagrad, the embedding table
3134must be at index 0 and the accumulator must be at index 1.
3235
33- The hyperparameters are passed in as a 1D array of shape [num_hyperparameters].
34- The order of the hyperparameters _must_ be identical to the order that the XLA
35- compiler expects. For example, for SGD and Adagrad, the learning rate must be at
36- index 0.
36+ The hyperparameters are passed as trailing scalar operands (0D tensors) after
37+ the activation gradients argument. The order of the hyperparameters _must_ be
38+ identical to the order that the XLA compiler expects. For example, for SGD and
39+ Adagrad, the learning rate must be at index 0.
3740"""
3841
3942import functools
@@ -72,10 +75,7 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
7275 lhs_local_embedding_ids : np .ndarray ,
7376 lhs_local_sample_ids : np .ndarray ,
7477 lhs_gains : np .ndarray ,
75- embedding_variables : np .ndarray ,
76- activations_grad : np .ndarray ,
77- hyperparameters : np .ndarray ,
78- * _ ,
78+ * args ,
7979 optimizer_generator : Callable [[mlir .LoweringRuleContext , str , int ], None ],
8080 max_ids_per_partition : int ,
8181 max_unique_ids_per_partition : int ,
@@ -99,21 +99,44 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
9999 )
100100 if lhs_gains .dtype != np .float32 :
101101 raise ValueError (f"lhs_gains must have type float32, got { lhs_gains .dtype } " )
102+ # Args Layout:
103+ # num_minibatches, tables, activations_grad, *hparams
104+ arg_list = list (args )
105+ # Strip optional num_minibatches if present (scalar of any dtype).
106+ if arg_list and not arg_list [0 ].shape :
107+ arg_list = arg_list [1 :]
102108
103- if embedding_variables .dtype != np .float32 :
104- raise ValueError (
105- "embedding_table must have type float32, got"
106- f" { embedding_variables .dtype } "
107- )
108- if hyperparameters .dtype != np .float32 or len (hyperparameters .shape ) != 1 :
109+ # Split trailing scalar hyperparameters.
110+ split = len (arg_list )
111+ while split > 0 and not arg_list [split - 1 ].shape :
112+ split -= 1
113+ non_hparams = arg_list [:split ]
114+
115+ if not non_hparams :
116+ raise ValueError ("Missing activations_grad and table operands." )
117+ activations_grad = non_hparams [- 1 ]
118+ tables = non_hparams [:- 1 ]
119+
120+ if not tables :
121+ raise ValueError ("At least one table (the embedding variable) is required." )
122+ if activations_grad .dtype != np .float32 or len (activations_grad .shape ) != 2 :
109123 raise ValueError (
110- "hyperparameters must be 1 dimensional with dtype float32, got"
111- f" { hyperparameters .dtype } and shape { hyperparameters .shape } "
124+ "activations_grad must be rank-2 with dtype float32, got"
125+ f" dtype { activations_grad .dtype } and shape { activations_grad .shape } "
112126 )
113- if activations_grad .dtype != np .float32 :
127+ # Validate tables: embedding table first (2D), slots may be 1D or 2D.
128+ if tables [0 ].dtype != np .float32 or len (tables [0 ].shape ) != 2 :
114129 raise ValueError (
115- f"activations_grad must have type float32, got { activations_grad .dtype } "
130+ "The first table must be the embedding table (rank-2, dtype float32),"
131+ f" got dtype { tables [0 ].dtype } and shape { tables [0 ].shape } "
116132 )
133+ for t in tables :
134+ if t .dtype != np .float32 :
135+ raise ValueError ("All tables must have dtype float32." )
136+ if len (t .shape ) not in (1 , 2 ):
137+ raise ValueError (
138+ "Slot variables must be rank-1 or rank-2; got shape {t.shape}."
139+ )
117140 if len (lhs_row_pointers .shape ) != 1 :
118141 raise ValueError (
119142 f"lhs_row_pointers must have rank 1, got { lhs_row_pointers .shape } "
@@ -128,20 +151,6 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
128151 f"equal rank 1 shapes, got shapes { lhs_local_sample_ids .shape } , "
129152 f"{ lhs_local_embedding_ids .shape } and { lhs_gains .shape } "
130153 )
131- if len (embedding_variables .shape ) != 3 :
132- raise ValueError (
133- f"embedding_table must have rank 3, got { embedding_variables .shape } "
134- )
135- if len (activations_grad .shape ) != 2 :
136- raise ValueError (
137- f"activations_grad must have rank 2, got { activations_grad .shape } "
138- )
139- if embedding_variables .shape [- 1 ] != activations_grad .shape [- 1 ]:
140- raise ValueError (
141- "embedding_table and activations_grad must have equal feature (minor)"
142- f" dimensions, got { embedding_variables .shape } ,"
143- f" { activations_grad .shape } "
144- )
145154 if not callable (optimizer_generator ):
146155 raise ValueError ("optimizer_generator must be callable" )
147156
@@ -163,15 +172,6 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
163172 if not computation_name :
164173 raise ValueError ("computation_name must be non-empty" )
165174
166- num_tables = embedding_variables .shape [0 ]
167- return tuple (
168- core .ShapedArray (
169- (embedding_variables .shape [1 ], embedding_variables .shape [2 ]),
170- dtype = jnp .float32 ,
171- )
172- for _ in range (num_tables )
173- )
174-
175175
176176tpu_sparse_dense_matmul_optimizer_grad_primitive .def_abstract_eval (
177177 _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval
@@ -184,21 +184,36 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
184184 lhs_local_embedding_ids : mlir .ir .BlockArgument ,
185185 lhs_local_sample_ids : mlir .ir .BlockArgument ,
186186 lhs_gains : mlir .ir .BlockArgument ,
187- embedding_variables : mlir .ir .BlockArgument ,
188- activations_grad : mlir .ir .BlockArgument ,
189- hyperparameters : mlir .ir .BlockArgument ,
190- * ,
187+ * args : mlir .ir .Value ,
191188 optimizer_generator : Callable [[mlir .LoweringRuleContext , str , int ], None ],
192189 max_ids_per_partition : int ,
193190 max_unique_ids_per_partition : int ,
194191 computation_name : str = "sparse_dense_matmul_optimizer_grad" ,
195192 sharding_strategy : int = 1 ,
196193) -> Tuple [np .ndarray , ...]:
197194 """Lowering for sparse_dense_matmul_optimizer_grad."""
198- num_slot_variables = (
199- embedding_variables .type .maybe_downcast ().get_dim_size (0 ) - 1
200- )
201- num_hyperparameters = hyperparameters .type .maybe_downcast ().get_dim_size (0 )
195+ args = list (args )
196+ if args and args [0 ].type .maybe_downcast ().get_rank () == 0 :
197+ args = args [1 :]
198+
199+ # Split trailing scalar hyperparameters.
200+ split = len (args )
201+ while split > 0 and args [split - 1 ].type .maybe_downcast ().get_rank () == 0 :
202+ split -= 1
203+ non_hparams = args [:split ]
204+ hyperparams = args [split :]
205+ if not non_hparams :
206+ raise ValueError ("Missing activations_grad and table operands." )
207+ activations_grad = non_hparams [- 1 ]
208+ tables = non_hparams [:- 1 ]
209+
210+ if not tables :
211+ raise ValueError ("At least one embedding table is required." )
212+ emb_rank = tables [0 ].type .maybe_downcast ().get_rank ()
213+ if emb_rank != 2 :
214+ raise ValueError ("First table must be rank-2 embedding variable." )
215+ num_slot_variables = len (tables ) - 1
216+ num_hyperparameters = len (hyperparams )
202217 sdmm_sgd_config = {
203218 "max_ids_per_partition" : max_ids_per_partition ,
204219 "max_unique_ids_per_partition" : max_unique_ids_per_partition ,
@@ -214,48 +229,11 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
214229
215230 optimizer_update_computation_name = computation_name
216231
217- # Because we cannot take in a tuple or list of Nd arrays, we need to slice
218- # the embedding tables into individual tables. The order of the user input
219- # must be kept intact.
220- tables = []
221- table_shape = (
222- embedding_variables .type .maybe_downcast ().get_dim_size (1 ),
223- embedding_variables .type .maybe_downcast ().get_dim_size (2 ),
224- )
225- for i in range (num_slot_variables + 1 ):
226- sliced = hlo .slice (
227- embedding_variables ,
228- mlir .dense_int_array ([i , 0 , 0 ]),
229- mlir .dense_int_array ([i + 1 , table_shape [0 ], table_shape [1 ]]),
230- mlir .dense_int_array ([1 , 1 , 1 ]),
231- )
232- sliced = hlo .reshape (
233- ir .RankedTensorType .get (
234- [table_shape [0 ], table_shape [1 ]],
235- ir .F32Type .get (),
236- ),
237- sliced ,
238- )
239- tables .append (sliced )
240232 optimizer_generator (
241233 ctx ,
242234 optimizer_update_computation_name ,
243235 tables [0 ].type .maybe_downcast ().get_dim_size (1 ),
244236 )
245- hyperparams = []
246- f32type = mlir .aval_to_ir_type (core .ShapedArray ((), np .float32 ))
247- for i in range (num_hyperparameters ):
248- sliced_param = hlo .slice (
249- hyperparameters ,
250- mlir .dense_int_array ([i ]),
251- mlir .dense_int_array ([i + 1 ]),
252- mlir .dense_int_array ([1 ]),
253- )
254- sliced_param = hlo .reshape (
255- f32type ,
256- sliced_param ,
257- )
258- hyperparams .append (sliced_param )
259237
260238 operands = (
261239 [
@@ -271,7 +249,7 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
271249 op = jax .ffi .ffi_lowering (
272250 "SparseDenseMatmulGradOpWithOptimizerUpdate" ,
273251 result_types = [
274- ir .TupleType .get_tuple ([tables [ 0 ] .type for _ in range ( len ( tables )) ]) # pylint: disable=attribute-error
252+ ir .TupleType .get_tuple ([t .type for t in tables ]) # pylint: disable=attribute-error
275253 ],
276254 backend_config = backend_config ,
277255 called_computations = [optimizer_update_computation_name ],
0 commit comments