Skip to content

Commit

Permalink
Automated rollback of commit cff4768
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 470893869
  • Loading branch information
tensorflower-gardener committed Aug 30, 2022
1 parent cff4768 commit d5cc050
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 296 deletions.
5 changes: 0 additions & 5 deletions tensorflow_privacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,9 @@
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class

# Optimizers
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_generic_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class

from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer
Expand Down
21 changes: 3 additions & 18 deletions tensorflow_privacy/privacy/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ py_library(
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
)

py_library(
name = "dp_optimizer_factory",
srcs = [
"dp_optimizer_keras.py",
],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/dp_query",
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
],
)

py_library(
name = "dp_optimizer_vectorized",
srcs = [
Expand All @@ -44,10 +32,7 @@ py_library(
"dp_optimizer_keras.py",
],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/dp_query",
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
],
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
)

py_library(
Expand Down Expand Up @@ -99,7 +84,7 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dp_optimizer_keras",
":dp_optimizer_keras_vectorized",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras_vectorized",
],
)
224 changes: 73 additions & 151 deletions tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,21 @@
# limitations under the License.
# ==============================================================================
"""Differentially private version of Keras optimizer v2."""
from typing import Optional, Type
import warnings

import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query


def _normalize(microbatch_gradient: tf.Tensor,
num_microbatches: float) -> tf.Tensor:
"""Normalizes `microbatch_gradient` by `num_microbatches`."""
return tf.truediv(microbatch_gradient,
tf.cast(num_microbatches, microbatch_gradient.dtype))
from tensorflow_privacy.privacy.dp_query import gaussian_query


def make_keras_generic_optimizer_class(
cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private (DP) subclass of `cls`.
def make_keras_optimizer_class(cls):
"""Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
Returns:
A DP-SGD subclass of `cls`.
"""

class DPOptimizerClass(cls): # pylint: disable=empty-docstring
Expand Down Expand Up @@ -145,23 +138,24 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring

def __init__(
self,
dp_sum_query: dp_query.DPQuery,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
gradient_accumulation_steps=1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initializes the DPOptimizerClass.
"""Initialize the DPOptimizerClass.
Args:
dp_sum_query: `DPQuery` object, specifying differential privacy
mechanism to use.
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one
split. Default is `None` which means that number of microbatches
is equal to batch size (i.e. each microbatch contains exactly one
example). If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
microbatches is equal to
`num_microbatches * gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1
Expand All @@ -171,13 +165,13 @@ def __init__(
"""
super().__init__(*args, **kwargs)
self.gradient_accumulation_steps = gradient_accumulation_steps
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
self._dp_sum_query = dp_sum_query
self._was_dp_gradients_called = False
# We initialize the self.`_global_state` within the gradient functions
# (and not here) because tensors must be initialized within the graph.

self._dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
self._global_state = None
self._was_dp_gradients_called = False

def _create_slots(self, var_list):
super()._create_slots(var_list) # pytype: disable=attribute-error
Expand Down Expand Up @@ -241,62 +235,66 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP-SGD version of base class method."""

self._was_dp_gradients_called = True
if self._global_state is None:
self._global_state = self._dp_sum_query.initial_global_state()

# Compute loss.
if not callable(loss) and tape is None:
raise ValueError('`tape` is required when a `Tensor` loss is passed.')

tape = tape if tape is not None else tf.GradientTape()

with tape:
if callable(loss):
if callable(loss):
with tape:
if not callable(var_list):
tape.watch(var_list)

loss = loss()
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)

if callable(var_list):
var_list = var_list()
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)

if callable(var_list):
var_list = var_list()
else:
with tape:
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)

var_list = tf.nest.flatten(var_list)

sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))

# Compute the per-microbatch losses using helpful jacobian method.
with tf.keras.backend.name_scope(self._name + '/gradients'):
jacobian_per_var = tape.jacobian(
jacobian = tape.jacobian(
microbatch_losses, var_list, unconnected_gradients='zero')

def process_microbatch(sample_state, microbatch_jacobians):
"""Process one microbatch (record) with privacy helper."""
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, microbatch_jacobians)
return sample_state
# Clip gradients to given l2_norm_clip.
def clip_gradients(g):
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0]

sample_state = self._dp_sum_query.initial_sample_state(var_list)
for idx in range(num_microbatches):
microbatch_jacobians_per_var = [
jacobian[idx] for jacobian in jacobian_per_var
]
sample_state = process_microbatch(sample_state,
microbatch_jacobians_per_var)
clipped_gradients = tf.map_fn(clip_gradients, jacobian)

grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
final_grads = tf.nest.map_structure(_normalize, grad_sums,
[num_microbatches] * len(grad_sums))
def reduce_noise_normalize_batch(g):
# Sum gradients over all microbatches.
summed_gradient = tf.reduce_sum(g, axis=0)

return list(zip(final_grads, var_list))
# Add noise to summed gradients.
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(
tf.shape(input=summed_gradient), stddev=noise_stddev)
noised_gradient = tf.add(summed_gradient, noise)

# Normalize by number of microbatches and return.
return tf.truediv(noised_gradient,
tf.cast(num_microbatches, tf.float32))

final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients)

return list(zip(final_gradients, var_list))

def get_gradients(self, loss, params):
"""DP-SGD version of base class method."""
Expand Down Expand Up @@ -324,13 +322,17 @@ def process_microbatch(i, sample_state):
sample_state = self._dp_sum_query.initial_sample_state(params)
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)

grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))

final_grads = tf.nest.map_structure(
_normalize, grad_sums, [self._num_microbatches] * len(grad_sums))
def normalize(v):
try:
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
except TypeError:
return None

final_grads = tf.nest.map_structure(normalize, grad_sums)

return final_grads

Expand Down Expand Up @@ -366,87 +368,7 @@ def apply_gradients(self, *args, **kwargs):
return DPOptimizerClass


def make_gaussian_query_optimizer_class(cls):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
Args:
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
"""

def return_gaussian_query_optimizer(
l2_norm_clip: float,
noise_multiplier: float,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.
This function is a thin wrapper around
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
apply a `GaussianSumQuery` to any `DPOptimizerClass`.
When combined with stochastic gradient descent, this creates the canonical
DP-SGD algorithm of "Deep Learning with Differential Privacy"
(see https://arxiv.org/abs/1607.00133).
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one example).
If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1 then
updates will be applied on each optimizer step.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
return cls(
dp_sum_query=dp_sum_query,
num_microbatches=num_microbatches,
gradient_accumulation_steps=gradient_accumulation_steps,
*args,
**kwargs)

return return_gaussian_query_optimizer


def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.
For backwards compatibility, we create this symbol to match the previous
output of `make_keras_optimizer_class` but using the new logic.
Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
"""
warnings.warn(
'`make_keras_optimizer_class` will be depracated on 2023-02-23. '
'Please switch to `make_gaussian_query_optimizer_class` and the '
'generic optimizers (`make_keras_generic_optimizer_class`).')
return make_gaussian_query_optimizer_class(
make_keras_generic_optimizer_class(cls))


GenericDPAdagradOptimizer = make_keras_generic_optimizer_class(
DPKerasAdagradOptimizer = make_keras_optimizer_class(
tf.keras.optimizers.Adagrad)
GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.Adam)
GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.SGD)

# We keep the same names for backwards compatibility.
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdagradOptimizer)
DPKerasAdamOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdamOptimizer)
DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer)
DPKerasAdamOptimizer = make_keras_optimizer_class(tf.keras.optimizers.Adam)
DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.SGD)
Loading

0 comments on commit d5cc050

Please sign in to comment.