Skip to content

Commit

Permalink
Sparsity Preserving DP-SGD in TF Privacy
Browse files Browse the repository at this point in the history
Add function to merge varname_to_contribution_count_fn maps from different layers.

See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm.

PiperOrigin-RevId: 660525767
  • Loading branch information
tensorflower-gardener committed Aug 19, 2024
1 parent 38d80ca commit 26bb0a6
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 7 deletions.
9 changes: 8 additions & 1 deletion tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,19 @@ licenses(["notice"])
py_library(
name = "sparse_noise_utils",
srcs = ["sparse_noise_utils.py"],
deps = [
":type_aliases",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
)

py_test(
name = "sparse_noise_utils_test",
srcs = ["sparse_noise_utils_test.py"],
deps = [":sparse_noise_utils"],
deps = [
":sparse_noise_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
],
)

py_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
"""

import collections
from typing import Mapping, Optional, Sequence

from scipy import stats
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
import tensorflow_probability as tfp


Expand Down Expand Up @@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
)


def extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list: Sequence[
gradient_clipping_utils.RegistryGeneratorFunctionOutput
],
trainable_vars: Sequence[tf.Variable],
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
"""Extracts a map of contribution count fns from generator outputs.
Args:
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
instances returned by
`gradient_clipping_utils.model_forward_backward_pass`.
trainable_vars: A list of trainable variables.
Returns:
A `dict` from varname to contribution counts functions
"""
if trainable_vars is not None:
# Create a set using `ref()` for fast set membership check. tf.Variable
# itself is not hashable.
trainable_vars = set([v.ref() for v in trainable_vars])

varname_to_contribution_counts_fns = collections.defaultdict(list)
for registry_fn_output in registry_fn_outputs_list:
if trainable_vars is None or any(
w.ref() in trainable_vars
for w in registry_fn_output.layer_trainable_weights
):
if registry_fn_output.varname_to_count_contribution_fn is not None:
duplicate_varnames = set(
registry_fn_output.varname_to_count_contribution_fn.keys()
) & set(varname_to_contribution_counts_fns.keys())
if duplicate_varnames:
raise ValueError(
'Duplicate varnames: {duplicate_varnames} found in contribution'
' counts functions.'
)
varname_to_contribution_counts_fns.update(
registry_fn_output.varname_to_count_contribution_fn
)
return varname_to_contribution_counts_fns


def get_contribution_counts(
trainable_vars: list[tf.Variable],
grads: list[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
) -> list[tf.Tensor | None]:
trainable_vars: Sequence[tf.Variable],
grads: Sequence[tf.Tensor],
varname_to_contribution_counts_fns: Mapping[
str, type_aliases.ContributionCountHistogramFn
],
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
"""Gets the contribution counts for each variable in the Model.
Args:
trainable_vars: A list of the trainable variables in the Model.
trainable_vars: A list of trainable variables.
grads: A corresponding list of gradients for each trainable variable.
varname_to_contribution_counts_fns: A mapping from variable name to a list
of functions to get the contribution counts for that variable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from scipy import stats
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils


Expand Down Expand Up @@ -436,6 +437,96 @@ def test_add_sparse_noise_with_noise(self):
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
)

def test_extract_varname_to_contribution_counts_fns(self):
def fn1(_):
return 1.0

def fn2(_):
return 2.0

var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
var3 = tf.Variable(tf.ones((1, 2)), name='var3')

registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn=None,
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var2:0': [fn2],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer3',
layer_vars=[var3],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var3],
varname_to_count_contribution_fn={
'var3:0': [fn1],
},
),
]
expected_varname_to_contribution_counts_fns = {
'var2:0': [fn2],
'var3:0': [fn1],
}
varname_to_contribution_counts_fns = (
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)
)
self.assertEqual(
varname_to_contribution_counts_fns,
expected_varname_to_contribution_counts_fns,
)

def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
def fn1(_):
return 1.0

def fn2(_):
return 2.0

var1 = tf.Variable(tf.ones((1, 2)), name='var1')
var2 = tf.Variable(tf.ones((1, 2)), name='var1')

registry_fn_outputs_list = [
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer1',
layer_vars=[var1],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var1],
varname_to_count_contribution_fn={
'var1:0': [fn1],
},
),
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
layer_id='layer2',
layer_vars=[var2],
layer_sqr_norm_fn=None,
layer_trainable_weights=[var2],
varname_to_count_contribution_fn={
'var1:0': [fn2],
},
),
]

with self.assertRaises(ValueError):
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
registry_fn_outputs_list,
trainable_vars=None,
)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SparseGradient = tf.IndexedSlices
ContributionCountHistogram = tf.SparseTensor
ContributionCountHistogramFn = Callable[
[SparseGradient], Mapping[str, ContributionCountHistogram]
[SparseGradient], ContributionCountHistogram
]
NumMicrobatches = int | tf.Tensor
SparsityPreservingNoiseLayerRegistryFunction = Callable[
Expand Down

0 comments on commit 26bb0a6

Please sign in to comment.