Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to restore from SC checkpoints #698

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@

_EMBEDDING_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
_EMBEDDING_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
_EMBEDDING_UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
_EMBEDDING_UNSUPPORTED = (
tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
)


def _normalize_and_prepare_optimizer(optimizer):
Expand Down Expand Up @@ -576,16 +578,24 @@ def tpu_step(inputs):

def __init__(
self,
feature_config: Union[tf.tpu.experimental.embedding.FeatureConfig,
Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[Union[tf.tpu.experimental.embedding.SGD,
tf.tpu.experimental.embedding.Adagrad,
tf.tpu.experimental.embedding.Adam,
tf.tpu.experimental.embedding.FTRL]],
feature_config: Union[
tf.tpu.experimental.embedding.FeatureConfig, Iterable # pylint:disable=g-bare-generic
],
optimizer: Optional[
Union[
tf.tpu.experimental.embedding.SGD,
tf.tpu.experimental.embedding.Adagrad,
tf.tpu.experimental.embedding.Adam,
tf.tpu.experimental.embedding.FTRL,
]
],
pipeline_execution_with_tensor_core: bool = False,
batch_size: Optional[int] = None,
embedding_feature: Optional[
tf.tpu.experimental.HardwareFeature.EmbeddingFeature] = None):
tf.tpu.experimental.HardwareFeature.EmbeddingFeature
] = None,
experimental_sparsecore_restore_info: Optional[Dict[str, Any]] = None,
):
"""A Keras layer for accelerated embedding lookups on TPU.

Args:
Expand All @@ -605,6 +615,9 @@ def __init__(
compatibility.
embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
hardware the layer should run on.
experimental_sparsecore_restore_info: Information from the sparse core
training required to restore for serving (like number of TPU devices
used.)
"""
super().__init__()
self._feature_config, self._table_config_map = (
Expand All @@ -616,7 +629,9 @@ def __init__(

self._embedding_feature = None
if self._using_tpu:
self._embedding_feature = self._strategy.extended.tpu_hardware_feature.embedding_feature
self._embedding_feature = (
self._strategy.extended.tpu_hardware_feature.embedding_feature
)
# Override the embedding feature setting if passed.
if embedding_feature is not None:
if embedding_feature == _EMBEDDING_UNSUPPORTED:
Expand All @@ -630,6 +645,9 @@ def __init__(

# Create TPU embedding mid level APIs according to the embedding feature
# setting.
self._experimental_sparsecore_restore_info = (
experimental_sparsecore_restore_info
)
self._tpu_embedding = self._create_tpu_embedding_mid_level_api(
self._using_tpu, self._embedding_feature,
pipeline_execution_with_tensor_core)
Expand Down Expand Up @@ -666,7 +684,10 @@ def _create_tpu_embedding_mid_level_api(
"""
if not using_tpu or embedding_feature is None:
return tf.tpu.experimental.embedding.TPUEmbeddingForServing(
self._feature_config, self._optimizer)
self._feature_config,
self._optimizer,
experimental_sparsecore_restore_info=self._experimental_sparsecore_restore_info,
)
if embedding_feature == _EMBEDDING_UNSUPPORTED:
return tf.tpu.experimental.embedding.TPUEmbeddingV0(
self._feature_config, self._optimizer)
Expand Down