diff --git a/jax_tpu_embedding/sparsecore/lib/core/BUILD b/jax_tpu_embedding/sparsecore/lib/core/BUILD index 16285a66..1608e526 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/BUILD @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -load("//third_party/bazel/python:pybind11.bzl", "pybind_extension") +load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library") load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library") @@ -72,17 +72,41 @@ cc_test( ], ) +pybind_library( + name = "input_preprocessing_py_util", + srcs = [ + "input_preprocessing_py_util.cc", + ], + hdrs = [ + "input_preprocessing_py_util.h", + ], + deps = [ + ":input_preprocessing_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@tsl//tsl/profiler/lib:traceme", + ], +) + pybind_extension( name = "input_preprocessing_cc", - srcs = ["input_preprocessing.cc"], + srcs = [ + "input_preprocessing.cc", + ], deps = [ + ":input_preprocessing_py_util", ":input_preprocessing_threads", ":input_preprocessing_util", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@highway//:hwy", + "@highway//hwy/contrib/sort:vqsort", "@tsl//tsl/profiler/lib:connected_traceme", "@tsl//tsl/profiler/lib:traceme", ], diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 0b5d9ef4..ac7cf2e4 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include -#include #include #include #include @@ -24,6 +23,7 @@ #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/synchronization/blocking_counter.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" #include "pybind11/cast.h" // from @pybind11 @@ -148,48 +148,6 @@ int ExtractCooTensors(const py::array& features, global_device_count, coo_tensors); } -absl::flat_hash_map> -GetStackedTableMetadata(py::list feature_specs, py::list features) { - tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; }); - absl::flat_hash_map> - stacked_table_metadata; - for (int i = 0; i < feature_specs.size(); ++i) { - const py::object& feature_spec = feature_specs[i]; - const py::array& feature = features[i].cast(); - const py::object& feature_transformation = - feature_spec.attr("_id_transformation"); - const py::object& table_spec = feature_spec.attr("table_spec"); - const py::object& stacked_table_spec = - table_spec.attr("stacked_table_spec"); - const std::string stacked_table_name = py::cast( - table_spec.attr("_setting_in_stack").attr("stack_name")); - int col_shift = 0; - int col_offset = 0; - int row_offset = 0; - const int max_ids_per_partition = - py::cast(stacked_table_spec.attr("max_ids_per_partition")); - const int max_unique_ids_per_partition = - py::cast(stacked_table_spec.attr("max_unique_ids_per_partition")); - if (!feature_transformation.is_none()) { - row_offset = py::cast(feature_transformation.attr("row_offset")); - col_shift = py::cast(feature_transformation.attr("col_shift")); - col_offset = py::cast(feature_transformation.attr("col_offset")); - } - stacked_table_metadata[stacked_table_name].emplace_back( - i, max_ids_per_partition, max_unique_ids_per_partition, row_offset, - col_offset, col_shift, - /*batch_size=*/feature.shape(0)); - } - // Sort the stacked tables by row_offset. - for (auto& [_, t] : stacked_table_metadata) { - std::sort(t.begin(), t.end(), - [](const StackedTableMetadata& a, const StackedTableMetadata& b) { - return a.row_offset < b.row_offset; - }); - } - return stacked_table_metadata; -} - // Preprocess inputs for a single table. Stacked table here refers to a // a table that has no parent in the table stacking hierarchy. So in the case // of table stacking, the stacked table is the top level table and in the case diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc new file mode 100644 index 00000000..db1747af --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc @@ -0,0 +1,88 @@ +// Copyright 2024 The JAX SC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/log/check.h" // from @com_google_absl +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" +#include "pybind11/cast.h" // from @pybind11 +#include "pybind11/gil.h" // from @pybind11 +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "tsl/profiler/lib/traceme.h" // from @tsl + +namespace jax_sc_embedding { + +namespace py = ::pybind11; + +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, const int batch_size) { + tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; }); + absl::flat_hash_map> + stacked_table_metadata; + for (int i = 0; i < feature_specs.size(); ++i) { + const py::object& feature_spec = feature_specs[i]; + + const py::object& feature_transformation = + feature_spec.attr("_id_transformation"); + const py::object& table_spec = feature_spec.attr("table_spec"); + const py::object& stacked_table_spec = + table_spec.attr("stacked_table_spec"); + const std::string stacked_table_name = py::cast( + table_spec.attr("_setting_in_stack").attr("stack_name")); + int col_shift = 0; + int col_offset = 0; + int row_offset = 0; + const int max_ids_per_partition = + py::cast(stacked_table_spec.attr("max_ids_per_partition")); + const int max_unique_ids_per_partition = + py::cast(stacked_table_spec.attr("max_unique_ids_per_partition")); + const int vocab_size = + py::cast(stacked_table_spec.attr("stack_vocab_size")); + if (!feature_transformation.is_none()) { + row_offset = py::cast(feature_transformation.attr("row_offset")); + col_shift = py::cast(feature_transformation.attr("col_shift")); + col_offset = py::cast(feature_transformation.attr("col_offset")); + } + stacked_table_metadata[stacked_table_name].emplace_back( + i, max_ids_per_partition, max_unique_ids_per_partition, row_offset, + col_offset, col_shift, + /*batch_size=*/batch_size, vocab_size); + } + // Sort the stacked tables by row_offset. + for (auto& [_, t] : stacked_table_metadata) { + std::sort(t.begin(), t.end(), + [](const StackedTableMetadata& a, const StackedTableMetadata& b) { + return a.row_offset < b.row_offset; + }); + } + return stacked_table_metadata; +} + +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, + const py::list& features) { + tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; }); + int batch_size = features[0].cast().shape(0); + return GetStackedTableMetadata(feature_specs, batch_size); +} + +} // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h new file mode 100644 index 00000000..290b0ecd --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h @@ -0,0 +1,40 @@ +// Copyright 2024 The JAX SC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ +#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ +#include +#include + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 + +namespace jax_sc_embedding { + +namespace py = ::pybind11; + +// Copy information from feature_specs to StackedTableMetadata. +// The features argument is only used to get the batch size. +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, + const py::list& features); + +// Copy information from feature_specs to StackedTableMetadata. +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, int batch_size); + +} // namespace jax_sc_embedding + +#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h index 1359271c..dc9f6ab2 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h @@ -35,7 +35,7 @@ struct CooFormat { // Get adjusted col_id based on shift and offset. int GetColId(int col_id, int col_shift, int col_offset, int num_scs_mod, - int num_scs_mod_inv); + int num_scs_mod_inv); inline unsigned int CeilOfRatio(unsigned int numerator, unsigned int denominator) { @@ -50,14 +50,16 @@ struct StackedTableMetadata { StackedTableMetadata() = delete; StackedTableMetadata(int feature_index, int max_ids_per_partition, int max_unique_ids_per_partition, int row_offset, - int col_offset, int col_shift, int batch_size) + int col_offset, int col_shift, int batch_size, + int stacked_table_vocab_size = 0) : feature_index(feature_index), max_ids_per_partition(max_ids_per_partition), max_unique_ids_per_partition(max_unique_ids_per_partition), row_offset(row_offset), col_offset(col_offset), col_shift(col_shift), - batch_size(batch_size) {} + batch_size(batch_size), + stacked_table_vocab_size(stacked_table_vocab_size) {} // The batch is given as a list of features (numpy arrays). `feature_index` // represents the index of the feature in the list. int feature_index; @@ -70,6 +72,8 @@ struct StackedTableMetadata { // Process local batch size of the feature. int batch_size; + + int stacked_table_vocab_size; }; void SortAndGroupCooTensors( diff --git a/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD b/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD index e8146f87..e1b6c526 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD @@ -42,6 +42,7 @@ pytype_strict_library( "//jax_tpu_embedding/sparsecore/lib/core:constants", pypi_requirement("jax"), pypi_requirement("jax/_src/lib"), + pypi_requirement("jax/extend"), pypi_requirement("numpy"), ], ) diff --git a/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py b/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py index 99a92f91..678381dc 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py +++ b/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py @@ -20,13 +20,14 @@ from jax._src import dispatch from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +import jax.extend as jex from jax.interpreters import mlir import jax.numpy as jnp from jax_tpu_embedding.sparsecore.lib.core import constants import numpy as np # Define the sparse dense matmul primitive. -tpu_sparse_dense_matmul_csr_with_mini_batching_primitive = core.Primitive( +tpu_sparse_dense_matmul_csr_with_mini_batching_primitive = jex.core.Primitive( "sparse_dense_matmul_csr_with_mini_batching" ) diff --git a/jax_tpu_embedding/sparsecore/lib/nn/BUILD b/jax_tpu_embedding/sparsecore/lib/nn/BUILD index 4cfb76ab..39b82014 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/nn/BUILD @@ -32,11 +32,27 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "embedding_utils", + srcs = ["embedding_utils.py"], + visibility = ["//jax_tpu_embedding/sparsecore/lib/nn:__subpackages__"], + deps = [ + ":embedding_spec", + ":table_stacking", + pypi_requirement("absl/logging"), + pypi_requirement("jax"), + pypi_requirement("jax:experimental"), + pypi_requirement("numpy"), + pypi_requirement("tree"), + ], +) + pytype_strict_library( name = "embedding", srcs = ["embedding.py"], deps = [ ":embedding_spec", + ":embedding_utils", ":table_stacking", "//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc", "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr", diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 417401cc..1b7da028 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -26,11 +26,13 @@ from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_cc from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_csr from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +from jax_tpu_embedding.sparsecore.lib.nn import embedding_utils from jax_tpu_embedding.sparsecore.lib.nn import table_stacking from jax_tpu_embedding.sparsecore.lib.proto import embedding_spec_pb2 import numpy as np import tree + ArrayLike = jnp.ndarray | np.ndarray T: TypeAlias = TypeVar("T") @@ -272,16 +274,6 @@ def auto_stack_tables( ) -def sharding_strategy_to_int(sharding_strategy: str) -> int: - if sharding_strategy == "MOD": - return 1 - else: - raise ValueError( - f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" - " supported." - ) - - def preprocess_sparse_dense_matmul_input( features: Nested[ArrayLike], features_weights: Nested[ArrayLike], @@ -348,60 +340,13 @@ def preprocess_sparse_dense_matmul_input( local_device_count, global_device_count, num_sc_per_device, - sharding_strategy_to_int(sharding_strategy), + embedding_utils.sharding_strategy_to_int(sharding_strategy), has_leading_dimension, static_buffer_size_multiplier, allow_id_dropping=allow_id_dropping, ) -def _get_activation_for_feature( - feature: embedding_spec.FeatureSpec, - activations: dict[str, jax.Array], - global_device_count: int, -) -> jax.Array: - """Gets the activation slice for a given feature.""" - assert feature.table_spec.stacked_table_spec is not None - if feature.id_transformation is None: - raise ValueError( - "FeatureIdTransformation cannot be None. It is None for" - f" {feature.name}", - ) - per_device_offset = ( - feature.id_transformation.row_offset // global_device_count - ) - if feature.output_shape[-1] > feature.table_spec.embedding_dim: - raise ValueError( - f"Feature {feature.name} has output shape {feature.output_shape} and" - f" embedding dim {feature.table_spec.embedding_dim}. The output shape" - " must be at least same as the (original, unpadded)embedding dim." - ) - return jax.lax.slice( - activations[feature.table_spec.stacked_table_spec.stack_name], - (per_device_offset, 0), - ( - per_device_offset + feature.output_shape[0] // global_device_count, - feature.output_shape[-1], - ), - ) - - -def _unstack_embedding_activations( - activations: dict[str, jax.Array], - feature_specs: Nested[embedding_spec.FeatureSpec], - global_device_count: int, -) -> Nested[jax.Array]: - """Unstacks the activations to match the feature specs.""" - - get_activation_for = functools.partial( - _get_activation_for_feature, - activations=activations, - global_device_count=global_device_count, - ) - - return jax.tree_util.tree_map(get_activation_for, feature_specs) - - @jax.named_call def tpu_sparse_dense_matmul( lhs_row_pointers: Mapping[str, jax.Array], @@ -482,7 +427,7 @@ def tpu_sparse_dense_matmul( stacked_table_specs = get_stacked_table_specs(feature_specs) assert lhs_row_pointers.keys() == stacked_table_specs.keys() - sharding_strategy = _sharding_strategy_to_enum(sharding_strategy) + sharding_strategy = embedding_utils.sharding_strategy_to_enum(sharding_strategy) activations = {} for stacked_table_name in stacked_table_specs: @@ -507,61 +452,11 @@ def tpu_sparse_dense_matmul( ) ) - return _unstack_embedding_activations( + return embedding_utils.unstack_embedding_activations( activations, feature_specs, global_device_count ) -def _sharding_strategy_to_enum(sharding_strategy: str) -> int: - """Converts the sharding strategy string to the enum.""" - if sharding_strategy.upper() == "MOD": - return 1 - else: - raise ValueError( - f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" - " supported." - ) - - -def _stack_embedding_gradients( - activation_gradients: Nested[jax.Array], - feature_specs: Nested[embedding_spec.FeatureSpec], -) -> Mapping[str, jax.Array]: - """Stacks the gradients for update to embedding variables.""" - stacked_table_to_features = collections.defaultdict(list) - for gradient, feature in zip( - tree.flatten(activation_gradients), tree.flatten(feature_specs) - ): - assert feature.table_spec.stacked_table_spec is not None - if feature.id_transformation is None: - raise ValueError( - "FeatureIdTransformation cannot be None here. It is None for" - f" {feature.name}" - ) - stacked_table_to_features[ - feature.table_spec.stacked_table_spec.stack_name - ].append((feature, gradient)) - stacked_table_to_gradients = collections.defaultdict(list) - for stacked_table_name, stacked_features in stacked_table_to_features.items(): - stacked_features.sort(key=lambda x: x[0].id_transformation.row_offset) - for f, g in stacked_features: - # feature.table_spec.embedding_dim is the original table dim, before - # padding - gradient = g.reshape([-1, f.table_spec.embedding_dim]) - # Add padding for extra cols - extra_cols = ( - f.table_spec.setting_in_stack.padded_embedding_dim - - f.table_spec.embedding_dim - ) - if extra_cols != 0: - gradient = jax.lax.pad(gradient, 0.0, [(0, 0, 0), (0, extra_cols, 0)]) - stacked_table_to_gradients[stacked_table_name].append(gradient) - return { - t: jax.lax.concatenate(grads, dimension=0) - for t, grads in stacked_table_to_gradients.items() - } - - @jax.named_call def tpu_sparse_dense_matmul_grad( activation_gradients: Nested[jax.Array], @@ -643,10 +538,10 @@ def tpu_sparse_dense_matmul_grad( stacked_table_specs = get_stacked_table_specs(feature_specs) assert lhs_row_pointers.keys() == stacked_table_specs.keys() - gradients = _stack_embedding_gradients(activation_gradients, feature_specs) + gradients = embedding_utils.stack_embedding_gradients(activation_gradients, feature_specs) assert lhs_row_pointers.keys() == gradients.keys() - sharding_strategy = _sharding_strategy_to_enum(sharding_strategy) + sharding_strategy = embedding_utils.sharding_strategy_to_enum(sharding_strategy) updated_embedding_variables = {} for stacked_table_name in stacked_table_specs: diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py new file mode 100644 index 00000000..84630240 --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py @@ -0,0 +1,138 @@ +# Copyright 2024 The JAX SC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities for embedding lookup and update.""" + +import collections +import functools +from typing import Mapping, Sequence, TypeAlias, TypeVar, Union + +import jax +import jax.numpy as jnp +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +import numpy as np +import tree + +ArrayLike = jnp.ndarray | np.ndarray + +T: TypeAlias = TypeVar("T") +Nested: TypeAlias = Union[T, Sequence[T], Mapping[str, T]] + + +def sharding_strategy_to_int(sharding_strategy: str) -> int: + if sharding_strategy == "MOD": + return 1 + else: + raise ValueError( + f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" + " supported." + ) + + +def _get_activation_for_feature( + feature: embedding_spec.FeatureSpec, + activations: dict[str, jax.Array], + global_device_count: int, +) -> jax.Array: + """Gets the activation slice for a given feature.""" + assert feature.table_spec.stacked_table_spec is not None + if feature.id_transformation is None: + raise ValueError( + "FeatureIdTransformation cannot be None. It is None for" + f" {feature.name}", + ) + per_device_offset = ( + feature.id_transformation.row_offset // global_device_count + ) + if feature.output_shape[-1] > feature.table_spec.embedding_dim: + raise ValueError( + f"Feature {feature.name} has output shape {feature.output_shape} and" + f" embedding dim {feature.table_spec.embedding_dim}. The output shape" + " must be at least same as the (original, unpadded)embedding dim." + ) + return jax.lax.slice( + activations[feature.table_spec.stacked_table_spec.stack_name], + (per_device_offset, 0), + ( + per_device_offset + feature.output_shape[0] // global_device_count, + feature.output_shape[-1], + ), + ) + + +def unstack_embedding_activations( + activations: dict[str, jax.Array], + feature_specs: Nested[embedding_spec.FeatureSpec], + global_device_count: int, +) -> Nested[jax.Array]: + """Unstacks the activations to match the feature specs.""" + + get_activation_for = functools.partial( + _get_activation_for_feature, + activations=activations, + global_device_count=global_device_count, + ) + + return jax.tree_util.tree_map(get_activation_for, feature_specs) + + +def sharding_strategy_to_enum(sharding_strategy: str) -> int: + """Converts the sharding strategy string to the enum.""" + if sharding_strategy.upper() == "MOD": + return 1 + else: + raise ValueError( + f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" + " supported." + ) + + +def stack_embedding_gradients( + activation_gradients: Nested[jax.Array], + feature_specs: Nested[embedding_spec.FeatureSpec], +) -> Mapping[str, jax.Array]: + """Stacks the gradients for update to embedding variables.""" + stacked_table_to_features = collections.defaultdict(list) + for gradient, feature in zip( + tree.flatten(activation_gradients), tree.flatten(feature_specs) + ): + assert feature.table_spec.stacked_table_spec is not None + if feature.id_transformation is None: + raise ValueError( + "FeatureIdTransformation cannot be None here. It is None for" + f" {feature.name}" + ) + stacked_table_to_features[ + feature.table_spec.stacked_table_spec.stack_name + ].append((feature, gradient)) + stacked_table_to_gradients = collections.defaultdict(list) + for stacked_table_name, stacked_features in stacked_table_to_features.items(): + stacked_features.sort(key=lambda x: x[0].id_transformation.row_offset) + for f, g in stacked_features: + # feature.table_spec.embedding_dim is the original table dim, before + # padding + gradient = g.reshape([-1, f.table_spec.embedding_dim]) + # Add padding for extra cols + extra_cols = ( + f.table_spec.setting_in_stack.padded_embedding_dim + - f.table_spec.embedding_dim + ) + if extra_cols != 0: + gradient = jax.lax.pad(gradient, 0.0, [(0, 0, 0), (0, extra_cols, 0)]) + stacked_table_to_gradients[stacked_table_name].append(gradient) + return { + t: jax.lax.concatenate(grads, dimension=0) + for t, grads in stacked_table_to_gradients.items() + } + +