Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy [5 of 5]
Browse files Browse the repository at this point in the history
Integrate sparsity preserving noise into DP Keras Model with fast gradient clipping.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 648402434
  • Loading branch information
tensorflower-gardener committed Aug 19, 2024
1 parent 93c7e54 commit b27a800
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
# Tensorflow aliases.
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]

PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]

InputTensors = PackedTensors

OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]]
OutputTensors = Union[Tensor, Iterable[Tensor]]

BatchSize = Union[int, tf.Tensor]

Expand Down
2 changes: 2 additions & 0 deletions tensorflow_privacy/privacy/keras_models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ py_library(
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
],
)

Expand Down
79 changes: 76 additions & 3 deletions tensorflow_privacy/privacy/keras_models/dp_keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,37 @@
# limitations under the License.
"""Keras Model for vectorized dpsgd with XLA acceleration."""

import dataclasses

from absl import logging
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils


_PRIVATIZED_LOSS_NAME = 'privatized_loss'


@dataclasses.dataclass
class SparsityPreservingDPSGDConfig:
"""Config for adding sparsity preserving noise to the gradients."""

# The ratio of how the noise is split between partition selection and gradient
# noise.
sparse_selection_ratio: float = 0.0
# The threshold to use for private partition selection.
sparse_selection_threshold: int = 100
# A `LayerRegistry` instance containing functions that help compute
# contribution counts for sparse layers. See
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
# more details.
sparse_selection_layer_registry: snlr.LayerRegistry | None = None


def make_dp_model_class(cls):
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""

Expand Down Expand Up @@ -104,6 +125,9 @@ def __init__(
num_microbatches=None,
use_xla=True,
layer_registry=None,
sparsity_preserving_dpsgd_config: (
SparsityPreservingDPSGDConfig | None
) = None,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs,
):
Expand All @@ -118,6 +142,9 @@ def __init__(
help compute gradient norms quickly. See
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
more details.
sparsity_preserving_dpsgd_config: If provided, uses partition selection
and sparse noise for privatizing sparse gradients for layers in
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__` method.
"""
Expand All @@ -127,6 +154,8 @@ def __init__(
self._layer_registry = layer_registry
self._clipping_loss = None

self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config

# Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API.
# In particular, boolean values supplied to `use_xla` in the earlier API
Expand Down Expand Up @@ -276,11 +305,16 @@ def train_step(self, data):
# microbatches is done here.
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)

sparse_noise_layer_registry = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_layer_registry = (
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
)
registry_generator_fn = (
gradient_clipping_utils.get_registry_generator_fn(
tape=tape,
layer_registry=self._layer_registry,
sparse_noise_layer_registry=None,
sparse_noise_layer_registry=sparse_noise_layer_registry,
num_microbatches=num_microbatches,
)
)
Expand Down Expand Up @@ -310,14 +344,53 @@ def train_step(self, data):
)
)
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
if self._noise_multiplier > 0:
noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
contribution_counts = None
if self._sparsity_preserving_dpsgd_config is not None:
logging.info('Using sparse noise.')

varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
self.trainable_variables,
)
)
contribution_counts = sparse_noise_utils.get_contribution_counts(
self.trainable_variables,
clipped_grads,
varname_to_contribution_counts_fns,
)

noise_multiplier_sparse, noise_multiplier = (
sparse_noise_utils.split_noise_multiplier(
noise_multiplier,
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
contribution_counts,
)
)
logging.info(
'Split noise multiplier for gradient noise: %s and partition'
' selection: %s',
noise_multiplier,
noise_multiplier_sparse,
)

if noise_multiplier > 0:
sparse_noise_config = None
if self._sparsity_preserving_dpsgd_config is not None:
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
sparse_noise_multiplier=noise_multiplier_sparse,
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
sparse_contribution_counts=contribution_counts,
)
grads = noise_utils.add_aggregate_noise(
clipped_grads,
num_microbatches,
self._l2_norm_clip,
self._noise_multiplier,
noise_multiplier,
loss_reduction=None,
loss_model=self,
sparse_noise_config=sparse_noise_config,
)
else:
grads = clipped_grads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def sample_true_positive_indices(
tf.shape(contribution_count_values),
mean=0.0,
stddev=noise_multiplier,
dtype=tf.float32,
dtype=contribution_count_values.dtype,
)
)
noised_contribution_counts_indices = contribution_counts.indices[
Expand Down Expand Up @@ -281,7 +281,7 @@ def add_sparse_gradient_noise(
"""
filtered_grad_values = tf.gather(grad, indices)
sparse_noise_values = tf.random.normal(
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
)
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
return tf.IndexedSlices(
Expand Down Expand Up @@ -362,15 +362,10 @@ def get_contribution_counts(
if var.name not in varname_to_contribution_counts_fns:
contribution_counts_list.append(None)
continue
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fns or not contribution_counts_fns[0]:
contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
if not contribution_counts_fn:
contribution_counts_list.append(None)
continue
if len(contribution_counts_fns) > 1:
raise NotImplementedError(
'Sparse noise is not supported for shared weight variables.'
)
contribution_counts_fn = contribution_counts_fns[0]
contribution_counts = contribution_counts_fn(grad)
contribution_counts_list.append(contribution_counts)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

InputArgs = Sequence[Any]
InputKwargs = Mapping[str, Any]
SparseGradient = tf.IndexedSlices
SparseGradient = tf.IndexedSlices | tf.SparseTensor
ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[
[SparseGradient], ContributionCountHistogram
Expand Down

0 comments on commit b27a800

Please sign in to comment.