Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated rollback of commit cff47686f6cd9b6967f95b6eb366f3a4bf6f192f #313

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions tensorflow_privacy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,9 @@
from tensorflow_privacy.privacy.keras_models.dp_keras_model import make_dp_model_class

# Optimizers
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import GenericDPSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdagradOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasAdamOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_gaussian_query_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_generic_optimizer_class
from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import make_keras_optimizer_class

from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras_vectorized import VectorizedDPKerasAdagradOptimizer
Expand Down
21 changes: 3 additions & 18 deletions tensorflow_privacy/privacy/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ py_library(
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
)

py_library(
name = "dp_optimizer_factory",
srcs = [
"dp_optimizer_keras.py",
],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/dp_query",
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
],
)

py_library(
name = "dp_optimizer_vectorized",
srcs = [
Expand All @@ -44,10 +32,7 @@ py_library(
"dp_optimizer_keras.py",
],
srcs_version = "PY3",
deps = [
"//tensorflow_privacy/privacy/dp_query",
"//tensorflow_privacy/privacy/dp_query:gaussian_query",
],
deps = ["//tensorflow_privacy/privacy/dp_query:gaussian_query"],
)

py_library(
Expand Down Expand Up @@ -99,7 +84,7 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":dp_optimizer_keras",
":dp_optimizer_keras_vectorized",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras",
"//tensorflow_privacy/privacy/optimizers:dp_optimizer_keras_vectorized",
],
)
224 changes: 73 additions & 151 deletions tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,21 @@
# limitations under the License.
# ==============================================================================
"""Differentially private version of Keras optimizer v2."""
from typing import Optional, Type
import warnings

import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import gaussian_query


def _normalize(microbatch_gradient: tf.Tensor,
num_microbatches: float) -> tf.Tensor:
"""Normalizes `microbatch_gradient` by `num_microbatches`."""
return tf.truediv(microbatch_gradient,
tf.cast(num_microbatches, microbatch_gradient.dtype))
from tensorflow_privacy.privacy.dp_query import gaussian_query


def make_keras_generic_optimizer_class(
cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private (DP) subclass of `cls`.
def make_keras_optimizer_class(cls):
"""Given a subclass of `tf.keras.optimizers.Optimizer`, returns a DP-SGD subclass of it.

Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.

Returns:
A DP-SGD subclass of `cls`.
"""

class DPOptimizerClass(cls): # pylint: disable=empty-docstring
Expand Down Expand Up @@ -145,23 +138,24 @@ class DPOptimizerClass(cls): # pylint: disable=empty-docstring

def __init__(
self,
dp_sum_query: dp_query.DPQuery,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
gradient_accumulation_steps=1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initializes the DPOptimizerClass.
"""Initialize the DPOptimizerClass.

Args:
dp_sum_query: `DPQuery` object, specifying differential privacy
mechanism to use.
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one
split. Default is `None` which means that number of microbatches
is equal to batch size (i.e. each microbatch contains exactly one
example). If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
microbatches is equal to
`num_microbatches * gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1
Expand All @@ -171,13 +165,13 @@ def __init__(
"""
super().__init__(*args, **kwargs)
self.gradient_accumulation_steps = gradient_accumulation_steps
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
self._num_microbatches = num_microbatches
self._dp_sum_query = dp_sum_query
self._was_dp_gradients_called = False
# We initialize the self.`_global_state` within the gradient functions
# (and not here) because tensors must be initialized within the graph.

self._dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
self._global_state = None
self._was_dp_gradients_called = False

def _create_slots(self, var_list):
super()._create_slots(var_list) # pytype: disable=attribute-error
Expand Down Expand Up @@ -241,62 +235,66 @@ def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
"""DP-SGD version of base class method."""

self._was_dp_gradients_called = True
if self._global_state is None:
self._global_state = self._dp_sum_query.initial_global_state()

# Compute loss.
if not callable(loss) and tape is None:
raise ValueError('`tape` is required when a `Tensor` loss is passed.')

tape = tape if tape is not None else tf.GradientTape()

with tape:
if callable(loss):
if callable(loss):
with tape:
if not callable(var_list):
tape.watch(var_list)

loss = loss()
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)

if callable(var_list):
var_list = var_list()
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)

if callable(var_list):
var_list = var_list()
else:
with tape:
if self._num_microbatches is None:
num_microbatches = tf.shape(input=loss)[0]
else:
num_microbatches = self._num_microbatches
microbatch_losses = tf.reduce_mean(
tf.reshape(loss, [num_microbatches, -1]), axis=1)

var_list = tf.nest.flatten(var_list)

sample_params = (
self._dp_sum_query.derive_sample_params(self._global_state))

# Compute the per-microbatch losses using helpful jacobian method.
with tf.keras.backend.name_scope(self._name + '/gradients'):
jacobian_per_var = tape.jacobian(
jacobian = tape.jacobian(
microbatch_losses, var_list, unconnected_gradients='zero')

def process_microbatch(sample_state, microbatch_jacobians):
"""Process one microbatch (record) with privacy helper."""
sample_state = self._dp_sum_query.accumulate_record(
sample_params, sample_state, microbatch_jacobians)
return sample_state
# Clip gradients to given l2_norm_clip.
def clip_gradients(g):
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0]

sample_state = self._dp_sum_query.initial_sample_state(var_list)
for idx in range(num_microbatches):
microbatch_jacobians_per_var = [
jacobian[idx] for jacobian in jacobian_per_var
]
sample_state = process_microbatch(sample_state,
microbatch_jacobians_per_var)
clipped_gradients = tf.map_fn(clip_gradients, jacobian)

grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))
final_grads = tf.nest.map_structure(_normalize, grad_sums,
[num_microbatches] * len(grad_sums))
def reduce_noise_normalize_batch(g):
# Sum gradients over all microbatches.
summed_gradient = tf.reduce_sum(g, axis=0)

return list(zip(final_grads, var_list))
# Add noise to summed gradients.
noise_stddev = self._l2_norm_clip * self._noise_multiplier
noise = tf.random.normal(
tf.shape(input=summed_gradient), stddev=noise_stddev)
noised_gradient = tf.add(summed_gradient, noise)

# Normalize by number of microbatches and return.
return tf.truediv(noised_gradient,
tf.cast(num_microbatches, tf.float32))

final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
clipped_gradients)

return list(zip(final_gradients, var_list))

def get_gradients(self, loss, params):
"""DP-SGD version of base class method."""
Expand Down Expand Up @@ -324,13 +322,17 @@ def process_microbatch(i, sample_state):
sample_state = self._dp_sum_query.initial_sample_state(params)
for idx in range(self._num_microbatches):
sample_state = process_microbatch(idx, sample_state)

grad_sums, self._global_state, _ = (
self._dp_sum_query.get_noised_result(sample_state,
self._global_state))

final_grads = tf.nest.map_structure(
_normalize, grad_sums, [self._num_microbatches] * len(grad_sums))
def normalize(v):
try:
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
except TypeError:
return None

final_grads = tf.nest.map_structure(normalize, grad_sums)

return final_grads

Expand Down Expand Up @@ -366,87 +368,7 @@ def apply_gradients(self, *args, **kwargs):
return DPOptimizerClass


def make_gaussian_query_optimizer_class(cls):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.

Args:
cls: `DPOptimizerClass`, the output of `make_keras_optimizer_class`.
"""

def return_gaussian_query_optimizer(
l2_norm_clip: float,
noise_multiplier: float,
num_microbatches: Optional[int] = None,
gradient_accumulation_steps: int = 1,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Returns a `DPOptimizerClass` `cls` using the `GaussianSumQuery`.

This function is a thin wrapper around
`make_keras_optimizer_class.<locals>.DPOptimizerClass` which can be used to
apply a `GaussianSumQuery` to any `DPOptimizerClass`.

When combined with stochastic gradient descent, this creates the canonical
DP-SGD algorithm of "Deep Learning with Differential Privacy"
(see https://arxiv.org/abs/1607.00133).

Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients).
noise_multiplier: Ratio of the standard deviation to the clipping norm.
num_microbatches: Number of microbatches into which each minibatch is
split. Default is `None` which means that number of microbatches is
equal to batch size (i.e. each microbatch contains exactly one example).
If `gradient_accumulation_steps` is greater than 1 and
`num_microbatches` is not `None` then the effective number of
microbatches is equal to `num_microbatches *
gradient_accumulation_steps`.
gradient_accumulation_steps: If greater than 1 then optimizer will be
accumulating gradients for this number of optimizer steps before
applying them to update model weights. If this argument is set to 1 then
updates will be applied on each optimizer step.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
dp_sum_query = gaussian_query.GaussianSumQuery(
l2_norm_clip, l2_norm_clip * noise_multiplier)
return cls(
dp_sum_query=dp_sum_query,
num_microbatches=num_microbatches,
gradient_accumulation_steps=gradient_accumulation_steps,
*args,
**kwargs)

return return_gaussian_query_optimizer


def make_keras_optimizer_class(cls: Type[tf.keras.optimizers.Optimizer]):
"""Returns a differentially private optimizer using the `GaussianSumQuery`.

For backwards compatibility, we create this symbol to match the previous
output of `make_keras_optimizer_class` but using the new logic.

Args:
cls: Class from which to derive a DP subclass. Should be a subclass of
`tf.keras.optimizers.Optimizer`.
"""
warnings.warn(
'`make_keras_optimizer_class` will be depracated on 2023-02-23. '
'Please switch to `make_gaussian_query_optimizer_class` and the '
'generic optimizers (`make_keras_generic_optimizer_class`).')
return make_gaussian_query_optimizer_class(
make_keras_generic_optimizer_class(cls))


GenericDPAdagradOptimizer = make_keras_generic_optimizer_class(
DPKerasAdagradOptimizer = make_keras_optimizer_class(
tf.keras.optimizers.Adagrad)
GenericDPAdamOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.Adam)
GenericDPSGDOptimizer = make_keras_generic_optimizer_class(
tf.keras.optimizers.SGD)

# We keep the same names for backwards compatibility.
DPKerasAdagradOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdagradOptimizer)
DPKerasAdamOptimizer = make_gaussian_query_optimizer_class(
GenericDPAdamOptimizer)
DPKerasSGDOptimizer = make_gaussian_query_optimizer_class(GenericDPSGDOptimizer)
DPKerasAdamOptimizer = make_keras_optimizer_class(tf.keras.optimizers.Adam)
DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.SGD)
Loading