diff --git a/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py b/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py index 66d8952..95dd40f 100644 --- a/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py +++ b/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py @@ -53,7 +53,9 @@ _EMBEDDING_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2 _EMBEDDING_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1 -_EMBEDDING_UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED +_EMBEDDING_UNSUPPORTED = ( + tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED +) def _normalize_and_prepare_optimizer(optimizer): @@ -576,16 +578,24 @@ def tpu_step(inputs): def __init__( self, - feature_config: Union[tf.tpu.experimental.embedding.FeatureConfig, - Iterable], # pylint:disable=g-bare-generic - optimizer: Optional[Union[tf.tpu.experimental.embedding.SGD, - tf.tpu.experimental.embedding.Adagrad, - tf.tpu.experimental.embedding.Adam, - tf.tpu.experimental.embedding.FTRL]], + feature_config: Union[ + tf.tpu.experimental.embedding.FeatureConfig, Iterable # pylint:disable=g-bare-generic + ], + optimizer: Optional[ + Union[ + tf.tpu.experimental.embedding.SGD, + tf.tpu.experimental.embedding.Adagrad, + tf.tpu.experimental.embedding.Adam, + tf.tpu.experimental.embedding.FTRL, + ] + ], pipeline_execution_with_tensor_core: bool = False, batch_size: Optional[int] = None, embedding_feature: Optional[ - tf.tpu.experimental.HardwareFeature.EmbeddingFeature] = None): + tf.tpu.experimental.HardwareFeature.EmbeddingFeature + ] = None, + experimental_sparsecore_restore_info: Optional[Dict[str, Any]] = None, + ): """A Keras layer for accelerated embedding lookups on TPU. Args: @@ -605,6 +615,9 @@ def __init__( compatibility. embedding_feature: EmbeddingFeature enum, inidicating which version of TPU hardware the layer should run on. + experimental_sparsecore_restore_info: Information from the sparse core + training required to restore for serving (like number of TPU devices + used.) """ super().__init__() self._feature_config, self._table_config_map = ( @@ -616,7 +629,9 @@ def __init__( self._embedding_feature = None if self._using_tpu: - self._embedding_feature = self._strategy.extended.tpu_hardware_feature.embedding_feature + self._embedding_feature = ( + self._strategy.extended.tpu_hardware_feature.embedding_feature + ) # Override the embedding feature setting if passed. if embedding_feature is not None: if embedding_feature == _EMBEDDING_UNSUPPORTED: @@ -630,6 +645,9 @@ def __init__( # Create TPU embedding mid level APIs according to the embedding feature # setting. + self._experimental_sparsecore_restore_info = ( + experimental_sparsecore_restore_info + ) self._tpu_embedding = self._create_tpu_embedding_mid_level_api( self._using_tpu, self._embedding_feature, pipeline_execution_with_tensor_core) @@ -666,7 +684,10 @@ def _create_tpu_embedding_mid_level_api( """ if not using_tpu or embedding_feature is None: return tf.tpu.experimental.embedding.TPUEmbeddingForServing( - self._feature_config, self._optimizer) + self._feature_config, + self._optimizer, + experimental_sparsecore_restore_info=self._experimental_sparsecore_restore_info, + ) if embedding_feature == _EMBEDDING_UNSUPPORTED: return tf.tpu.experimental.embedding.TPUEmbeddingV0( self._feature_config, self._optimizer)