Skip to content

Commit

Permalink
Minor bug fix to MultiHeadAttention registry function.
Browse files Browse the repository at this point in the history
Covers the case where a layer arg can be passed as a layer kwarg.

PiperOrigin-RevId: 520394717
  • Loading branch information
tensorflower-gardener committed Mar 29, 2023
1 parent abb0c3f commit dd43418
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Utility functions that help in the computation of per-example gradient norms."""

from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Text, Tuple, Union

from absl import logging
import tensorflow as tf
Expand All @@ -36,19 +36,6 @@ def has_internal_compute_graph(input_object: Any):
)


def _get_internal_layers(
input_layer: tf.keras.layers.Layer,
) -> List[tf.keras.layers.Layer]:
"""Returns a list of layers that are nested within a given layer."""
internal_layers = []
if isinstance(input_layer, tf.keras.Model) and hasattr(input_layer, 'layers'):
for layer in input_layer.layers:
internal_layers.extend(_get_internal_layers(layer))
else:
internal_layers.append(input_layer)
return internal_layers


def model_forward_pass(
input_model: tf.keras.Model,
inputs: PackedTensors,
Expand Down Expand Up @@ -114,18 +101,10 @@ def generator_fn(layer_instance, args, kwargs):
generator_outputs_list.extend(node_generator_outputs)
else:
# Otherwise, we parse the node directly.
node_layers = _get_internal_layers(node.layer)
for layer in node_layers:
node_layer_outputs, layer_generator_outputs = generator_fn(
layer, args, kwargs
)
generator_outputs_list.append(layer_generator_outputs)
args = (
node_layer_outputs
if isinstance(node_layer_outputs, tuple)
else (node_layer_outputs,)
)
kwargs = {}
node_layer_outputs, layer_generator_outputs = generator_fn(
node.layer, args, kwargs
)
generator_outputs_list.append(layer_generator_outputs)

# Update the current dictionary of inputs for the next node.
for x_id, y in zip(
Expand Down Expand Up @@ -163,9 +142,8 @@ def all_trainable_layers_are_registered(
False otherwise.
"""
for layer in input_model.layers:
for sublayer in _get_internal_layers(layer):
if not layer_registry.is_elem(sublayer) and sublayer.trainable_variables:
return False
if not layer_registry.is_elem(layer) and layer.trainable_variables:
return False
return True


Expand Down Expand Up @@ -213,17 +191,53 @@ def add_noise(g):

def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
custom_layer_set: Optional[Set[type]] = None, # pylint: disable=g-bare-generic
) -> PackedTensors:
"""Returns the model outputs generated by only core Keras layers."""
cust_obj_dict = dict.copy(tf.keras.utils.get_custom_objects())
cust_hash_set = set([hash(v) for v in cust_obj_dict.values()])
"""Returns the model outputs generated by only core Keras layers.
Args:
input_model: A `tf.keras.Model` instance to obtain outputs from.
custom_layer_set: An optional `set` of custom layers to expand. If `None`,
then this is the set of all registered custom Keras layers.
Returns:
A `tf.Tensor` that is the result of `input_model(input_model.inputs)`
using only Keras layers that are not in `custom_layer_set`.
"""
# Set up helper variables and functions.
custom_layer_set = (
custom_layer_set or tf.keras.utils.get_custom_objects().values()
)

def _is_core(layer_instance):
return type(layer_instance) not in custom_layer_set

def generator_fn(layer_instance, args, kwargs):
if hash(layer_instance.__class__) in cust_hash_set:
# Using `.call()` does not register the layer in the compute graph of
# a forward pass.
return layer_instance.call(*args, **kwargs), None
else:
return layer_instance(*args, **kwargs), None
# Using `.call()` does not register the layer in the compute graph of
# a forward pass.
layer_outputs = (
layer_instance(*args, **kwargs)
if _is_core(layer_instance)
else layer_instance.call(*args, **kwargs)
)
return layer_outputs, None

# Return early if all the existing layers contain only core layers.
if all(_is_core(layer) for layer in input_model.layers):
return model_forward_pass(input_model, input_model.inputs)[0]

return model_forward_pass(input_model, input_model.inputs, generator_fn)[0]
# Do a forward pass to expand the outermost layers.
candidate_outputs, _ = model_forward_pass(
input_model, input_model.inputs, generator_fn
)

# The following recursion is inefficient because it recursively builds `n`
# Keras model graphs, where `n` is the number of recursive calls. However,
# it appears to be the only valid approach without accessing Keras's internal
# functions (e.g., `keras.engine.functional._map_graph_network()`).
cleaned_model = tf.keras.Model(
inputs=input_model.inputs, outputs=candidate_outputs
)
return generate_model_outputs_using_core_keras_layers(
cleaned_model, custom_layer_set
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,72 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from absl.testing import parameterized
import tensorflow as tf

from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils


# ==============================================================================
# Helper functions and classes.
# ==============================================================================
@tf.keras.utils.register_keras_serializable('gradient_clipping_utils_test')
class DoubleDense(tf.keras.layers.Layer):
"""Generates two dense layers nested together."""

def __init__(self, units: int):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units, name='DDense_ext_1')
self.dense2 = tf.keras.layers.Dense(1, name='DDense_ext_2')

def call(self, inputs: Any):
x = self.dense1(inputs)
return self.dense2(x)


@tf.keras.utils.register_keras_serializable('gradient_clipping_utils_test')
class TripleDense(tf.keras.layers.Layer):
"""Generates three dense layers nested together."""

def __init__(self, units: int):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units, name='TDense_ext_1')
self.dense2 = tf.keras.layers.Dense(units, name='TDense_ext_2')
self.dense3 = tf.keras.layers.Dense(1, name='TDense_ext_3')

def call(self, inputs: Any):
x1 = self.dense1(inputs)
x2 = self.dense2(x1)
return self.dense3(x2)


def get_reduced_model(sample_inputs, hidden_layer_list, new_custom_layers=None):
"""Reduces a set of layers to only core Keras layers in a model."""
sample_outputs = sample_inputs
for l in hidden_layer_list:
sample_outputs = l(sample_outputs)
custom_model = tf.keras.Model(inputs=sample_inputs, outputs=sample_outputs)
if new_custom_layers:
reduced_outputs = (
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
custom_model,
custom_layer_set=new_custom_layers,
)
)
else:
reduced_outputs = (
gradient_clipping_utils.generate_model_outputs_using_core_keras_layers(
custom_model
)
)
return tf.keras.Model(inputs=custom_model.inputs, outputs=reduced_outputs)


# ==============================================================================
# Main tests.
# ==============================================================================
class ModelForwardPassTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.product(
Expand Down Expand Up @@ -75,5 +135,46 @@ def test_outputs_are_consistent(
self.assertAllClose(computed_outputs, true_outputs)


class GenerateOutputsUsingCoreKerasLayers(
tf.test.TestCase, parameterized.TestCase
):

def test_single_custom_layer_is_reduced(self):
num_units = 5
num_dims = 3
reduced_model = get_reduced_model(
tf.keras.Input(num_dims),
[DoubleDense(num_units)],
)
# Ignore the input layer.
for l in reduced_model.layers[1:]:
self.assertIsInstance(l, tf.keras.layers.Dense)

def test_two_distinct_custom_layers_are_reduced(self):
num_units = 5
num_dims = 3
reduced_model = get_reduced_model(
tf.keras.Input(num_dims),
[DoubleDense(num_units), TripleDense(num_units)],
)
# Ignore the input layer.
for l in reduced_model.layers[1:]:
self.assertIsInstance(l, tf.keras.layers.Dense)

def test_new_custom_layer_spec(self):
num_units = 5
num_dims = 3
reduced_model = get_reduced_model(
tf.keras.Input(num_dims),
[DoubleDense(num_units), TripleDense(num_units)],
new_custom_layers=set([DoubleDense]),
)
# Ignore the input layer.
for l in reduced_model.layers[1:]:
self.assertTrue(
isinstance(l, tf.keras.layers.Dense) or isinstance(l, TripleDense)
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit dd43418

Please sign in to comment.