diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py index 82a0c66d..538af679 100644 --- a/tensorflow_privacy/__init__.py +++ b/tensorflow_privacy/__init__.py @@ -61,14 +61,9 @@ from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class # Optimizers - from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdagradOptimizer - from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer - from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer - from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class - from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_generic_optimizer_class from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer diff --git a/tensorflow_privacy/privacy/optimizers/BUILD b/tensorflow_privacy/privacy/optimizers/BUILD index 37b238b2..e6ca5cee 100644 --- a/tensorflow_privacy/privacy/optimizers/BUILD +++ b/tensorflow_privacy/privacy/optimizers/BUILD @@ -18,18 +18,6 @@ py_library( deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"], ) -py_library( - name = "dp_optimizer_factory", - srcs = [ - "dp_optimizer_keras.py", - ], - srcs_version = "PY3", - deps = [ - "//tensorflow_privacy/privacy/dp_query", - "//tensorflow_privacy/privacy/dp_query:gaussian_query", - ], -) - py_library( name = "dp_optimizer_vectorized", srcs = [ @@ -44,10 +32,7 @@ py_library( "dp_optimizer_keras.py", ], srcs_version = "PY3", - deps = [ - "//tensorflow_privacy/privacy/dp_query", - "//tensorflow_privacy/privacy/dp_query:gaussian_query", - ], + deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"], ) py_library( @@ -99,7 +84,7 @@ py_test( python_version = "PY3", srcs_version = "PY3", deps = [ - ":dp_optimizer_keras", - ":dp_optimizer_keras_vectorized", + "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras", + "//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras_vectorized", ], ) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py index e9349f45..51aeac46 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -13,28 +13,21 @@ # limitations under the License. # ============================================================================== """Differentially private version of Keras optimizer v2.""" -from typing import Optional, Type -import warnings import tensorflow as tf -from tensorflow_privacy.privacy.dp_query import dp_query -from tensorflow_privacy.privacy.dp_query import gaussian_query - -def _normalize(microbatch_gradient: tf.Tensor, - num_microbatches: float) -> tf.Tensor: - """Normalizes `microbatch_gradient` by `num_microbatches`.""" - return tf.truediv(microbatch_gradient, - tf.cast(num_microbatches, microbatch_gradient.dtype)) +from tensorflow_privacy.privacy.dp_query import gaussian_query -def make_keras_generic_optimizer_class( - cls: Type[tf.keras.optimizers.Optimizer]): - """Returns a differentially private (DP) subclass of `cls`. +def make_keras_optimizer_class(cls): + """Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it. Args: cls: Class from which to derive a DP subclass. Should be a subclass of `tf.keras.optimizers.Optimizer`. + + Returns: + A DP-SGD subclass of `cls`. """ class DPOptimizerClass(cls): # pylint: disable=empty-docstring @@ -145,23 +138,24 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring def __init__( self, - dp_sum_query: dp_query.DPQuery, - num_microbatches: Optional[int] = None, - gradient_accumulation_steps: int = 1, + l2_norm_clip, + noise_multiplier, + num_microbatches=None, + gradient_accumulation_steps=1, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs): - """Initializes the DPOptimizerClass. + """Initialize the DPOptimizerClass. Args: - dp_sum_query: `DPQuery` object, specifying differential privacy - mechanism to use. + l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients). + noise_multiplier: Ratio of the standard deviation to the clipping norm. num_microbatches: Number of microbatches into which each minibatch is - split. Default is `None` which means that number of microbatches is - equal to batch size (i.e. each microbatch contains exactly one + split. Default is `None` which means that number of microbatches + is equal to batch size (i.e. each microbatch contains exactly one example). If `gradient_accumulation_steps` is greater than 1 and `num_microbatches` is not `None` then the effective number of - microbatches is equal to `num_microbatches * - gradient_accumulation_steps`. + microbatches is equal to + `num_microbatches * gradient_accumulation_steps`. gradient_accumulation_steps: If greater than 1 then optimizer will be accumulating gradients for this number of optimizer steps before applying them to update model weights. If this argument is set to 1 @@ -171,13 +165,13 @@ def __init__( """ super().__init__(*args, **kwargs) self.gradient_accumulation_steps = gradient_accumulation_steps + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier self._num_microbatches = num_microbatches - self._dp_sum_query = dp_sum_query - self._was_dp_gradients_called = False - # We initialize the self.`_global_state` within the gradient functions - # (and not here) because tensors must be initialized within the graph. - + self._dp_sum_query = gaussian_query.GaussianSumQuery( + l2_norm_clip, l2_norm_clip * noise_multiplier) self._global_state = None + self._was_dp_gradients_called = False def _create_slots(self, var_list): super()._create_slots(var_list) # pytype: disable=attribute-error @@ -241,62 +235,66 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): """DP-SGD version of base class method.""" self._was_dp_gradients_called = True - if self._global_state is None: - self._global_state = self._dp_sum_query.initial_global_state() - # Compute loss. if not callable(loss) and tape is None: raise ValueError('`tape` is required when a `Tensor` loss is passed.') - tape = tape if tape is not None else tf.GradientTape() - with tape: - if callable(loss): + if callable(loss): + with tape: if not callable(var_list): tape.watch(var_list) loss = loss() - if self._num_microbatches is None: - num_microbatches = tf.shape(input=loss)[0] - else: - num_microbatches = self._num_microbatches - microbatch_losses = tf.reduce_mean( - tf.reshape(loss, [num_microbatches, -1]), axis=1) - - if callable(var_list): - var_list = var_list() + if self._num_microbatches is None: + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [num_microbatches, -1]), axis=1) + + if callable(var_list): + var_list = var_list() + else: + with tape: + if self._num_microbatches is None: + num_microbatches = tf.shape(input=loss)[0] + else: + num_microbatches = self._num_microbatches + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [num_microbatches, -1]), axis=1) var_list = tf.nest.flatten(var_list) - sample_params = ( - self._dp_sum_query.derive_sample_params(self._global_state)) - # Compute the per-microbatch losses using helpful jacobian method. with tf.keras.backend.name_scope(self._name + '/gradients'): - jacobian_per_var = tape.jacobian( + jacobian = tape.jacobian( microbatch_losses, var_list, unconnected_gradients='zero') - def process_microbatch(sample_state, microbatch_jacobians): - """Process one microbatch (record) with privacy helper.""" - sample_state = self._dp_sum_query.accumulate_record( - sample_params, sample_state, microbatch_jacobians) - return sample_state + # Clip gradients to given l2_norm_clip. + def clip_gradients(g): + return tf.clip_by_global_norm(g, self._l2_norm_clip)[0] - sample_state = self._dp_sum_query.initial_sample_state(var_list) - for idx in range(num_microbatches): - microbatch_jacobians_per_var = [ - jacobian[idx] for jacobian in jacobian_per_var - ] - sample_state = process_microbatch(sample_state, - microbatch_jacobians_per_var) + clipped_gradients = tf.map_fn(clip_gradients, jacobian) - grad_sums, self._global_state, _ = ( - self._dp_sum_query.get_noised_result(sample_state, - self._global_state)) - final_grads = tf.nest.map_structure(_normalize, grad_sums, - [num_microbatches] * len(grad_sums)) + def reduce_noise_normalize_batch(g): + # Sum gradients over all microbatches. + summed_gradient = tf.reduce_sum(g, axis=0) - return list(zip(final_grads, var_list)) + # Add noise to summed gradients. + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_gradient), stddev=noise_stddev) + noised_gradient = tf.add(summed_gradient, noise) + + # Normalize by number of microbatches and return. + return tf.truediv(noised_gradient, + tf.cast(num_microbatches, tf.float32)) + + final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, + clipped_gradients) + + return list(zip(final_gradients, var_list)) def get_gradients(self, loss, params): """DP-SGD version of base class method.""" @@ -324,13 +322,17 @@ def process_microbatch(i, sample_state): sample_state = self._dp_sum_query.initial_sample_state(params) for idx in range(self._num_microbatches): sample_state = process_microbatch(idx, sample_state) - grad_sums, self._global_state, _ = ( self._dp_sum_query.get_noised_result(sample_state, self._global_state)) - final_grads = tf.nest.map_structure( - _normalize, grad_sums, [self._num_microbatches] * len(grad_sums)) + def normalize(v): + try: + return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32)) + except TypeError: + return None + + final_grads = tf.nest.map_structure(normalize, grad_sums) return final_grads @@ -366,87 +368,7 @@ def apply_gradients(self, *args, **kwargs): return DPOptimizerClass -def make_gaussian_query_optimizer_class(cls): - """Returns a differentially private optimizer using the `GaussianSumQuery`. - - Args: - cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`. - """ - - def return_gaussian_query_optimizer( - l2_norm_clip: float, - noise_multiplier: float, - num_microbatches: Optional[int] = None, - gradient_accumulation_steps: int = 1, - *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args - **kwargs): - """Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`. - - This function is a thin wrapper around - `make_keras_optimizer_class..DPOptimizerClass` which can be used to - apply a `GaussianSumQuery` to any `DPOptimizerClass`. - - When combined with stochastic gradient descent, this creates the canonical - DP-SGD algorithm of "Deep Learning with Differential Privacy" - (see https://arxiv.org/abs/1607.00133). - - Args: - l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients). - noise_multiplier: Ratio of the standard deviation to the clipping norm. - num_microbatches: Number of microbatches into which each minibatch is - split. Default is `None` which means that number of microbatches is - equal to batch size (i.e. each microbatch contains exactly one example). - If `gradient_accumulation_steps` is greater than 1 and - `num_microbatches` is not `None` then the effective number of - microbatches is equal to `num_microbatches * - gradient_accumulation_steps`. - gradient_accumulation_steps: If greater than 1 then optimizer will be - accumulating gradients for this number of optimizer steps before - applying them to update model weights. If this argument is set to 1 then - updates will be applied on each optimizer step. - *args: These will be passed on to the base class `__init__` method. - **kwargs: These will be passed on to the base class `__init__` method. - """ - dp_sum_query = gaussian_query.GaussianSumQuery( - l2_norm_clip, l2_norm_clip * noise_multiplier) - return cls( - dp_sum_query=dp_sum_query, - num_microbatches=num_microbatches, - gradient_accumulation_steps=gradient_accumulation_steps, - *args, - **kwargs) - - return return_gaussian_query_optimizer - - -def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]): - """Returns a differentially private optimizer using the `GaussianSumQuery`. - - For backwards compatibility, we create this symbol to match the previous - output of `make_keras_optimizer_class` but using the new logic. - - Args: - cls: Class from which to derive a DP subclass. Should be a subclass of - `tf.keras.optimizers.Optimizer`. - """ - warnings.warn( - '`make_keras_optimizer_class` will be depracated on 2023-02-23. ' - 'Please switch to `make_gaussian_query_optimizer_class` and the ' - 'generic optimizers (`make_keras_generic_optimizer_class`).') - return make_gaussian_query_optimizer_class( - make_keras_generic_optimizer_class(cls)) - - -GenericDPAdagradOptimizer = make_keras_generic_optimizer_class( +DPKerasAdagradOptimizer = make_keras_optimizer_class( tf.keras.optimizers.Adagrad) -GenericDPAdamOptimizer = make_keras_generic_optimizer_class( - tf.keras.optimizers.Adam) -GenericDPSGDOptimizer = make_keras_generic_optimizer_class( - tf.keras.optimizers.SGD) - -# We keep the same names for backwards compatibility. -DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class( - GenericDPAdagradOptimizer) -DPKerasAdamOptimizer = make_gaussian_query_optimizer_class( - GenericDPAdamOptimizer) -DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer) +DPKerasAdamOptimizer = make_keras_optimizer_class(tf.keras.optimizers.Adam) +DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.SGD) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py index 8a667eb3..045b1188 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from absl.testing import parameterized import numpy as np import tensorflow as tf @@ -28,29 +29,36 @@ def _loss(self, val0, val1): return 0.5 * tf.reduce_sum( input_tensor=tf.math.squared_difference(val0, val1), axis=1) + # Parameters for testing: optimizer, num_microbatches, expected gradient for + # var0, expected gradient for var1. @parameterized.named_parameters( - ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), - ('DPAdam_2', dp_optimizer_keras.DPKerasAdamOptimizer, 2), - ('DPAdagrad _4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4), - ('DPGradientDescentVectorized_1', - dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), - ('DPAdamVectorized_2', - dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2), - ('DPAdagradVectorized_4', - dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4), - ('DPAdagradVectorized_None', - dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None), + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5 + ], [-0.5]), + ('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdamVectorized 2', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2, + [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized None', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None, + [-2.5, -2.5], [-0.5]), ) - def testBaselineWithCallableLossNoNoise(self, optimizer_class, - num_microbatches): + def testBaselineWithCallableLoss(self, cls, num_microbatches, expected_grad0, + expected_grad1): var0 = tf.Variable([1.0, 2.0]) var1 = tf.Variable([3.0]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]]) - expected_grad0 = [-2.5, -2.5] - expected_grad1 = [-0.5] - optimizer = optimizer_class( + opt = cls( l2_norm_clip=100.0, noise_multiplier=0.0, num_microbatches=num_microbatches, @@ -58,34 +66,40 @@ def testBaselineWithCallableLossNoNoise(self, optimizer_class, loss = lambda: self._loss(data0, var0) + self._loss(data1, var1) - grads_and_vars = optimizer._compute_gradients(loss, [var0, var1]) - + grads_and_vars = opt._compute_gradients(loss, [var0, var1]) self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0]) self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0]) + # Parameters for testing: optimizer, num_microbatches, expected gradient for + # var0, expected gradient for var1. @parameterized.named_parameters( - ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), - ('DPAdam_2', dp_optimizer_keras.DPKerasAdamOptimizer, 2), - ('DPAdagrad_4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4), - ('DPGradientDescentVectorized_1', - dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), - ('DPAdamVectorized_2', - dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2), - ('DPAdagradVectorized_4', - dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4), - ('DPAdagradVectorized_None', - dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None), + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, [-2.5, -2.5 + ], [-0.5]), + ('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdamVectorized 2', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer, 2, + [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized 4', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), + ('DPAdagradVectorized None', + dp_optimizer_keras_vectorized.VectorizedDPKerasAdagradOptimizer, None, + [-2.5, -2.5], [-0.5]), ) - def testBaselineWithTensorLossNoNoise(self, optimizer_class, - num_microbatches): + def testBaselineWithTensorLoss(self, cls, num_microbatches, expected_grad0, + expected_grad1): var0 = tf.Variable([1.0, 2.0]) var1 = tf.Variable([3.0]) data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]]) - expected_grad0 = [-2.5, -2.5] - expected_grad1 = [-0.5] - optimizer = optimizer_class( + opt = cls( l2_norm_clip=100.0, noise_multiplier=0.0, num_microbatches=num_microbatches, @@ -95,7 +109,7 @@ def testBaselineWithTensorLossNoNoise(self, optimizer_class, with tape: loss = self._loss(data0, var0) + self._loss(data1, var1) - grads_and_vars = optimizer._compute_gradients(loss, [var0, var1], tape=tape) + grads_and_vars = opt._compute_gradients(loss, [var0, var1], tape=tape) self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0]) self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0]) @@ -104,11 +118,11 @@ def testBaselineWithTensorLossNoNoise(self, optimizer_class, ('DPGradientDescentVectorized', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer), ) - def testClippingNorm(self, optimizer_class): + def testClippingNorm(self, cls): var0 = tf.Variable([0.0, 0.0]) data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) - optimizer = optimizer_class( + opt = cls( l2_norm_clip=1.0, noise_multiplier=0.0, num_microbatches=1, @@ -116,39 +130,37 @@ def testClippingNorm(self, optimizer_class): loss = lambda: self._loss(data0, var0) # Expected gradient is sum of differences. - grads_and_vars = optimizer._compute_gradients(loss, [var0]) + grads_and_vars = opt._compute_gradients(loss, [var0]) self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) @parameterized.named_parameters( - ('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, + ('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, 4.0, 1), - ('DPGradientDescent_4_1_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0, + ('DPGradientDescent 4 1 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0, 1.0, 4), - ('DPGradientDescentVectorized_2_4_1', + ('DPGradientDescentVectorized 2 4 1', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, 1), - ('DPGradientDescentVectorized_4_1_4', + ('DPGradientDescentVectorized 4 1 4', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4.0, 1.0, 4), ) - def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier, + def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, num_microbatches): - tf.random.set_seed(2) var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32)) data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32)) - optimizer = optimizer_class( + opt = cls( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, num_microbatches=num_microbatches, learning_rate=2.0) loss = lambda: self._loss(data0, var0) - grads_and_vars = optimizer._compute_gradients(loss, [var0]) + grads_and_vars = opt._compute_gradients(loss, [var0]) grads = grads_and_vars[0][0].numpy() # Test standard deviation is close to l2_norm_clip * noise_multiplier. - self.assertNear( np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5) @@ -163,9 +175,9 @@ def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier, ('DPAdamVectorized', dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer), ) - def testRaisesOnNoCallOfComputeGradients(self, optimizer_class): + def testAssertOnNoCallOfComputeGradients(self, cls): """Tests that assertion fails when DP gradients are not computed.""" - optimizer = optimizer_class( + opt = cls( l2_norm_clip=100.0, noise_multiplier=0.0, num_microbatches=1, @@ -173,14 +185,14 @@ def testRaisesOnNoCallOfComputeGradients(self, optimizer_class): with self.assertRaises(AssertionError): grads_and_vars = tf.Variable([0.0]) - optimizer.apply_gradients(grads_and_vars) + opt.apply_gradients(grads_and_vars) # Expect no exception if _compute_gradients is called. var0 = tf.Variable([0.0]) data0 = tf.Variable([[0.0]]) loss = lambda: self._loss(data0, var0) - grads_and_vars = optimizer._compute_gradients(loss, [var0]) - optimizer.apply_gradients(grads_and_vars) + grads_and_vars = opt._compute_gradients(loss, [var0]) + opt.apply_gradients(grads_and_vars) class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): @@ -190,8 +202,8 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): the Estimator framework. """ - def _make_linear_model_fn(self, optimizer_class, l2_norm_clip, - noise_multiplier, num_microbatches, learning_rate): + def _make_linear_model_fn(self, opt_cls, l2_norm_clip, noise_multiplier, + num_microbatches, learning_rate): """Returns a model function for a linear regressor.""" def linear_model_fn(features, labels, mode): @@ -206,7 +218,7 @@ def linear_model_fn(features, labels, mode): vector_loss = 0.5 * tf.math.squared_difference(labels, preds) scalar_loss = tf.reduce_mean(input_tensor=vector_loss) - optimizer = optimizer_class( + optimizer = opt_cls( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, num_microbatches=num_microbatches, @@ -222,25 +234,26 @@ def linear_model_fn(features, labels, mode): return linear_model_fn + # Parameters for testing: optimizer, num_microbatches. @parameterized.named_parameters( - ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), - ('DPGradientDescent_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), - ('DPGradientDescent_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), - ('DPGradientDescentVectorized_1', + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), + ('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), + ('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), + ('DPGradientDescentVectorized 1', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), - ('DPGradientDescentVectorized_2', + ('DPGradientDescentVectorized 2', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2), - ('DPGradientDescentVectorized_4', + ('DPGradientDescentVectorized 4', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 4), - ('DPGradientDescentVectorized_None', + ('DPGradientDescentVectorized None', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, None), ) - def testBaselineNoNoise(self, optimizer_class, num_microbatches): + def testBaseline(self, cls, num_microbatches): """Tests that DP optimizers work with tf.estimator.""" linear_regressor = tf_estimator.Estimator( - model_fn=self._make_linear_model_fn(optimizer_class, 100.0, 0.0, - num_microbatches, 0.05)) + model_fn=self._make_linear_model_fn(cls, 100.0, 0.0, num_microbatches, + 0.05)) true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) true_bias = np.array([6.0]).astype(np.float32) @@ -263,12 +276,13 @@ def train_input_fn(): self.assertAllClose( linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05) + # Parameters for testing: optimizer, num_microbatches. @parameterized.named_parameters( - ('DPGradientDescent_1', dp_optimizer_keras.DPKerasSGDOptimizer), - ('DPGradientDescentVectorized_1', - dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer), + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), + ('DPGradientDescentVectorized 1', + dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 1), ) - def testClippingNorm(self, optimizer_class): + def testClippingNorm(self, cls, num_microbatches): """Tests that DP optimizers work with tf.estimator.""" true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32) @@ -282,12 +296,8 @@ def train_input_fn(): (train_data, train_labels)).batch(1) unclipped_linear_regressor = tf_estimator.Estimator( - model_fn=self._make_linear_model_fn( - optimizer_class=optimizer_class, - l2_norm_clip=1.0e9, - noise_multiplier=0.0, - num_microbatches=1, - learning_rate=1.0)) + model_fn=self._make_linear_model_fn(cls, 1.0e9, 0.0, num_microbatches, + 1.0)) unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1) kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel') @@ -295,12 +305,8 @@ def train_input_fn(): global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value]))) clipped_linear_regressor = tf_estimator.Estimator( - model_fn=self._make_linear_model_fn( - optimizer_class=optimizer_class, - l2_norm_clip=1.0, - noise_multiplier=0.0, - num_microbatches=1, - learning_rate=1.0)) + model_fn=self._make_linear_model_fn(cls, 1.0, 0.0, num_microbatches, + 1.0)) clipped_linear_regressor.train(input_fn=train_input_fn, steps=1) self.assertAllClose( @@ -315,29 +321,29 @@ def train_input_fn(): # Parameters for testing: optimizer, l2_norm_clip, noise_multiplier, # num_microbatches. @parameterized.named_parameters( - ('DPGradientDescent_2_4_1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, + ('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, 4.0, 1), - ('DPGradientDescent_3_2_4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0, + ('DPGradientDescent 3 2 4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0, 2.0, 4), - ('DPGradientDescent_8_6_8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0, + ('DPGradientDescent 8 6 8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0, 6.0, 8), - ('DPGradientDescentVectorized_2_4_1', + ('DPGradientDescentVectorized 2 4 1', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.0, 4.0, 1), - ('DPGradientDescentVectorized_3_2_4', + ('DPGradientDescentVectorized 3 2 4', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 3.0, 2.0, 4), - ('DPGradientDescentVectorized_8_6_8', + ('DPGradientDescentVectorized 8 6 8', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 8.0, 6.0, 8), ) - def testNoiseMultiplier(self, optimizer_class, l2_norm_clip, noise_multiplier, + def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, num_microbatches): """Tests that DP optimizers work with tf.estimator.""" linear_regressor = tf_estimator.Estimator( model_fn=self._make_linear_model_fn( - optimizer_class, + cls, l2_norm_clip, noise_multiplier, num_microbatches, @@ -371,9 +377,9 @@ def train_input_fn(): ('DPAdamVectorized', dp_optimizer_keras_vectorized.VectorizedDPKerasAdamOptimizer), ) - def testRaisesOnNoCallOfGetGradients(self, optimizer_class): + def testAssertOnNoCallOfGetGradients(self, cls): """Tests that assertion fails when DP gradients are not computed.""" - optimizer = optimizer_class( + opt = cls( l2_norm_clip=100.0, noise_multiplier=0.0, num_microbatches=1, @@ -381,7 +387,7 @@ def testRaisesOnNoCallOfGetGradients(self, optimizer_class): with self.assertRaises(AssertionError): grads_and_vars = tf.Variable([0.0]) - optimizer.apply_gradients(grads_and_vars) + opt.apply_gradients(grads_and_vars) def testLargeBatchEmulationNoNoise(self): # Test for emulation of large batch training. @@ -402,7 +408,7 @@ def testLargeBatchEmulationNoNoise(self): x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32) loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1 - optimizer = dp_optimizer_keras.DPKerasSGDOptimizer( + opt = dp_optimizer_keras.DPKerasSGDOptimizer( l2_norm_clip=100.0, noise_multiplier=0.0, gradient_accumulation_steps=2, @@ -412,36 +418,35 @@ def testLargeBatchEmulationNoNoise(self): self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([3.0], var1) - optimizer.minimize(loss1, [var0, var1]) + opt.minimize(loss1, [var0, var1]) # After first call to optimizer values didn't change self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([3.0], var1) - optimizer.minimize(loss2, [var0, var1]) + opt.minimize(loss2, [var0, var1]) # After second call to optimizer updates were applied self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) self.assertAllCloseAccordingToType([2.0], var1) - optimizer.minimize(loss2, [var0, var1]) + opt.minimize(loss2, [var0, var1]) # After third call to optimizer values didn't change self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) self.assertAllCloseAccordingToType([2.0], var1) - optimizer.minimize(loss2, [var0, var1]) + opt.minimize(loss2, [var0, var1]) # After fourth call to optimizer updates were applied again self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0) self.assertAllCloseAccordingToType([1.0], var1) @parameterized.named_parameters( - ('DPKerasSGDOptimizer_1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), - ('DPKerasSGDOptimizer_2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), - ('DPKerasSGDOptimizer_4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), - ('DPKerasAdamOptimizer_2', dp_optimizer_keras.DPKerasAdamOptimizer, 1), - ('DPKerasAdagradOptimizer_2', dp_optimizer_keras.DPKerasAdagradOptimizer, + ('DPKerasSGDOptimizer 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), + ('DPKerasSGDOptimizer 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), + ('DPKerasSGDOptimizer 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), + ('DPKerasAdamOptimizer 2', dp_optimizer_keras.DPKerasAdamOptimizer, 1), + ('DPKerasAdagradOptimizer 2', dp_optimizer_keras.DPKerasAdagradOptimizer, 2), ) - def testLargeBatchEmulation(self, optimizer_class, - gradient_accumulation_steps): + def testLargeBatchEmulation(self, cls, gradient_accumulation_steps): # Tests various optimizers with large batch emulation. # Uses clipping and noise, thus does not test specific values # of the variables and only tests how often variables are updated. @@ -450,7 +455,7 @@ def testLargeBatchEmulation(self, optimizer_class, x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32) loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1 - optimizer = optimizer_class( + opt = cls( l2_norm_clip=100.0, noise_multiplier=0.0, gradient_accumulation_steps=gradient_accumulation_steps, @@ -459,7 +464,7 @@ def testLargeBatchEmulation(self, optimizer_class, for _ in range(gradient_accumulation_steps): self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([3.0], var1) - optimizer.minimize(loss, [var0, var1]) + opt.minimize(loss, [var0, var1]) self.assertNotAllClose([[1.0, 2.0]], var0) self.assertNotAllClose([3.0], var1) @@ -496,19 +501,19 @@ def call(self, inputs, training=None): return sequence_output, pooled_output -def keras_embedding_model_fn(optimizer_class, +def keras_embedding_model_fn(opt_cls, l2_norm_clip: float, noise_multiplier: float, num_microbatches: int, learning_rate: float, - use_sequence_output: bool = False, + use_seq_output: bool = False, unconnected_gradients_to_zero: bool = False): """Construct a simple embedding model with a classification layer.""" # Every sample has 4 tokens (sequence length=4). x = tf.keras.layers.Input(shape=(4,), dtype=tf.float32, name='input') sequence_output, pooled_output = SimpleEmbeddingModel()(x) - if use_sequence_output: + if use_seq_output: embedding = sequence_output else: embedding = pooled_output @@ -517,7 +522,7 @@ def keras_embedding_model_fn(optimizer_class, embedding) model = tf.keras.Model(inputs=x, outputs=probs, name='model') - optimizer = optimizer_class( + optimizer = opt_cls( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, num_microbatches=num_microbatches, @@ -557,7 +562,7 @@ class DPVectorizedOptimizerUnconnectedNodesTest(tf.test.TestCase, @parameterized.named_parameters( ('DPSGDVectorized_SeqOutput_UnconnectedGradients', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),) - def testSeqOutputUnconnectedGradientsAsNoneFails(self, optimizer_class): + def testSeqOutputUnconnectedGradientsAsNoneFails(self, cls): """Tests that DP vectorized optimizers with 'None' unconnected gradients fail. Sequence models that have unconnected gradients (with @@ -569,16 +574,16 @@ def testSeqOutputUnconnectedGradientsAsNoneFails(self, optimizer_class): These tests test the various combinations of this flag and the model. Args: - optimizer_class: The DP optimizer class to test. + cls: The DP optimizer class to test. """ embedding_model = keras_embedding_model_fn( - optimizer_class, + cls, l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1, learning_rate=1.0, - use_sequence_output=True, + use_seq_output=True, unconnected_gradients_to_zero=False) train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) @@ -600,16 +605,16 @@ def train_data_input_fn(): @parameterized.named_parameters( ('DPSGDVectorized_PooledOutput_UnconnectedGradients', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer),) - def testPooledOutputUnconnectedGradientsAsNonePasses(self, optimizer_class): + def testPooledOutputUnconnectedGradientsAsNonePasses(self, cls): """Tests that DP vectorized optimizers with 'None' unconnected gradients fail.""" embedding_model = keras_embedding_model_fn( - optimizer_class, + cls, l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1, learning_rate=1.0, - use_sequence_output=False, + use_seq_output=False, unconnected_gradients_to_zero=False) train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) @@ -633,17 +638,16 @@ def train_data_input_fn(): ('DPSGDVectorized_PooledOutput_UnconnectedGradientsAreZero', dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, False), ) - def testUnconnectedGradientsAsZeroPasses(self, optimizer_class, - use_sequence_output): + def testUnconnectedGradientsAsZeroPasses(self, cls, use_seq_output): """Tests that DP vectorized optimizers with 'Zero' unconnected gradients pass.""" embedding_model = keras_embedding_model_fn( - optimizer_class, + cls, l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1, learning_rate=1.0, - use_sequence_output=use_sequence_output, + use_seq_output=use_seq_output, unconnected_gradients_to_zero=True) train_data = np.random.randint(0, 10, size=(1000, 4), dtype=np.int32) @@ -660,6 +664,5 @@ def train_data_input_fn(): # other exceptions are errors. self.fail('ValueError raised by model.fit().') - if __name__ == '__main__': tf.test.main()