Skip to content

Commit

Permalink
Add registry function and tests for tf.keras.layers.EinsumDense.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511864469
  • Loading branch information
tensorflower-gardener committed Feb 28, 2023
1 parent d7cd3f8 commit 6512e44
Show file tree
Hide file tree
Showing 6 changed files with 652 additions and 89 deletions.
8 changes: 8 additions & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ py_library(
name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3",
deps = [":layer_registry"],
)

py_library(
name = "einsum_utils",
srcs = ["einsum_utils.py"],
srcs_version = "PY3",
)

py_library(
name = "layer_registry",
srcs = ["layer_registry.py"],
srcs_version = "PY3",
deps = [":einsum_utils"],
)

py_library(
Expand Down
30 changes: 24 additions & 6 deletions tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@
`compute_gradient_norms()` function).
"""

from typing import Union, Iterable, Text, TypeAlias

import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr

InputTensor: TypeAlias = Union[
tf.Tensor, Iterable[tf.Tensor], dict[Text, tf.Tensor]
]


def get_registry_generator_fn(tape, layer_registry):
def get_registry_generator_fn(
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
Expand Down Expand Up @@ -53,7 +62,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
return registry_generator_fn


def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
def compute_gradient_norms(
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in
Expand All @@ -62,7 +76,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from. The
loss of the model *must* be a scalar loss.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
Expand Down Expand Up @@ -106,7 +120,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))


def compute_clip_weights(l2_norm_clip, gradient_norms):
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
"""Computes the per-example loss/clip weights for clipping.
When the sum of the per-example losses is replaced a weighted sum, where
Expand All @@ -132,7 +146,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):


def compute_pred_and_clipped_gradients(
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
input_model: tf.keras.Model,
x_batch: InputTensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example predictions and per-example clipped loss gradient.
Expand All @@ -147,7 +165,7 @@ def compute_pred_and_clipped_gradients(
Args:
input_model: The `tf.keras.Model` from which to obtain the layers from.
x_batch: A `tf.Tensor` representing a batch of inputs to the model. The
x_batch: An `InputTensor` representing a batch of inputs to the model. The
first axis must be the batch dimension.
y_batch: A `tf.Tensor` representing a batch of output labels. The first axis
must be the batch dimension. The number of examples should match the
Expand Down
Loading

0 comments on commit 6512e44

Please sign in to comment.