Skip to content

Commit 71f6c77

Browse files
Sparsity Preserving DP-SGD in TF Privacy
Add function to merge varname_to_contribution_count_fn maps from different layers. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 660525767
1 parent e42b574 commit 71f6c77

File tree

10 files changed

+273
-8
lines changed

10 files changed

+273
-8
lines changed

tensorflow_privacy/privacy/fast_gradient_clipping/BUILD

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ py_library(
4646
":common_manip_utils",
4747
":layer_registry",
4848
":type_aliases",
49+
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
50+
"//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases",
4951
],
5052
)
5153

@@ -55,7 +57,11 @@ py_test(
5557
python_version = "PY3",
5658
shard_count = 8,
5759
srcs_version = "PY3",
58-
deps = [":gradient_clipping_utils"],
60+
deps = [
61+
":gradient_clipping_utils",
62+
":layer_registry",
63+
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
64+
],
5965
)
6066

6167
py_library(

tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def compute_gradient_norms(
164164
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
165165
tape=tape,
166166
layer_registry=layer_registry,
167+
sparse_noise_layer_registry=None,
167168
num_microbatches=num_microbatches,
168169
)
169170
layer_grad_vars, generator_outputs_list = (

tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _run_model_forward_backward_pass(
132132
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
133133
tape=tape,
134134
layer_registry=layer_registry.make_default_layer_registry(),
135+
sparse_noise_layer_registry=None,
135136
num_microbatches=None,
136137
)
137138
layer_grad_vars, registry_fn_outputs_list = (

tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,18 @@
2222
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
2323
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
2424
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
25+
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
26+
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases
2527

2628

2729
@dataclasses.dataclass(frozen=True)
2830
class RegistryGeneratorFunctionOutput:
2931
layer_id: str
3032
layer_vars: Optional[Sequence[tf.Variable]]
3133
layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction]
34+
varname_to_count_contribution_fn: Optional[
35+
dict[str, sn_type_aliases.ContributionCountHistogramFn]
36+
]
3237
layer_trainable_weights: Optional[Sequence[tf.Variable]]
3338

3439

@@ -46,6 +51,7 @@ def has_internal_compute_graph(input_object: Any):
4651
def get_registry_generator_fn(
4752
tape: tf.GradientTape,
4853
layer_registry: lr.LayerRegistry,
54+
sparse_noise_layer_registry: snlr.LayerRegistry,
4955
num_microbatches: Optional[type_aliases.BatchSize] = None,
5056
) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]:
5157
"""Creates the generator function for `model_forward_backward_pass()`.
@@ -58,6 +64,10 @@ def get_registry_generator_fn(
5864
`output` is the pre-activator tensor, `sqr_grad_norms` is related to the
5965
squared norms of a layer's pre-activation tensor, and `vars` are relevant
6066
trainable
67+
sparse_noise_layer_registry: A `LayerRegistry` instance containing functions
68+
that help compute contribution counts for sparse noise. See
69+
`tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
70+
more details.
6171
num_microbatches: An optional number or scalar `tf.Tensor` for the number of
6272
microbatches. If not None, indicates that the loss is grouped into
6373
num_microbatches (in this case, the batch dimension needs to be a multiple
@@ -83,6 +93,16 @@ def registry_generator_fn(layer_instance, args, kwargs):
8393
'be used for efficient gradient clipping.'
8494
% layer_instance.__class__.__name__
8595
)
96+
varname_to_count_contribution_fn = None
97+
if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem(
98+
layer_instance
99+
):
100+
count_contribution_registry_fn = sparse_noise_layer_registry.lookup(
101+
layer_instance
102+
)
103+
varname_to_count_contribution_fn = count_contribution_registry_fn(
104+
layer_instance, args, kwargs, num_microbatches
105+
)
86106
registry_fn = layer_registry.lookup(layer_instance)
87107
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
88108
layer_instance, args, kwargs, tape, num_microbatches
@@ -91,6 +111,7 @@ def registry_generator_fn(layer_instance, args, kwargs):
91111
layer_id=str(id(layer_instance)),
92112
layer_vars=layer_vars,
93113
layer_sqr_norm_fn=layer_sqr_norm_fn,
114+
varname_to_count_contribution_fn=varname_to_count_contribution_fn,
94115
layer_trainable_weights=layer_instance.trainable_weights,
95116
)
96117
else:

tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from absl.testing import parameterized
1818
import tensorflow as tf
1919
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
20+
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr
21+
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
2022

2123

2224
# ==============================================================================
@@ -175,5 +177,92 @@ def test_new_custom_layer_spec(self):
175177
)
176178

177179

180+
class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase):
181+
182+
def _get_sparse_layer_registry(self):
183+
def count_contribution_fn(_):
184+
return None
185+
186+
def registry_fn(*_):
187+
return {'var': count_contribution_fn}
188+
189+
registry = snlr.LayerRegistry()
190+
registry.insert(tf.keras.layers.Embedding, registry_fn)
191+
return registry, count_contribution_fn
192+
193+
def _get_layer_registry(self):
194+
var = tf.Variable(1.0)
195+
output = tf.ones((1, 1))
196+
197+
def sqr_norm_fn(_):
198+
return None
199+
200+
def registry_fn(*_):
201+
return [var], output, sqr_norm_fn
202+
203+
registry = lr.LayerRegistry()
204+
registry.insert(tf.keras.layers.Embedding, registry_fn)
205+
registry.insert(tf.keras.layers.Dense, registry_fn)
206+
return registry, var, output, sqr_norm_fn
207+
208+
def test_registry_generator_fn(self):
209+
inputs = tf.constant([[0, 1]])
210+
model = tf.keras.Sequential([
211+
tf.keras.layers.Embedding(10, 1),
212+
tf.keras.layers.Dense(1),
213+
])
214+
215+
sparse_layer_registry, count_contribution_fn = (
216+
self._get_sparse_layer_registry()
217+
)
218+
layer_registry, var, output, sqr_norm_fn = self._get_layer_registry()
219+
registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn(
220+
tape=tf.GradientTape(),
221+
layer_registry=layer_registry,
222+
sparse_noise_layer_registry=sparse_layer_registry,
223+
num_microbatches=None,
224+
)
225+
embedding_layer = model.layers[0]
226+
out, embedding_registry_generator_fn_output = registry_generator_fn(
227+
embedding_layer,
228+
[inputs],
229+
{},
230+
)
231+
expected_embedding_registry_generator_fn_output = (
232+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
233+
layer_id=str(id(embedding_layer)),
234+
layer_vars=[var],
235+
layer_sqr_norm_fn=sqr_norm_fn,
236+
varname_to_count_contribution_fn={'var': count_contribution_fn},
237+
layer_trainable_weights=embedding_layer.trainable_weights,
238+
)
239+
)
240+
self.assertEqual(
241+
embedding_registry_generator_fn_output,
242+
expected_embedding_registry_generator_fn_output,
243+
)
244+
self.assertEqual(out, output)
245+
dense_layer = model.layers[1]
246+
out, dense_registry_generator_fn_output = registry_generator_fn(
247+
dense_layer,
248+
[inputs],
249+
{},
250+
)
251+
expected_dense_registry_generator_fn_output = (
252+
gradient_clipping_utils.RegistryGeneratorFunctionOutput(
253+
layer_id=str(id(dense_layer)),
254+
layer_vars=[var],
255+
layer_sqr_norm_fn=sqr_norm_fn,
256+
varname_to_count_contribution_fn=None,
257+
layer_trainable_weights=dense_layer.trainable_weights,
258+
)
259+
)
260+
self.assertEqual(
261+
dense_registry_generator_fn_output,
262+
expected_dense_registry_generator_fn_output,
263+
)
264+
self.assertEqual(out, output)
265+
266+
178267
if __name__ == '__main__':
179268
tf.test.main()

tensorflow_privacy/privacy/keras_models/dp_keras_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def train_step(self, data):
280280
gradient_clipping_utils.get_registry_generator_fn(
281281
tape=tape,
282282
layer_registry=self._layer_registry,
283+
sparse_noise_layer_registry=None,
283284
num_microbatches=num_microbatches,
284285
)
285286
)

tensorflow_privacy/privacy/sparsity_preserving_noise/BUILD

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@ licenses(["notice"])
55
py_library(
66
name = "sparse_noise_utils",
77
srcs = ["sparse_noise_utils.py"],
8+
deps = [
9+
":type_aliases",
10+
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
11+
],
812
)
913

1014
py_test(
1115
name = "sparse_noise_utils_test",
1216
srcs = ["sparse_noise_utils_test.py"],
13-
deps = [":sparse_noise_utils"],
17+
deps = [
18+
":sparse_noise_utils",
19+
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
20+
],
1421
)
1522

1623
py_library(

tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
For more details on the algorithm, refer to https://arxiv.org/abs/2311.08357.
1717
"""
1818

19+
import collections
1920
from typing import Mapping, Optional, Sequence
2021

2122
from scipy import stats
2223
import tensorflow as tf
24+
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
25+
from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases
2326
import tensorflow_probability as tfp
2427

2528

@@ -288,15 +291,60 @@ def add_sparse_gradient_noise(
288291
)
289292

290293

294+
def extract_varname_to_contribution_counts_fns(
295+
registry_fn_outputs_list: Sequence[
296+
gradient_clipping_utils.RegistryGeneratorFunctionOutput
297+
],
298+
trainable_vars: Sequence[tf.Variable],
299+
) -> dict[str, type_aliases.ContributionCountHistogramFn]:
300+
"""Extracts a map of contribution count fns from generator outputs.
301+
302+
Args:
303+
registry_fn_outputs_list: A list of `RegistryGeneratorFunctionOutput`
304+
instances returned by
305+
`gradient_clipping_utils.model_forward_backward_pass`.
306+
trainable_vars: A list of trainable variables.
307+
308+
Returns:
309+
A `dict` from varname to contribution counts functions
310+
"""
311+
if trainable_vars is not None:
312+
# Create a set using `ref()` for fast set membership check. tf.Variable
313+
# itself is not hashable.
314+
trainable_vars = set([v.ref() for v in trainable_vars])
315+
316+
varname_to_contribution_counts_fns = collections.defaultdict(list)
317+
for registry_fn_output in registry_fn_outputs_list:
318+
if trainable_vars is None or any(
319+
w.ref() in trainable_vars
320+
for w in registry_fn_output.layer_trainable_weights
321+
):
322+
if registry_fn_output.varname_to_count_contribution_fn is not None:
323+
duplicate_varnames = set(
324+
registry_fn_output.varname_to_count_contribution_fn.keys()
325+
) & set(varname_to_contribution_counts_fns.keys())
326+
if duplicate_varnames:
327+
raise ValueError(
328+
'Duplicate varnames: {duplicate_varnames} found in contribution'
329+
' counts functions.'
330+
)
331+
varname_to_contribution_counts_fns.update(
332+
registry_fn_output.varname_to_count_contribution_fn
333+
)
334+
return varname_to_contribution_counts_fns
335+
336+
291337
def get_contribution_counts(
292-
trainable_vars: list[tf.Variable],
293-
grads: list[tf.Tensor],
294-
varname_to_contribution_counts_fns: Mapping[str, tf.SparseTensor],
295-
) -> list[tf.Tensor | None]:
338+
trainable_vars: Sequence[tf.Variable],
339+
grads: Sequence[tf.Tensor],
340+
varname_to_contribution_counts_fns: Mapping[
341+
str, type_aliases.ContributionCountHistogramFn
342+
],
343+
) -> list[type_aliases.ContributionCountHistogram | None]:
296344
"""Gets the contribution counts for each variable in the Model.
297345
298346
Args:
299-
trainable_vars: A list of the trainable variables in the Model.
347+
trainable_vars: A list of trainable variables.
300348
grads: A corresponding list of gradients for each trainable variable.
301349
varname_to_contribution_counts_fns: A mapping from variable name to a list
302350
of functions to get the contribution counts for that variable.

0 commit comments

Comments
 (0)