Skip to content

Commit

Permalink
Update differential_privacy module to convert DP events to and from…
Browse files Browse the repository at this point in the history
… `NamedTuple`s.

PiperOrigin-RevId: 525517433
  • Loading branch information
michaelreneer authored and tensorflower-gardener committed Apr 19, 2023
1 parent e362f51 commit 809d1f9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 39 deletions.
13 changes: 6 additions & 7 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
"""

import abc
import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union
from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union

import attr
import tensorflow as tf

# TODO(b/192464750): find a proper place for the helper functions, privatize
Expand Down Expand Up @@ -170,8 +170,7 @@ def next(self, state):
return self.value_fn(), state


@attr.s(eq=False, frozen=True, slots=True)
class TreeState(object):
class TreeState(NamedTuple):
"""Class defining state of the tree.
Attributes:
Expand All @@ -183,9 +182,9 @@ class TreeState(object):
for the most recent leaf node.
value_generator_state: State of a stateful `ValueGenerator` for tree node.
"""
level_buffer = attr.ib(type=tf.Tensor)
level_buffer_idx = attr.ib(type=tf.Tensor)
value_generator_state = attr.ib(type=Any)
level_buffer: tf.Tensor
level_buffer_idx: tf.Tensor
value_generator_state: Any


# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
Expand Down
64 changes: 38 additions & 26 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
"""

import attr
from typing import Any, NamedTuple

import dp_accounting
import tensorflow as tf

from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation

Expand Down Expand Up @@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP.
"""

@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for Tree sum queries.
Attributes:
Expand All @@ -94,9 +95,9 @@ class GlobalState(object):
clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
"""
tree_state = attr.ib()
clip_value = attr.ib()
samples_cumulative_sum = attr.ib()
tree_state: Any
clip_value: Any
samples_cumulative_sum: Any

def __init__(self,
record_specs,
Expand Down Expand Up @@ -182,10 +183,11 @@ def get_noised_result(self, sample_state, global_state):
global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
new_global_state = attr.evolve(
global_state,
new_global_state = TreeCumulativeSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
)
event = dp_accounting.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event

Expand All @@ -206,10 +208,11 @@ def reset_state(self, noised_results, global_state):
state for the next cumulative sum.
"""
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
return TreeCumulativeSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=noised_results,
tree_state=new_tree_state)
)

@classmethod
def build_l2_gaussian_query(cls,
Expand Down Expand Up @@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP.
"""

@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for Tree sum queries.
Attributes:
Expand All @@ -323,9 +325,9 @@ class GlobalState(object):
previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample.
"""
tree_state = attr.ib()
clip_value = attr.ib()
previous_tree_noise = attr.ib()
tree_state: Any
clip_value: Any
previous_tree_noise: Any

def __init__(self,
record_specs,
Expand Down Expand Up @@ -426,8 +428,11 @@ def get_noised_result(self, sample_state, global_state):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise,
global_state.previous_tree_noise)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
new_global_state = TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=tree_noise,
)
event = dp_accounting.UnsupportedDpEvent()
return noised_sample, new_global_state, event

Expand All @@ -448,21 +453,28 @@ def reset_state(self, noised_results, global_state):
"""
del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
)

def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state
assert isinstance(self._tree_aggregator.value_generator,
tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve(
global_state.tree_state, value_generator_state=noise_generator_state)
return attr.evolve(
global_state, clip_value=clip_norm, tree_state=new_tree_state)
new_tree_state = tree_aggregation.TreeState(
level_buffer=global_state.tree_state.level_buffer,
level_buffer_idx=global_state.tree_state.level_buffer_idx,
value_generator_state=noise_generator_state,
)
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=clip_norm,
previous_tree_noise=global_state.previous_tree_noise,
)

@classmethod
def build_l2_gaussian_query(cls,
Expand Down
10 changes: 4 additions & 6 deletions tensorflow_privacy/privacy/dp_query/tree_range_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

import distutils
import math
from typing import Optional
from typing import Any, NamedTuple, Optional

import attr
import dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
Expand Down Expand Up @@ -102,17 +101,16 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
Improves efficiency and reduces noise scale.
"""

@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for TreeRangeSumQuery.
Attributes:
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
inner_query_state: The global state of the inner query.
"""
arity = attr.ib()
inner_query_state = attr.ib()
arity: Any
inner_query_state: Any

def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,
Expand Down

0 comments on commit 809d1f9

Please sign in to comment.