Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@
generator is a function that generates the MLIR code for the optimizer update
computation. Take a look in optimizer.py for examples.

Depending on the optimizer, different number of embedding variables may be
Depending on the optimizer, a different number of embedding variables may be
required. For example, for Adagrad, the optimizer update computation requires
both the embedding table and the accumulator.

These variables are passed in as an 3D array of shape [num_tables, vocab_size,
emb_size].
The order in which the variables are stacked _must_ be identical to the order
These variables are passed in individually as positional operands:
- the embedding table (2D: [vocab_size, emb_dim]) first,
- followed by any number of slot variables, each of which may be
2D ([vocab_size, emb_dim]) **or** 1D ([vocab_size]).

The order in which the variables are provided _must_ be identical to the order
that the XLA compiler expects. For example, for Adagrad, the embedding table
must be at index 0 and the accumulator must be at index 1.

The hyperparameters are passed in as a 1D array of shape [num_hyperparameters].
The order of the hyperparameters _must_ be identical to the order that the XLA
compiler expects. For example, for SGD and Adagrad, the learning rate must be at
index 0.
The hyperparameters are passed as trailing scalar operands (0D tensors) after
the activation gradients argument. The order of the hyperparameters _must_ be
identical to the order that the XLA compiler expects. For example, for SGD and
Adagrad, the learning rate must be at index 0.
"""

import functools
Expand Down Expand Up @@ -72,10 +75,7 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
lhs_local_embedding_ids: np.ndarray,
lhs_local_sample_ids: np.ndarray,
lhs_gains: np.ndarray,
embedding_variables: np.ndarray,
activations_grad: np.ndarray,
hyperparameters: np.ndarray,
*_,
*args,
optimizer_generator: Callable[[mlir.LoweringRuleContext, str, int], None],
max_ids_per_partition: int,
max_unique_ids_per_partition: int,
Expand All @@ -99,21 +99,44 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
)
if lhs_gains.dtype != np.float32:
raise ValueError(f"lhs_gains must have type float32, got {lhs_gains.dtype}")
# Args Layout:
# num_minibatches, tables, activations_grad, *hparams
arg_list = list(args)
# Strip optional num_minibatches if present (scalar of any dtype).
if arg_list and not arg_list[0].shape:
arg_list = arg_list[1:]

if embedding_variables.dtype != np.float32:
raise ValueError(
"embedding_table must have type float32, got"
f" {embedding_variables.dtype}"
)
if hyperparameters.dtype != np.float32 or len(hyperparameters.shape) != 1:
# Split trailing scalar hyperparameters.
split = len(arg_list)
while split > 0 and not arg_list[split - 1].shape:
split -= 1
non_hparams = arg_list[:split]

if not non_hparams:
raise ValueError("Missing activations_grad and table operands.")
activations_grad = non_hparams[-1]
tables = non_hparams[:-1]

if not tables:
raise ValueError("At least one table (the embedding variable) is required.")
if activations_grad.dtype != np.float32 or len(activations_grad.shape) != 2:
raise ValueError(
"hyperparameters must be 1 dimensional with dtype float32, got"
f" {hyperparameters.dtype} and shape {hyperparameters.shape}"
"activations_grad must be rank-2 with dtype float32, got"
f" dtype {activations_grad.dtype} and shape {activations_grad.shape}"
)
if activations_grad.dtype != np.float32:
# Validate tables: embedding table first (2D), slots may be 1D or 2D.
if tables[0].dtype != np.float32 or len(tables[0].shape) != 2:
raise ValueError(
f"activations_grad must have type float32, got {activations_grad.dtype}"
"The first table must be the embedding table (rank-2, dtype float32),"
f" got dtype {tables[0].dtype} and shape {tables[0].shape}"
)
for t in tables:
if t.dtype != np.float32:
raise ValueError("All tables must have dtype float32.")
if len(t.shape) not in (1, 2):
raise ValueError(
"Slot variables must be rank-1 or rank-2; got shape {t.shape}."
)
if len(lhs_row_pointers.shape) != 1:
raise ValueError(
f"lhs_row_pointers must have rank 1, got {lhs_row_pointers.shape}"
Expand All @@ -128,20 +151,6 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
f"equal rank 1 shapes, got shapes {lhs_local_sample_ids.shape}, "
f"{lhs_local_embedding_ids.shape} and {lhs_gains.shape}"
)
if len(embedding_variables.shape) != 3:
raise ValueError(
f"embedding_table must have rank 3, got {embedding_variables.shape}"
)
if len(activations_grad.shape) != 2:
raise ValueError(
f"activations_grad must have rank 2, got {activations_grad.shape}"
)
if embedding_variables.shape[-1] != activations_grad.shape[-1]:
raise ValueError(
"embedding_table and activations_grad must have equal feature (minor)"
f" dimensions, got {embedding_variables.shape},"
f" {activations_grad.shape}"
)
if not callable(optimizer_generator):
raise ValueError("optimizer_generator must be callable")

Expand All @@ -163,13 +172,8 @@ def _tpu_sparse_dense_matmul_optimizer_grad_abstract_eval(
if not computation_name:
raise ValueError("computation_name must be non-empty")

num_tables = embedding_variables.shape[0]
return tuple(
core.ShapedArray(
(embedding_variables.shape[1], embedding_variables.shape[2]),
dtype=jnp.float32,
)
for _ in range(num_tables)
core.ShapedArray(t.shape, jnp.float32) for t in tables
)


Expand All @@ -184,21 +188,37 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
lhs_local_embedding_ids: mlir.ir.BlockArgument,
lhs_local_sample_ids: mlir.ir.BlockArgument,
lhs_gains: mlir.ir.BlockArgument,
embedding_variables: mlir.ir.BlockArgument,
activations_grad: mlir.ir.BlockArgument,
hyperparameters: mlir.ir.BlockArgument,
*,
*args: mlir.ir.Value,
optimizer_generator: Callable[[mlir.LoweringRuleContext, str, int], None],
max_ids_per_partition: int,
max_unique_ids_per_partition: int,
computation_name: str = "sparse_dense_matmul_optimizer_grad",
sharding_strategy: int = 1,
) -> Tuple[np.ndarray, ...]:
"""Lowering for sparse_dense_matmul_optimizer_grad."""
num_slot_variables = (
embedding_variables.type.maybe_downcast().get_dim_size(0) - 1
)
num_hyperparameters = hyperparameters.type.maybe_downcast().get_dim_size(0)
args = list(args)
if args and ir.RankedTensorType(args[0].type).rank == 0:
args = args[1:]

# Split trailing scalar hyperparameters.
split = len(args)
while split > 0 and ir.RankedTensorType(args[split - 1].type).rank == 0:
split -= 1
non_hparams = args[:split]
hyperparams = args[split:]
if not non_hparams:
raise ValueError("Missing activations_grad and table operands.")
activations_grad = non_hparams[-1]
tables = non_hparams[:-1]

if not tables:
raise ValueError("At least one embedding table is required.")
table_type = tables[0].type
emb_rank = ir.RankedTensorType(table_type).rank
if emb_rank != 2:
raise ValueError("First table must be rank-2 embedding variable.")
num_slot_variables = len(tables) - 1
num_hyperparameters = len(hyperparams)
sdmm_sgd_config = {
"max_ids_per_partition": max_ids_per_partition,
"max_unique_ids_per_partition": max_unique_ids_per_partition,
Expand All @@ -214,48 +234,11 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(

optimizer_update_computation_name = computation_name

# Because we cannot take in a tuple or list of Nd arrays, we need to slice
# the embedding tables into individual tables. The order of the user input
# must be kept intact.
tables = []
table_shape = (
embedding_variables.type.maybe_downcast().get_dim_size(1),
embedding_variables.type.maybe_downcast().get_dim_size(2),
)
for i in range(num_slot_variables + 1):
sliced = hlo.slice(
embedding_variables,
mlir.dense_int_array([i, 0, 0]),
mlir.dense_int_array([i + 1, table_shape[0], table_shape[1]]),
mlir.dense_int_array([1, 1, 1]),
)
sliced = hlo.reshape(
ir.RankedTensorType.get(
[table_shape[0], table_shape[1]],
ir.F32Type.get(),
),
sliced,
)
tables.append(sliced)
optimizer_generator(
ctx,
optimizer_update_computation_name,
tables[0].type.maybe_downcast().get_dim_size(1),
)
hyperparams = []
f32type = mlir.aval_to_ir_type(core.ShapedArray((), np.float32))
for i in range(num_hyperparameters):
sliced_param = hlo.slice(
hyperparameters,
mlir.dense_int_array([i]),
mlir.dense_int_array([i + 1]),
mlir.dense_int_array([1]),
)
sliced_param = hlo.reshape(
f32type,
sliced_param,
)
hyperparams.append(sliced_param)

operands = (
[
Expand All @@ -271,7 +254,7 @@ def _tpu_sparse_dense_matmul_optimizer_grad_lowering(
op = jax.ffi.ffi_lowering(
"SparseDenseMatmulGradOpWithOptimizerUpdate",
result_types=[
ir.TupleType.get_tuple([tables[0].type for _ in range(len(tables))]) # pylint: disable=attribute-error
ir.TupleType.get_tuple([t.type for t in tables]) # pylint: disable=attribute-error
],
backend_config=backend_config,
called_computations=[optimizer_update_computation_name],
Expand Down
Loading