diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index 79bc243e..5717e4fc 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -25,10 +25,10 @@ """ import distutils import math -import attr +from typing import Optional +import attr import tensorflow as tf - from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import tree_aggregation @@ -442,16 +442,20 @@ def _loop_body(i, h): return tree -def _get_add_noise(stddev): +def _get_add_noise(stddev, seed: int = None): """Utility function to decide which `add_noise` to use according to tf version.""" if distutils.version.LooseVersion( tf.__version__) < distutils.version.LooseVersion('2.0.0'): + # The seed should be only used for testing purpose. + if seed is not None: + tf.random.set_seed(seed) + def add_noise(v): return v + tf.random.normal( tf.shape(input=v), stddev=stddev, dtype=v.dtype) else: - random_normal = tf.random_normal_initializer(stddev=stddev) + random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed) def add_noise(v): return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype) @@ -478,17 +482,16 @@ class GlobalState(object): """Class defining global state for `CentralTreeSumQuery`. Attributes: - stddev: The stddev of the noise added to each node in the tree. - arity: The branching factor of the tree (i.e. the number of children each - internal node has). l1_bound: An upper bound on the L1 norm of the input record. This is needed to bound the sensitivity and deploy differential privacy. """ - stddev = attr.ib() - arity = attr.ib() l1_bound = attr.ib() - def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10): + def __init__(self, + stddev: float, + arity: int = 2, + l1_bound: int = 10, + seed: Optional[int] = None): """Initializes the `CentralTreeSumQuery`. Args: @@ -497,15 +500,17 @@ def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10): arity: The branching factor of the tree. l1_bound: An upper bound on the L1 norm of the input record. This is needed to bound the sensitivity and deploy differential privacy. + seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for + test purpose. """ self._stddev = stddev self._arity = arity self._l1_bound = l1_bound + self._seed = seed def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" - return CentralTreeSumQuery.GlobalState( - stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound) + return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound) def derive_sample_params(self, global_state): """Implements `tensorflow_privacy.DPQuery.derive_sample_params`.""" @@ -536,10 +541,9 @@ def get_noised_result(self, sample_state, global_state): The jth node on the ith layer of the tree can be accessed by tree[i][j] where tree is the returned value. """ - add_noise = _get_add_noise(self._stddev) - tree = _build_tree_from_leaf(sample_state, global_state.arity) - return tf.nest.map_structure( - add_noise, tree, expand_composites=True), global_state + add_noise = _get_add_noise(self._stddev, self._seed) + tree = _build_tree_from_leaf(sample_state, self._arity) + return tf.map_fn(add_noise, tree), global_state class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery): @@ -577,7 +581,11 @@ class GlobalState(object): arity = attr.ib() l1_bound = attr.ib() - def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10): + def __init__(self, + stddev: float, + arity: int = 2, + l1_bound: int = 10, + seed: Optional[int] = None): """Initializes the `DistributedTreeSumQuery`. Args: @@ -585,10 +593,13 @@ def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10): arity: The branching factor of the tree. l1_bound: An upper bound on the L1 norm of the input record. This is needed to bound the sensitivity and deploy differential privacy. + seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for + test purpose. """ self._stddev = stddev self._arity = arity self._l1_bound = l1_bound + self._seed = seed def initial_global_state(self): """Implements `tensorflow_privacy.DPQuery.initial_global_state`.""" @@ -628,9 +639,9 @@ def preprocess_record(self, params, record): use_norm=l1_norm) preprocessed_record = preprocessed_record[0] - add_noise = _get_add_noise(self._stddev) + add_noise = _get_add_noise(self._stddev, self._seed) tree = _build_tree_from_leaf(preprocessed_record, arity) - noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True) + noisy_tree = tf.map_fn(add_noise, tree) # The following codes reshape the output vector so the output shape of can # be statically inferred. This is useful when used with diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py index 34f2c9ce..cc3a89aa 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py @@ -502,21 +502,15 @@ def test_get_noised_result(self, arity, record, expected_tree): ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), ) def test_get_noised_result_with_noise(self, stddev, record, expected_tree): - query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev) + query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) preprocessed_record = query.preprocess_record(params, record) - sample_state_list = [] - for _ in range(1000): - sample_state, _ = query.get_noised_result(preprocessed_record, - global_state) - sample_state_list.append(sample_state.flat_values.numpy()) - expectation = np.mean(sample_state_list, axis=0) - variance = np.std(sample_state_list, axis=0) - - self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4) + + sample_state, _ = query.get_noised_result(preprocessed_record, global_state) + self.assertAllClose( - variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4) + sample_state.flat_values, expected_tree, atol=3 * stddev) @parameterized.named_parameters( ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32), @@ -556,8 +550,7 @@ def test_initial_global_state_type(self): def test_derive_sample_params(self): query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD) global_state = query.initial_global_state() - stddev, arity, l1_bound = query.derive_sample_params( - global_state) + stddev, arity, l1_bound = query.derive_sample_params(global_state) self.assertAllClose(stddev, NOISE_STD) self.assertAllClose(arity, 2) self.assertAllClose(l1_bound, 10) @@ -587,21 +580,14 @@ def test_preprocess_record(self, arity, record, expected_tree): ('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]), ) def test_preprocess_record_with_noise(self, stddev, record, expected_tree): - query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev) + query = tree_aggregation_query.DistributedTreeSumQuery( + stddev=stddev, seed=0) global_state = query.initial_global_state() params = query.derive_sample_params(global_state) - preprocessed_record_list = [] - for _ in range(1000): - preprocessed_record = query.preprocess_record(params, record) - preprocessed_record_list.append(preprocessed_record.numpy()) - - expectation = np.mean(preprocessed_record_list, axis=0) - variance = np.std(preprocessed_record_list, axis=0) + preprocessed_record = query.preprocess_record(params, record) - self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4) - self.assertAllClose( - variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4) + self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev) @parameterized.named_parameters( ('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),