Skip to content

Commit 93c7e54

Browse files
Sparsity Preserving DP-SGD in TF Privacy
Add function to merge varname_to_contribution_count_fn maps from different layers. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 664906202
1 parent 38d80ca commit 93c7e54

File tree

4 files changed

+153
-7
lines changed

4 files changed

+153
-7
lines changed

tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@ licenses(["notice"])
55
py_library(
66
name = "sparse_noise_utils",
77
srcs = ["sparse_noise_utils.py"],
8+
deps = [
9+
":type_aliases",
10+
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
11+
],
812
)
913

1014
py_test(
1115
name = "sparse_noise_utils_test",
1216
srcs = ["sparse_noise_utils_test.py"],
13-
deps = [":sparse_noise_utils"],
17+
deps = [
18+
":sparse_noise_utils",
19+
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
20+
],
1421
)
1522

1623
py_library(

tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
1717
"""
1818

19+
import collections
1920
from typing import Mapping, Optional, Sequence
2021

2122
from scipy import stats
2223
import tensorflow as tf
24+
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
25+
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
2326
import tensorflow_probability as tfp
2427

2528

@@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
288291
)
289292

290293

294+
def extract_varname_to_contribution_counts_fns(
295+
registry_fn_outputs_list: Sequence[
296+
gradient_clipping_utils.RegistryGeneratorFunctionOutput
297+
],
298+
trainable_vars: Sequence[tf.Variable],
299+
) -> Mapping[str, type_aliases.ContributionCountHistogramFn]:
300+
"""Extracts a map of contribution count fns from generator outputs.
301+
302+
Args:
303+
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
304+
instances returned by
305+
`gradient_clipping_utils.model_forward_backward_pass`.
306+
trainable_vars: A list of trainable variables.
307+
308+
Returns:
309+
A `dict` from varname to contribution counts functions
310+
"""
311+
if trainable_vars is not None:
312+
# Create a set using `ref()` for fast set membership check. tf.Variable
313+
# itself is not hashable.
314+
trainable_vars = set([v.ref() for v in trainable_vars])
315+
316+
varname_to_contribution_counts_fns = collections.defaultdict(list)
317+
for registry_fn_output in registry_fn_outputs_list:
318+
if trainable_vars is None or any(
319+
w.ref() in trainable_vars
320+
for w in registry_fn_output.layer_trainable_weights
321+
):
322+
if registry_fn_output.varname_to_count_contribution_fn is not None:
323+
duplicate_varnames = set(
324+
registry_fn_output.varname_to_count_contribution_fn.keys()
325+
) & set(varname_to_contribution_counts_fns.keys())
326+
if duplicate_varnames:
327+
raise ValueError(
328+
'Duplicate varnames: {duplicate_varnames} found in contribution'
329+
' counts functions.'
330+
)
331+
varname_to_contribution_counts_fns.update(
332+
registry_fn_output.varname_to_count_contribution_fn
333+
)
334+
return varname_to_contribution_counts_fns
335+
336+
291337
def get_contribution_counts(
292-
trainable_vars: list[tf.Variable],
293-
grads: list[tf.Tensor],
294-
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
295-
) -> list[tf.Tensor | None]:
338+
trainable_vars: Sequence[tf.Variable],
339+
grads: Sequence[tf.Tensor],
340+
varname_to_contribution_counts_fns: Mapping[
341+
str, type_aliases.ContributionCountHistogramFn
342+
],
343+
) -> Sequence[type_aliases.ContributionCountHistogram | None]:
296344
"""Gets the contribution counts for each variable in the Model.
297345
298346
Args:
299-
trainable_vars: A list of the trainable variables in the Model.
347+
trainable_vars: A list of trainable variables.
300348
grads: A corresponding list of gradients for each trainable variable.
301349
varname_to_contribution_counts_fns: A mapping from variable name to a list
302350
of functions to get the contribution counts for that variable.

tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
from scipy import stats
1919
import tensorflow as tf
20+
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
2021
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
2122

2223

@@ -436,6 +437,96 @@ def test_add_sparse_noise_with_noise(self):
436437
np.all(np.not_equal(noised_grad_valid_indices, grad.values.numpy()))
437438
)
438439

440+
def test_extract_varname_to_contribution_counts_fns(self):
441+
def fn1(_):
442+
return 1.0
443+
444+
def fn2(_):
445+
return 2.0
446+
447+
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
448+
var2 = tf.Variable(tf.ones((1, 2)), name='var2')
449+
var3 = tf.Variable(tf.ones((1, 2)), name='var3')
450+
451+
registry_fn_outputs_list = [
452+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
453+
layer_id='layer1',
454+
layer_vars=[var1],
455+
layer_sqr_norm_fn=None,
456+
layer_trainable_weights=[var1],
457+
varname_to_count_contribution_fn=None,
458+
),
459+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
460+
layer_id='layer2',
461+
layer_vars=[var2],
462+
layer_sqr_norm_fn=None,
463+
layer_trainable_weights=[var2],
464+
varname_to_count_contribution_fn={
465+
'var2:0': [fn2],
466+
},
467+
),
468+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
469+
layer_id='layer3',
470+
layer_vars=[var3],
471+
layer_sqr_norm_fn=None,
472+
layer_trainable_weights=[var3],
473+
varname_to_count_contribution_fn={
474+
'var3:0': [fn1],
475+
},
476+
),
477+
]
478+
expected_varname_to_contribution_counts_fns = {
479+
'var2:0': [fn2],
480+
'var3:0': [fn1],
481+
}
482+
varname_to_contribution_counts_fns = (
483+
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
484+
registry_fn_outputs_list,
485+
trainable_vars=None,
486+
)
487+
)
488+
self.assertEqual(
489+
varname_to_contribution_counts_fns,
490+
expected_varname_to_contribution_counts_fns,
491+
)
492+
493+
def test_extract_varname_to_contribution_counts_fns_duplicate_varnames(self):
494+
def fn1(_):
495+
return 1.0
496+
497+
def fn2(_):
498+
return 2.0
499+
500+
var1 = tf.Variable(tf.ones((1, 2)), name='var1')
501+
var2 = tf.Variable(tf.ones((1, 2)), name='var1')
502+
503+
registry_fn_outputs_list = [
504+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
505+
layer_id='layer1',
506+
layer_vars=[var1],
507+
layer_sqr_norm_fn=None,
508+
layer_trainable_weights=[var1],
509+
varname_to_count_contribution_fn={
510+
'var1:0': [fn1],
511+
},
512+
),
513+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
514+
layer_id='layer2',
515+
layer_vars=[var2],
516+
layer_sqr_norm_fn=None,
517+
layer_trainable_weights=[var2],
518+
varname_to_count_contribution_fn={
519+
'var1:0': [fn2],
520+
},
521+
),
522+
]
523+
524+
with self.assertRaises(ValueError):
525+
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
526+
registry_fn_outputs_list,
527+
trainable_vars=None,
528+
)
529+
439530

440531
if __name__ == '__main__':
441532
tf.test.main()

tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SparseGradient = tf.IndexedSlices
2323
ContributionCountHistogram = tf.SparseTensor
2424
ContributionCountHistogramFn = Callable[
25-
[SparseGradient], Mapping[str, ContributionCountHistogram]
25+
[SparseGradient], ContributionCountHistogram
2626
]
2727
NumMicrobatches = int | tf.Tensor
2828
SparsityPreservingNoiseLayerRegistryFunction = Callable[

0 commit comments

Comments
 (0)