From 345fe5becd2d8ee12d0d0056453229a66c40865e Mon Sep 17 00:00:00 2001 From: Walid Krichene Date: Fri, 19 May 2023 17:21:54 -0700 Subject: [PATCH] Support fast clipping in DPAM. PiperOrigin-RevId: 533589057 --- .../fast_gradient_clipping/clip_grads.py | 83 ++++---- .../fast_gradient_clipping/clip_grads_test.py | 178 +++++++++--------- .../fast_gradient_clipping/layer_registry.py | 17 +- .../privacy/keras_models/dp_keras_model.py | 40 ++-- .../keras_models/dp_keras_model_test.py | 176 +++++++++++------ 5 files changed, 283 insertions(+), 211 deletions(-) diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index c11a93bd..c74a2505 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -69,9 +69,10 @@ def registry_generator_fn(layer_instance, args, kwargs): def compute_gradient_norms( input_model: tf.keras.Model, + layer_registry: lr.LayerRegistry, x_batch: InputTensor, y_batch: tf.Tensor, - layer_registry: lr.LayerRegistry, + weight_batch: Optional[tf.Tensor] = None, per_example_loss_fn: Optional[LossFn] = None, num_microbatches: Optional[lr.BatchSize] = None, trainable_vars: Optional[List[tf.Variable]] = None, @@ -84,15 +85,16 @@ def compute_gradient_norms( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. The loss of the model *must* be a scalar loss. + layer_registry: A `LayerRegistry` instance containing functions that help + compute gradient norms quickly. See + `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for + more details. x_batch: An `InputTensor` representing a batch of inputs to the model. The first axis must be the batch dimension. y_batch: A `tf.Tensor` representing a batch of output labels. The first axis must be the batch dimension. The number of examples should match the number of examples in `x_batch`. - layer_registry: A `LayerRegistry` instance containing functions that help - compute gradient norms quickly. See - `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for - more details. + weight_batch: Optional batch of weights, passed to the loss function. per_example_loss_fn: takes as input predictions, labels and weights, and outputs a vector of per-example losses. If None, derived from `input_model.loss` by disabling its reduction. @@ -108,8 +110,9 @@ def compute_gradient_norms( variables are included. Returns: - A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th - per-example loss function. + A scalar vector, whose i-th entry is the norm of the gradient of the i-th + example loss (when num_microbatches is None) or the norm of the gradient of + the i-th microbatch loss (define as a mean over the microbatch). """ tape = tf.GradientTape(persistent=True, watch_accessed_variables=False) registry_generator_fn = get_registry_generator_fn( @@ -127,7 +130,7 @@ def compute_gradient_norms( loss_config = input_model.loss.get_config() loss_config['reduction'] = tf.keras.losses.Reduction.NONE per_example_loss_fn = input_model.loss.from_config(loss_config) - losses = per_example_loss_fn(y_batch, model_outputs) + losses = per_example_loss_fn(y_batch, model_outputs, weight_batch) if losses.shape is None: raise NotImplementedError( "The unreduced (or per-example) loss's shape cannot be `None`" @@ -140,7 +143,7 @@ def compute_gradient_norms( ) if num_microbatches is not None: losses = tf.reduce_mean( - lr.add_microbatch_axis(losses, num_microbatches), axis=1 + lr.maybe_add_microbatch_axis(losses, num_microbatches), axis=1 ) summed_loss = tf.reduce_sum(losses) # Unwrap the generator outputs so that the next loop avoids duplicating @@ -165,8 +168,10 @@ def compute_gradient_norms( vars_list, unconnected_gradients=tf.UnconnectedGradients.ZERO, ) + if not grads_list: + raise ValueError('Empty gradient list.') sqr_norm_list = [] - for grads, f in zip(grads_list, sqr_norm_fns_list): + for grads, f in zip(grads_list, sqr_norm_fns_list, strict=True): sqr_norm_list.append(f(grads)) sqr_norm_tsr = tf.stack(sqr_norm_list, axis=1) return tf.sqrt(tf.reduce_sum(sqr_norm_tsr, axis=1)) @@ -199,10 +204,11 @@ def compute_clip_weights(l2_norm_clip: float, gradient_norms: tf.Tensor): def compute_clipped_gradients_and_outputs( input_model: tf.keras.Model, - x_batch: InputTensor, - y_batch: tf.Tensor, l2_norm_clip: float, layer_registry: lr.LayerRegistry, + x_batch: InputTensor, + y_batch: tf.Tensor, + weight_batch: Optional[tf.Tensor] = None, num_microbatches: Optional[lr.BatchSize] = None, clipping_loss: Optional[LossFn] = None, ) -> Tuple[List[tf.Tensor], tf.Tensor, tf.Tensor]: @@ -218,11 +224,6 @@ def compute_clipped_gradients_and_outputs( Args: input_model: The `tf.keras.Model` from which to obtain the layers from. - x_batch: An `InputTensor` representing a batch of inputs to the model. The - first axis must be the batch dimension. - y_batch: A `tf.Tensor` representing a batch of output labels. The first axis - must be the batch dimension. The number of examples should match the - number of examples in `x_batch`. l2_norm_clip: A `float` indicating the norm to which per-example gradients will be clipped. That is, all gradients of the per-example loss functions will have norm at most `l2_norm_clip`. @@ -232,6 +233,15 @@ def compute_clipped_gradients_and_outputs( `output` is the pre-activator tensor, `sqr_grad_norms` is related to the squared norms of a layer's pre-activation tensor, and `vars` are relevant trainable weights (see `layer_registry_factories.py` for examples). + x_batch: An `InputTensor` representing a batch of inputs to the model. The + first axis must be the batch dimension. + y_batch: A `tf.Tensor` representing a batch of output labels. The first axis + must be the batch dimension. The number of examples should match the + number of examples in `x_batch`. + weight_batch: Optional vector of weights, passed to the loss function. Must + be of size [batch_size]. In case of microbatching, this will be reshaped + to [num_microbatches, batch_size/num_microbatches] before passing it to + the loss. num_microbatches: An optional number or scalar `tf.Tensor` for the number of microbatches. If not None, indicates that the loss is grouped into num_microbatches (in this case, the batch dimension needs to be a multiple @@ -243,11 +253,10 @@ def compute_clipped_gradients_and_outputs( the value of the clipped loss does not reflect the true loss. Returns: - A `tuple` `(grad, y_pred, clipping_loss_value)`. The first element is the - clipped gradient of the loss function, the second is the result of - applying `input_model` to `x_batch`, and the third is loss value of - `input_model`, weighted by the loss weights generated by a specific - `compute_clip_weights()` call. + clipped_grad: the clipped gradient of the loss function + y_pred: the result of applying `input_model` to `x_batch` + clipping_loss_value: the loss value weighted in such a way that its gradient + is `clipped_grad`. """ if input_model.loss.reduction == 'none': raise NotImplementedError( @@ -258,13 +267,25 @@ def compute_clipped_gradients_and_outputs( clipping_loss = input_model.compiled_loss gradient_norms = compute_gradient_norms( input_model, + layer_registry, x_batch, y_batch, - layer_registry, + weight_batch, num_microbatches=num_microbatches, trainable_vars=input_model.trainable_variables, ) - loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms) + clip_weights = compute_clip_weights(l2_norm_clip, gradient_norms) + if weight_batch is not None: + if num_microbatches is None: + clip_weights = clip_weights * weight_batch # shape [num_microbatches] + else: + # In this case, weight_batch is of shape [batch_size], we first reshape to + # [num_microbatches, microbatch_size] then multiply by the clip_weights + # (which is of shape [num_microbatches]) + weight_batch = lr.maybe_add_microbatch_axis( + weight_batch, num_microbatches + ) + clip_weights = clip_weights[:, tf.newaxis] * weight_batch with tf.GradientTape() as tape: # WARNING: When num_microbatches is not None, we need to be sure that # `compute_loss` always computes the mean over the microbatches @@ -274,17 +295,9 @@ def compute_clipped_gradients_and_outputs( # is not defined in the contract so may not hold, especially for # custom losses. y_pred = input_model(x_batch, training=True) - loss_y_batch = ( - y_batch - if num_microbatches is None - else lr.add_microbatch_axis(y_batch, num_microbatches) - ) - loss_y_pred = ( - y_pred - if num_microbatches is None - else lr.add_microbatch_axis(y_pred, num_microbatches) - ) - clipping_loss_value = clipping_loss(loss_y_batch, loss_y_pred, loss_weights) + mb_y_batch = lr.maybe_add_microbatch_axis(y_batch, num_microbatches) + mb_y_pred = lr.maybe_add_microbatch_axis(y_pred, num_microbatches) + clipping_loss_value = clipping_loss(mb_y_batch, mb_y_pred, clip_weights) clipped_grads = tape.gradient( clipping_loss_value, input_model.trainable_variables, diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 8d5ffcc4..14834484 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union from absl.testing import parameterized import tensorflow as tf - from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry @@ -71,17 +69,23 @@ def sqr_norm_fn(base_vars): return [vars1, vars2], outputs, sqr_norm_fn -def test_loss_fn(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: - x = tf.reshape(x, (tf.shape(x)[0], -1)) - y = tf.reshape(y, (tf.shape(y)[0], -1)) +def test_loss_fn( + x: tf.Tensor, y: tf.Tensor, weights: Optional[tf.Tensor] = None +) -> tf.Tensor: # Define a loss function which is unlikely to be coincidently defined. - return 3.14 * tf.reduce_sum(tf.square(x - y), axis=1) + if weights is None: + weights = 1.0 + loss = 3.14 * tf.reduce_sum( + tf.cast(weights, tf.float32) * tf.square(x - y), axis=1 + ) + return loss def compute_true_gradient_norms( input_model: tf.keras.Model, x_batch: tf.Tensor, y_batch: tf.Tensor, + weight_batch: Optional[tf.Tensor], per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], num_microbatches: Optional[int], trainable_vars: Optional[tf.Variable] = None, @@ -93,7 +97,7 @@ def compute_true_gradient_norms( per_example_loss_fn = input_model.loss.from_config(loss_config) with tf.GradientTape(persistent=True) as tape: y_pred = input_model(x_batch) - loss = per_example_loss_fn(y_batch, y_pred) + loss = per_example_loss_fn(y_batch, y_pred, weight_batch) if num_microbatches is not None: loss = tf.reduce_mean( tf.reshape( @@ -123,7 +127,8 @@ def get_computed_and_true_norms( per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]], num_microbatches: Optional[int], is_eager: bool, - x_input: tf.Tensor, + x_batch: tf.Tensor, + weight_batch: Optional[tf.Tensor] = None, rng_seed: int = 777, registry: layer_registry.LayerRegistry = None, partial: bool = False, @@ -146,10 +151,11 @@ def get_computed_and_true_norms( per_example_loss_fn: If not None, used as vectorized per example loss function. num_microbatches: The number of microbatches. None or an integer. - is_eager: A `bool` that is `True` if the model should be run eagerly. - x_input: `tf.Tensor` inputs to be tested. - rng_seed: An `int` used to initialize model weights. - registry: A `layer_registry.LayerRegistry` instance. + is_eager: whether the model should be run eagerly. + x_batch: inputs to be tested. + weight_batch: optional weights passed to the loss. + rng_seed: used as a seed for random initialization. + registry: required for fast clipping. partial: Whether to compute the gradient norm with respect to a partial set of varibles. If True, only consider the variables in the first layer. @@ -175,13 +181,14 @@ def get_computed_and_true_norms( trainable_vars = l.trainable_variables if trainable_vars: break - y_pred = model(x_input) + y_pred = model(x_batch) y_batch = tf.ones_like(y_pred) tf.keras.utils.set_random_seed(rng_seed) computed_norms = clip_grads.compute_gradient_norms( - model, - x_input, - y_batch, + input_model=model, + x_batch=x_batch, + y_batch=y_batch, + weight_batch=weight_batch, layer_registry=registry, per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, @@ -190,8 +197,9 @@ def get_computed_and_true_norms( tf.keras.utils.set_random_seed(rng_seed) true_norms = compute_true_gradient_norms( model, - x_input, + x_batch, y_batch, + weight_batch, per_example_loss_fn, num_microbatches, trainable_vars=trainable_vars, @@ -309,24 +317,16 @@ def make_weighted_bow_model(layer_generator, input_dims, output_dim): # ============================================================================== # Factory functions. # ============================================================================== -def get_nd_test_tensors(n: int): - """Returns a list of candidate tests for a given dimension n.""" - return [ - tf.zeros((n,), dtype=tf.float64), - tf.convert_to_tensor(range(n), dtype_hint=tf.float64), - ] - - def get_nd_test_batches(n: int): - """Returns a list of candidate input batches of dimension n.""" - result = [] - tensors = get_nd_test_tensors(n) - for batch_size in range(1, len(tensors) + 1, 1): - combinations = list( - itertools.combinations(get_nd_test_tensors(n), batch_size) - ) - result = result + [tf.stack(ts, axis=0) for ts in combinations] - return result + """Returns a list of input batches of dimension n.""" + # The first two batches have a single element, the last batch has 2 elements. + x0 = tf.zeros([1, n], dtype=tf.float64) + x1 = tf.constant([range(n)], dtype=tf.float64) + x2 = tf.concat([x0, x1], axis=0) + w0 = tf.constant([1], dtype=tf.float64) + w1 = tf.constant([2], dtype=tf.float64) + w2 = tf.constant([0.5, 0.5], dtype=tf.float64) + return [x0, x1, x2], [w0, w1, w2] def get_dense_layer_generators(): @@ -366,11 +366,14 @@ class ClipGradsDirectTest(tf.test.TestCase, parameterized.TestCase): ) def test_clip_weights(self, input_dim, clip_value): tol = 1e-6 - for t in get_nd_test_tensors(input_dim): - self.assertIsNone(clip_grads.compute_clip_weights(None, t)) + ts, _ = get_nd_test_batches(input_dim) + for t in ts: weights = clip_grads.compute_clip_weights(clip_value, t) self.assertAllLessEqual(t * weights, clip_value + tol) + def test_clip_weights_none(self): + self.assertIsNone(clip_grads.compute_clip_weights(None, tf.ones(3))) + class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): @@ -383,6 +386,7 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase): num_microbatches=[None, 1, 2], is_eager=[True, False], partial=[True, False], + weighted=[True, False], ) def test_gradient_norms_on_various_models( self, @@ -394,21 +398,16 @@ def test_gradient_norms_on_various_models( num_microbatches, is_eager, partial, + weighted, ): model_generator = get_dense_model_generators()[model_name] layer_generator = get_dense_layer_generators()[layer_name] - x_batches = get_nd_test_batches(input_dim) + x_batches, weight_batches = get_nd_test_batches(input_dim) default_registry = layer_registry.make_default_layer_registry() - for x_batch in x_batches: - if ( - num_microbatches is not None - and x_batch.shape[0] % num_microbatches != 0 - ): + for x_batch, weight_batch in zip(x_batches, weight_batches): + batch_size = x_batch.shape[0] + if num_microbatches is not None and batch_size % num_microbatches != 0: continue - if model_name == 'tower1': - x_input = [x_batch, x_batch] - else: - x_input = x_batch (computed_norms, true_norms) = get_computed_and_true_norms( model_generator, layer_generator, @@ -417,10 +416,13 @@ def test_gradient_norms_on_various_models( per_example_loss_fn, num_microbatches, is_eager, - x_input, + x_batch=[x_batch, x_batch] if model_name == 'tower1' else x_batch, + weight_batch=weight_batch if weighted else None, registry=default_registry, partial=partial, ) + expected_size = num_microbatches or batch_size + self.assertEqual(computed_norms.shape[0], expected_size) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) @@ -471,31 +473,30 @@ def test_gradient_norms_on_various_models( is_eager, partial, ): + batch_size = x_batch.shape[0] + # The following are invalid test combinations, and are skipped. if ( - num_microbatches is not None - and x_batch.shape[0] % num_microbatches != 0 + num_microbatches is not None and batch_size % num_microbatches != 0 + ) or ( + model_name == 'weighted_bow1' and isinstance(x_batch, tf.RaggedTensor) ): return - valid_test_input = ( - not isinstance(x_batch, tf.RaggedTensor) - and model_name == 'weighted_bow1' - ) or (model_name != 'weighted_bow1') - if valid_test_input: - default_registry = layer_registry.make_default_layer_registry() - model_generator = get_embedding_model_generators()[model_name] - (computed_norms, true_norms) = get_computed_and_true_norms( - model_generator=model_generator, - layer_generator=None, - input_dims=x_batch.shape[1:], - output_dim=output_dim, - per_example_loss_fn=per_example_loss_fn, - num_microbatches=num_microbatches, - is_eager=is_eager, - x_input=x_batch, - registry=default_registry, - partial=partial, - ) - self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) + default_registry = layer_registry.make_default_layer_registry() + model_generator = get_embedding_model_generators()[model_name] + (computed_norms, true_norms) = get_computed_and_true_norms( + model_generator=model_generator, + layer_generator=None, + input_dims=x_batch.shape[1:], + output_dim=output_dim, + per_example_loss_fn=per_example_loss_fn, + num_microbatches=num_microbatches, + is_eager=is_eager, + x_batch=x_batch, + registry=default_registry, + partial=partial, + ) + self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) + self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): @@ -507,6 +508,7 @@ class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase): num_microbatches=[None, 2], is_eager=[True, False], partial=[True, False], + weighted=[True, False], ) def test_gradient_norms_on_various_models( self, @@ -516,15 +518,14 @@ def test_gradient_norms_on_various_models( num_microbatches, is_eager, partial, + weighted, ): registry = layer_registry.make_default_layer_registry() registry.insert(DoubleDense, double_dense_layer_computation) - x_batches = get_nd_test_batches(input_dim) - for x_batch in x_batches: - if ( - num_microbatches is not None - and x_batch.shape[0] % num_microbatches != 0 - ): + x_batches, weight_batches = get_nd_test_batches(input_dim) + for x_batch, weight_batch in zip(x_batches, weight_batches): + batch_size = x_batch.shape[0] + if num_microbatches is not None and batch_size % num_microbatches != 0: continue (computed_norms, true_norms) = get_computed_and_true_norms( model_generator=make_two_layer_sequential_model, @@ -534,10 +535,12 @@ def test_gradient_norms_on_various_models( per_example_loss_fn=per_example_loss_fn, num_microbatches=num_microbatches, is_eager=is_eager, - x_input=x_batch, + x_batch=x_batch, + weight_batch=weight_batch if weighted else None, registry=registry, partial=partial, ) + self.assertEqual(computed_norms.shape[0], num_microbatches or batch_size) self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2) @@ -574,17 +577,14 @@ def test_clipped_gradients_on_different_losses( ) # Stop early for efficiency. if reduction == 'none': - self.assertRaises( - NotImplementedError, - # function tested - clip_grads.compute_clipped_gradients_and_outputs, - # function args - self._model, - x_batch, - y_batch, - l2_norm_clip, - layer_registry.make_default_layer_registry(), - ) + with self.assertRaises(NotImplementedError): + clip_grads.compute_clipped_gradients_and_outputs( + self._model, + l2_norm_clip, + layer_registry.make_default_layer_registry(), + x_batch, + y_batch, + ) return # NOTE: losses from this point are scalar losses. with tf.GradientTape() as tape: @@ -593,10 +593,10 @@ def test_clipped_gradients_on_different_losses( true_grads = tape.gradient(loss_value, self._model.trainable_variables) clipped_grads, _, _ = clip_grads.compute_clipped_gradients_and_outputs( self._model, - x_batch, - y_batch, l2_norm_clip, layer_registry.make_default_layer_registry(), + x_batch, + y_batch, ) # Computes the L2 norm manually. diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py index 69f499d0..838b4f7f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/layer_registry.py @@ -105,26 +105,23 @@ def insert( # ============================================================================== # Utilities # ============================================================================== -def add_microbatch_axis( +def maybe_add_microbatch_axis( x: tf.Tensor, num_microbatches: Optional[BatchSize], ) -> tf.Tensor: """Adds the microbatch axis. - Reshape the input tensor to replace the first(batch) dimension with the - shape [num_microbatches, batch_size / num_microbatches]. The batch size - must be a multiple of num_microbatches (unless it is None, meaning - num_microbatches is the same as the batch size). - Args: x: the input tensor. - num_microbatches: None or a numeric value or a scalar `tf.Tensor`. + num_microbatches: If None, x is returned unchanged. Otherwise, must divide + the batch size. Returns: - The reshaped input tensor. + The input tensor x, reshaped from [batch_size, ...] to + [num_microbatches, batch_size / num_microbatches, ...]. """ if num_microbatches is None: - return tf.expand_dims(x, 1) + return x with tf.control_dependencies( [tf.assert_equal(tf.math.floormod(tf.shape(x)[0], num_microbatches), 0)] ): @@ -193,7 +190,7 @@ def sqr_norm_fn(base_vars_grads): def _compute_gramian(x): if num_microbatches is not None: - x_microbatched = add_microbatch_axis(x, num_microbatches) + x_microbatched = maybe_add_microbatch_axis(x, num_microbatches) return tf.matmul(x_microbatched, x_microbatched, transpose_b=True) else: # Special handling for better efficiency diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index 176d7be3..80132df4 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -15,7 +15,6 @@ from absl import logging import tensorflow as tf - from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr @@ -179,15 +178,19 @@ def _reduce_per_example_grads(self, stacked_grads): tf.shape(stacked_grads)[0], summed_grads.dtype ) - def _compute_per_example_grads(self, data): + def _compute_per_example_grads(self, microbatched_data): if self._clipping_loss is None: self._make_clipping_loss() - microbatched_x, microbatched_y = data + microbatched_x, microbatched_y, microbatched_weights = ( + tf.keras.utils.unpack_x_y_sample_weight(microbatched_data) + ) with tf.GradientTape() as tape: microbatched_y_pred = self(microbatched_x, training=True) # NOTE: `self._clipping_loss` does not include any regularization terms. microbatched_loss = self._clipping_loss( - microbatched_y, microbatched_y_pred + microbatched_y, + microbatched_y_pred, + sample_weight=microbatched_weights, ) grads_list = tape.gradient(microbatched_loss, self.trainable_variables) clipped_grads = self._process_per_example_grads(grads_list) @@ -232,12 +235,8 @@ def train_step(self, data): self._make_clipping_loss() output_metrics = {} x, y, weights = tf.keras.utils.unpack_x_y_sample_weight(data) - if weights is not None: - raise NotImplementedError( - 'DPModel does not currently support weighted losses.' - ) batch_size = tf.shape(y)[0] - eff_num_microbatches = self._num_microbatches or batch_size + num_microbatches = self._num_microbatches or batch_size # Branch based on gradient clipping algorithm. if self._enable_fast_peg_computation: @@ -251,13 +250,14 @@ def train_step(self, data): # microbatches is done here. clipped_grads, y_pred, clipping_loss = ( clip_grads.compute_clipped_gradients_and_outputs( - self, - x, - y, - self._l2_norm_clip, - self._layer_registry, - self._num_microbatches, - self._clipping_loss, + input_model=self, + x_batch=x, + y_batch=y, + weight_batch=weights, + l2_norm_clip=self._l2_norm_clip, + layer_registry=self._layer_registry, + num_microbatches=self._num_microbatches, + clipping_loss=self._clipping_loss, ) ) output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss @@ -265,7 +265,7 @@ def train_step(self, data): grads = gradient_clipping_utils.add_aggregate_noise( self, clipped_grads, - eff_num_microbatches, + num_microbatches, self._l2_norm_clip, self._noise_multiplier, ) @@ -276,7 +276,7 @@ def train_step(self, data): # Computes per-example clipped gradients directly. This is called # if at least one of the layers cannot use the "fast" gradient clipping # algorithm. - reshape_fn = lambda z: lr.add_microbatch_axis(z, eff_num_microbatches) + reshape_fn = lambda z: lr.maybe_add_microbatch_axis(z, num_microbatches) microbatched_data = tf.nest.map_structure(reshape_fn, data) clipped_grads = tf.vectorized_map( self._compute_per_example_grads, @@ -305,7 +305,9 @@ def train_step(self, data): output_metrics[_PRIVATIZED_LOSS_NAME] += summed_regularization_loss # Log the true loss, including regularization losses. - self.compiled_loss(y, y_pred, regularization_losses=self.losses) + self.compiled_loss( + y, y_pred, sample_weight=weights, regularization_losses=self.losses + ) # Forward the private gradients to the optimizer and return the results. self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py index d4dc7244..23acb371 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model_test.py @@ -21,8 +21,11 @@ def get_data(): # Data is for hidden weights of [3, 1] and bias of 2. - # With mean squared loss, we expect loss = 15^2 = 225, gradients of - # weights = [90, 120], and gradient of bias = 30. + # Loss is (w.x + b - y)^2, model is initialized at (w, b) = (0, 0). + # y = 15 + # Loss: y^2 = 15^2 = 225 + # Gradient w.r.t. w = -2yx = [90, 120] + # Gradient w.r.t. b = -2y = 30 data = np.array([[3, 4]]) labels = np.matmul(data, [[3], [1]]) + 2 return data, labels @@ -41,8 +44,10 @@ def testBaseline(self): layers=[ tf.keras.layers.InputLayer(input_shape=(2,)), tf.keras.layers.Dense( - 1, kernel_initializer='zeros', bias_initializer='zeros') - ]) + 1, kernel_initializer='zeros', bias_initializer='zeros' + ), + ], + ) optimizer = tf.keras.optimizers.SGD(learning_rate=0.01) loss = tf.keras.losses.MeanSquaredError() @@ -58,101 +63,149 @@ def testBaseline(self): @parameterized.product( l2_norm_clip=(10.0, 40.0, 200.0), fast_clipping=(True, False), + sequential=(True, False), + weighted=(False, True), ) - def testClippingNorm(self, l2_norm_clip, fast_clipping): + def testClippingNorm(self, l2_norm_clip, fast_clipping, sequential, weighted): """Tests that clipping norm works.""" train_data, train_labels = get_data() # Simple linear model returns w * x + b. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0.0, - layer_registry=layer_registry.make_default_layer_registry() - if fast_clipping - else None, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - 1, kernel_initializer='zeros', bias_initializer='zeros' - ), - ], + layer = tf.keras.layers.Dense( + 1, kernel_initializer='zeros', bias_initializer='zeros' ) + if sequential: + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, + layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer], + ) + else: + inputs = tf.keras.Input(shape=(2,), dtype=tf.float32) + outputs = layer(inputs) + model = dp_keras_model.DPModel( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + layer_registry=layer_registry.make_default_layer_registry() + if fast_clipping + else None, + inputs=inputs, + outputs=outputs, + ) learning_rate = 0.01 optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss = tf.keras.losses.MeanSquaredError() model.compile(optimizer=optimizer, loss=loss) - expected_loss = loss(train_labels, model(train_data)) - results = model.fit(train_data, train_labels, epochs=1, batch_size=1) - - model_weights = model.get_weights() - unclipped_gradient = np.sqrt(90**2 + 120**2 + 30**2) - scale = min(1.0, l2_norm_clip / unclipped_gradient) - expected_weights = np.array([[90], [120]]) * scale * learning_rate - expected_bias = np.array([30]) * scale * learning_rate + weights = None + data = tf.data.Dataset.from_tensors((train_data, train_labels)) + expected_grad_w = np.array([90.0, 120.0]) + expected_grad_b = np.array([30.0]) + if weighted: + # Apply a weight to the (single) example. + weights = [0.18] + data = tf.data.Dataset.from_tensors((train_data, train_labels, weights)) + expected_grad_w *= 0.18 + expected_grad_b *= 0.18 + + unclipped_norm = np.linalg.norm( + np.concatenate([expected_grad_w, expected_grad_b]) + ) + scale = min(1.0, l2_norm_clip / unclipped_norm) + expected_weights = expected_grad_w * scale * learning_rate + expected_bias = expected_grad_b * scale * learning_rate + expected_loss = loss(train_labels, model(train_data), weights) + results = model.fit(data, epochs=1, batch_size=1) + weights, bias = model.get_weights() # Check parameters are as expected, taking into account the learning rate. - self.assertAllClose(model_weights[0], expected_weights) - self.assertAllClose(model_weights[1], expected_bias) + self.assertAllClose(np.squeeze(weights), expected_weights) + self.assertAllClose(bias, expected_bias) # Check the value of the loss. actual_loss = results.history['loss'][0] self.assertAllClose(expected_loss, actual_loss) - def _compute_expected_gradients(self, data, labels, w, l2_norm_clip, - num_microbatches): + def _compute_expected_gradients( + self, + data, + labels, + weights, + w0, + l2_norm_clip, + num_microbatches, + ): + if weights is None: + weights = np.array([1], dtype=np.float32) batch_size = data.shape[0] if num_microbatches is None: num_microbatches = batch_size - - preds = np.matmul(data, np.expand_dims(w, axis=1)) - - grads = 2 * data * (preds - labels) - - grads = np.reshape(grads, - [num_microbatches, batch_size // num_microbatches, -1]) - + preds = np.matmul(data, w0[:, np.newaxis]) + grads = 2 * data * (preds - labels) * weights[:, np.newaxis] + grads = np.reshape( + grads, [num_microbatches, batch_size // num_microbatches, -1] + ) mb_grads = np.mean(grads, axis=1) mb_grad_norms = np.linalg.norm(mb_grads, axis=1) - scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0) - mb_grads = mb_grads * scale[:, np.newaxis] - final_grads = np.mean(mb_grads, axis=0) return final_grads @parameterized.product( num_microbatches=(None, 1, 2, 4), - fast_clipping=(False, True), + fast_clipping=(True, False), + sequential=(False, True), + weighted=(True, False), ) - def testMicrobatches(self, num_microbatches, fast_clipping): + def testMicrobatches( + self, num_microbatches, fast_clipping, sequential, weighted + ): l2_norm_clip = 1.0 train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]]) - w = np.zeros((2)) train_labels = np.array([[1.0], [3.0], [-2.0], [-4.0]]) + if weighted: + train_weights = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32) + dataset = tf.data.Dataset.from_tensors( + (train_data, train_labels, train_weights) + ) + else: + train_weights = None + dataset = tf.data.Dataset.from_tensors((train_data, train_labels)) learning_rate = 1.0 optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) loss = tf.keras.losses.MeanSquaredError() # Simple linear model returns w * x. - model = dp_keras_model.DPSequential( - l2_norm_clip=l2_norm_clip, - noise_multiplier=0.0, - num_microbatches=num_microbatches, - layer_registry=layer_registry.make_default_layer_registry() - if fast_clipping - else None, - layers=[ - tf.keras.layers.InputLayer(input_shape=(2,)), - tf.keras.layers.Dense( - 1, use_bias=False, kernel_initializer='zeros' - ), - ], + layer = tf.keras.layers.Dense(1, use_bias=False, kernel_initializer='zeros') + registry = ( + layer_registry.make_default_layer_registry() if fast_clipping else None ) + if sequential: + model = dp_keras_model.DPSequential( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layer_registry=registry, + layers=[tf.keras.layers.InputLayer(input_shape=(2,)), layer], + ) + else: + inputs = tf.keras.Input(shape=(2,), dtype=tf.float32) + outputs = layer(inputs) + model = dp_keras_model.DPModel( + l2_norm_clip=l2_norm_clip, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + layer_registry=registry, + inputs=inputs, + outputs=outputs, + ) model.compile(optimizer=optimizer, loss=loss) - model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False) + model.fit(dataset, epochs=1, batch_size=4, shuffle=False) model_weights = np.squeeze(model.get_weights()) @@ -163,7 +216,14 @@ def testMicrobatches(self, num_microbatches, fast_clipping): ) expected_grads = self._compute_expected_gradients( - train_data, train_labels, w, l2_norm_clip, effective_num_microbatches + train_data, + train_labels, + train_weights, + np.zeros( + 2, + ), + l2_norm_clip, + effective_num_microbatches, ) expected_weights = np.squeeze(-learning_rate * expected_grads) self.assertAllClose(model_weights, expected_weights)