Skip to content

Commit

Permalink
Add registry function and tests for tf.keras.layers.EinsumDense.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511864469
  • Loading branch information
tensorflower-gardener committed Feb 28, 2023
1 parent d7cd3f8 commit 5c8b762
Show file tree
Hide file tree
Showing 6 changed files with 637 additions and 87 deletions.
8 changes: 8 additions & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@ py_library(
name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"],
srcs_version = "PY3",
deps = [":layer_registry"],
)

py_library(
name = "einsum_utils",
srcs = ["einsum_utils.py"],
srcs_version = "PY3",
)

py_library(
name = "layer_registry",
srcs = ["layer_registry.py"],
srcs_version = "PY3",
deps = [":einsum_utils"],
)

py_library(
Expand Down
20 changes: 16 additions & 4 deletions tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@

import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr


def get_registry_generator_fn(tape, layer_registry):
def get_registry_generator_fn(
tape: tf.GradientTape, layer_registry: lr.LayerRegistry
):
"""Creates the generator function for `compute_gradient_norms()`."""
if layer_registry is None:
# Needed for backwards compatibility.
Expand Down Expand Up @@ -53,7 +56,12 @@ def registry_generator_fn(layer_instance, args, kwargs):
return registry_generator_fn


def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
def compute_gradient_norms(
input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example loss gradient norms for given data.
Applies a variant of the approach given in
Expand Down Expand Up @@ -106,7 +114,7 @@ def compute_gradient_norms(input_model, x_batch, y_batch, layer_registry):
return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1))


def compute_clip_weights(l2_norm_clip, gradient_norms):
def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor):
"""Computes the per-example loss/clip weights for clipping.
When the sum of the per-example losses is replaced a weighted sum, where
Expand All @@ -132,7 +140,11 @@ def compute_clip_weights(l2_norm_clip, gradient_norms):


def compute_pred_and_clipped_gradients(
input_model, x_batch, y_batch, l2_norm_clip, layer_registry
input_model: tf.keras.Model,
x_batch: tf.Tensor,
y_batch: tf.Tensor,
l2_norm_clip: float,
layer_registry: lr.LayerRegistry,
):
"""Computes the per-example predictions and per-example clipped loss gradient.
Expand Down
Loading

0 comments on commit 5c8b762

Please sign in to comment.