Skip to content

Commit

Permalink
Add support EmbeddingFeature V2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583907738
  • Loading branch information
TensorFlow Recommenders Authors committed Nov 29, 2023
1 parent a73df26 commit cae8e42
Showing 1 changed file with 135 additions and 78 deletions.
213 changes: 135 additions & 78 deletions tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Keras interface for TPU Embeddings in TF2."""

from typing import Iterable, Optional, Union, Any, Dict
from typing import Any, Dict, Iterable, Optional, Union

import tensorflow.compat.v2 as tf

Expand Down Expand Up @@ -51,9 +51,22 @@
}
_DUMMY_NAME = "tpu_embedding_helper_dummy"

_EMBEDDING_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
_EMBEDDING_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
_EMBEDDING_UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
EmbeddingFeature = tf.tpu.experimental.HardwareFeature.EmbeddingFeature

_EMBEDDING_V2 = EmbeddingFeature.V2
_EMBEDDING_V1 = EmbeddingFeature.V1
_EMBEDDING_UNSUPPORTED = EmbeddingFeature.UNSUPPORTED

TPUEmbeddingType = Union[
tf.tpu.experimental.embedding.TPUEmbedding,
tf.tpu.experimental.embedding.TPUEmbeddingV0,
tf.tpu.experimental.embedding.TPUEmbeddingForServing,
]

if hasattr(tf.tpu.experimental.embedding, "TPUEmbeddingV2"):
TPUEmbeddingType = (
TPUEmbeddingType | tf.tpu.experimental.embedding.TPUEmbeddingV2
)


def _normalize_and_prepare_optimizer(optimizer):
Expand Down Expand Up @@ -584,8 +597,7 @@ def __init__(
tf.tpu.experimental.embedding.FTRL]],
pipeline_execution_with_tensor_core: bool = False,
batch_size: Optional[int] = None,
embedding_feature: Optional[
tf.tpu.experimental.HardwareFeature.EmbeddingFeature] = None):
embedding_feature: Optional[EmbeddingFeature] = None):
"""A Keras layer for accelerated embedding lookups on TPU.
Args:
Expand Down Expand Up @@ -616,35 +628,42 @@ def __init__(

self._embedding_feature = None
if self._using_tpu:
self._embedding_feature = self._strategy.extended.tpu_hardware_feature.embedding_feature
self._embedding_feature = (
self._strategy.extended.tpu_hardware_feature.embedding_feature
)
# Override the embedding feature setting if passed.
if embedding_feature is not None:
if embedding_feature == _EMBEDDING_UNSUPPORTED:
self._embedding_feature = _EMBEDDING_UNSUPPORTED
if (embedding_feature != _EMBEDDING_UNSUPPORTED and
self._embedding_feature != embedding_feature):
if (
embedding_feature != _EMBEDDING_UNSUPPORTED
and self._embedding_feature != embedding_feature
):
raise ValueError(
"TPU only supports {} and {}, but got {} which is not supported."
.format(_EMBEDDING_UNSUPPORTED, self._embedding_feature,
embedding_feature))
.format(
_EMBEDDING_UNSUPPORTED,
self._embedding_feature,
embedding_feature,
)
)

# Create TPU embedding mid level APIs according to the embedding feature
# setting.
self._tpu_embedding = self._create_tpu_embedding_mid_level_api(
self._using_tpu, self._embedding_feature,
pipeline_execution_with_tensor_core)

self._using_tpu,
self._embedding_feature,
pipeline_execution_with_tensor_core,
)
self.batch_size = batch_size

self._tpu_call_id = 0

def _create_tpu_embedding_mid_level_api(
self, using_tpu: bool, embedding_feature: Optional[
tf.tpu.experimental.HardwareFeature.EmbeddingFeature],
pipeline_execution_with_tensor_core: bool
) -> Union[tf.tpu.experimental.embedding.TPUEmbedding,
tf.tpu.experimental.embedding.TPUEmbeddingV0,
tf.tpu.experimental.embedding.TPUEmbeddingForServing]:
self,
using_tpu: bool,
embedding_feature: Optional[EmbeddingFeature],
pipeline_execution_with_tensor_core: bool,
) -> TPUEmbeddingType:
"""Creates TPU Embedding mid level API instance based on settings.
Args:
Expand Down Expand Up @@ -675,7 +694,14 @@ def _create_tpu_embedding_mid_level_api(
self._feature_config, self._optimizer,
pipeline_execution_with_tensor_core)
elif embedding_feature == _EMBEDDING_V2:
raise NotImplementedError("Embedding feature v2 is not supported yet!")
if hasattr(tf.tpu.experimental.embedding, "TPUEmbeddingV2"):
return tf.tpu.experimental.embedding.TPUEmbeddingV2(
self._feature_config,
self._optimizer,
pipeline_execution_with_tensor_core,
)
else:
raise ValueError("TPUEmbeddingV2 is not supported in TF.")
else:
raise ValueError("Unknown embedding feature {}".format(embedding_feature))

Expand All @@ -693,7 +719,10 @@ def build(self, input_shape: Union[tf.TensorShape, Iterable[tf.TensorShape]]):
else:
self._tpu_embedding.build()

if self._embedding_feature == _EMBEDDING_V1:
if (
self._embedding_feature == _EMBEDDING_V1
or self._embedding_feature == _EMBEDDING_V2
):
# Note that self.tpu_embedding_helper_dummy matches _DUMMY_NAME above,
# or it will appear twice in the list of saveables. Note that the Python
# variable name should be _DUMMY_NAME too, as it is used to name internal
Expand Down Expand Up @@ -731,61 +760,85 @@ def _tpu_embedding_lookup(self, features: Any, weights: Any) -> Any:
A dict of looked up embedding tensors with keys matching those of
features_to_config_dict.
"""
# Each call to this function increments the _tpu_call_id by 1, this allows
# us to tag each of the main embedding ops with this call id so that we know
# during graph rewriting passes which ops correspond to the same layer call.
self._tpu_call_id += 1
name = "{}".format(self._tpu_call_id)

# Set training to true, even during eval. When name is set, this will
# trigger a pass that updates the training based on if there is a send
# gradients with the same name.
self._tpu_embedding.enqueue(features, weights, training=True, name=name)

# The gradient trap is a trick used to ensure we can compute the gradients
# at the correct point of the model. By default GradientTape only tracks
# the calculations which descend from variables. e.g. if you call
# tape.gradient on something that does not come from a variable involved in
# the computation, it will fail.
# We need to call tpu_embedding.apply_gradients on the gradients computed
# at tpu_embedding.dequeue. Since tpu_embedding.dequeue has no inputs, we
# can't compute the gradient at its output. To get around that we wrap
# the dequeue in a function with a custom gradient. This function takes one
# input, throws it away and returns the result of the dequeue. If we pass a
# dummy variable to this function and compute the gradient at the dummy
# variable, then the custom gradient function will be called with the
# graidents that we need to pass to tpu_embedding.apply_gradients.
@tf.custom_gradient
def gradient_trap(dummy):
"""Register a gradient function for activation.
Its purpose is to send gradients back to TPU.
Args:
dummy: a variable to prevent this backward pass from being pruned.
Returns:
a tuple of list of activations and their gradient function.
"""
activations = self._tpu_embedding.dequeue(name=name)

def grad(*grad_wrt_activations):
"""Gradient function."""
# Since the output of the function is flattened, the gradients
# are also flattened. Hence we have to pack them back in to the correct
# nested structure.
gradients = tf.nest.pack_sequence_as(self._feature_config,
grad_wrt_activations)
self._tpu_embedding.apply_gradients(gradients, name=name)

# This is the gradient for the input variable.
return tf.zeros_like(dummy)

# Custom gradient functions don't like nested structures of tensors, so we
# flatten them here.
return tf.nest.flatten(activations), grad

activations_with_trap = gradient_trap(getattr(self, _DUMMY_NAME))
if self._embedding_feature == _EMBEDDING_V2:

@tf.custom_gradient
def gradient_trap(dummy):
"""Register a gradient function for activation."""
activations, preserved_result = self._tpu_embedding(features, weights)

def grad(*grad_wrt_activations):
"""Gradient function."""
gradients = tf.nest.pack_sequence_as(
self._feature_config, grad_wrt_activations
)
self._tpu_embedding.apply_gradients(
gradients, preserved_outputs=preserved_result
)
return tf.zeros_like(dummy)

return tf.nest.flatten(activations), grad

activations_with_trap = gradient_trap(getattr(self, _DUMMY_NAME))
else:
# Each call to this function increments the _tpu_call_id by 1, this allows
# us to tag each of the main embedding ops with this call id so that we
# know during graph rewriting passes which ops correspond to the same
# layer call.
self._tpu_call_id += 1
name = "{}".format(self._tpu_call_id)

# Set training to true, even during eval. When name is set, this will
# trigger a pass that updates the training based on if there is a send
# gradients with the same name.
self._tpu_embedding.enqueue(features, weights, training=True, name=name)

# The gradient trap is a trick used to ensure we can compute the gradients
# at the correct point of the model. By default GradientTape only tracks
# the calculations which descend from variables. e.g. if you call
# tape.gradient on something that does not come from a variable involved
# in the computation, it will fail.
# We need to call tpu_embedding.apply_gradients on the gradients computed
# at tpu_embedding.dequeue. Since tpu_embedding.dequeue has no inputs, we
# can't compute the gradient at its output. To get around that we wrap
# the dequeue in a function with a custom gradient. This function takes
# one input, throws it away and returns the result of the dequeue. If we
# pass a dummy variable to this function and compute the gradient at the
# dummy variable, then the custom gradient function will be called with
# the graidents that we need to pass to tpu_embedding.apply_gradients.
@tf.custom_gradient
def gradient_trap(dummy):
"""Register a gradient function for activation.
Its purpose is to send gradients back to TPU.
Args:
dummy: a variable to prevent this backward pass from being pruned.
Returns:
a tuple of list of activations and their gradient function.
"""
activations = self._tpu_embedding.dequeue(name=name)

def grad(*grad_wrt_activations):
"""Gradient function."""
# Since the output of the function is flattened, the gradients
# are also flattened. Hence we have to pack them back in to the
# correct nested structure.
gradients = tf.nest.pack_sequence_as(
self._feature_config, grad_wrt_activations
)
self._tpu_embedding.apply_gradients(gradients, name=name)

# This is the gradient for the input variable.
return tf.zeros_like(dummy)

# Custom gradient functions don't like nested structures of tensors,
# so we flatten them here.
return tf.nest.flatten(activations), grad

activations_with_trap = gradient_trap(getattr(self, _DUMMY_NAME))

return tf.nest.pack_sequence_as(self._feature_config, activations_with_trap)

def call(
Expand Down Expand Up @@ -869,6 +922,10 @@ def embedding_tables(
`feature_config` passed to this layer's init.
"""
tables = self._tpu_embedding.embedding_tables
# TODO(ziyinh): handle stacked table here.
if self._embedding_feature == _EMBEDDING_V2:
return tables

# Use the table config map to map from the cloned configs back to the
# configs that where passed into the layer on init.
return {
Expand Down

0 comments on commit cae8e42

Please sign in to comment.