diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py index 0af7b36a2..436dcff3a 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation.py @@ -21,11 +21,11 @@ as helper functions for `tree_aggregation_query`. This module and helper functions are publicly accessible. """ + import abc import collections -from typing import Any, Callable, Collection, Optional, Tuple, Union +from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union -import attr import tensorflow as tf # TODO(b/192464750): find a proper place for the helper functions, privatize @@ -170,8 +170,7 @@ def next(self, state): return self.value_fn(), state -@attr.s(eq=False, frozen=True, slots=True) -class TreeState(object): +class TreeState(NamedTuple): """Class defining state of the tree. Attributes: @@ -183,9 +182,9 @@ class TreeState(object): for the most recent leaf node. value_generator_state: State of a stateful `ValueGenerator` for tree node. """ - level_buffer = attr.ib(type=tf.Tensor) - level_buffer_idx = attr.ib(type=tf.Tensor) - value_generator_state = attr.ib(type=Any) + level_buffer: tf.Tensor + level_buffer_idx: tf.Tensor + value_generator_state: Any # TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`. diff --git a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py index b73adc780..65bec4d4d 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py @@ -33,9 +33,11 @@ eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0] """ -import attr +from typing import Any, NamedTuple + import dp_accounting import tensorflow as tf + from tensorflow_privacy.privacy.dp_query import dp_query from tensorflow_privacy.privacy.dp_query import tree_aggregation @@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery): O(clip_norm*log(T)/eps) to guarantee eps-DP. """ - @attr.s(frozen=True) - class GlobalState(object): + class GlobalState(NamedTuple): """Class defining global state for Tree sum queries. Attributes: @@ -94,9 +95,9 @@ class GlobalState(object): clip_value: The clipping value to be passed to clip_fn. samples_cumulative_sum: Noiseless cumulative sum of samples over time. """ - tree_state = attr.ib() - clip_value = attr.ib() - samples_cumulative_sum = attr.ib() + tree_state: Any + clip_value: Any + samples_cumulative_sum: Any def __init__(self, record_specs, @@ -182,10 +183,11 @@ def get_noised_result(self, sample_state, global_state): global_state.tree_state) noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum, cumulative_sum_noise) - new_global_state = attr.evolve( - global_state, + new_global_state = TreeCumulativeSumQuery.GlobalState( + tree_state=new_tree_state, + clip_value=global_state.clip_value, samples_cumulative_sum=new_cumulative_sum, - tree_state=new_tree_state) + ) event = dp_accounting.UnsupportedDpEvent() return noised_cumulative_sum, new_global_state, event @@ -206,10 +208,11 @@ def reset_state(self, noised_results, global_state): state for the next cumulative sum. """ new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state) - return attr.evolve( - global_state, + return TreeCumulativeSumQuery.GlobalState( + tree_state=new_tree_state, + clip_value=global_state.clip_value, samples_cumulative_sum=noised_results, - tree_state=new_tree_state) + ) @classmethod def build_l2_gaussian_query(cls, @@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery): O(clip_norm*log(T)/eps) to guarantee eps-DP. """ - @attr.s(frozen=True) - class GlobalState(object): + class GlobalState(NamedTuple): """Class defining global state for Tree sum queries. Attributes: @@ -323,9 +325,9 @@ class GlobalState(object): previous_tree_noise: Cumulative noise by tree aggregation from the previous time the query is called on a sample. """ - tree_state = attr.ib() - clip_value = attr.ib() - previous_tree_noise = attr.ib() + tree_state: Any + clip_value: Any + previous_tree_noise: Any def __init__(self, record_specs, @@ -426,8 +428,11 @@ def get_noised_result(self, sample_state, global_state): noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c, sample_state, tree_noise, global_state.previous_tree_noise) - new_global_state = attr.evolve( - global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state) + new_global_state = TreeResidualSumQuery.GlobalState( + tree_state=new_tree_state, + clip_value=global_state.clip_value, + previous_tree_noise=tree_noise, + ) event = dp_accounting.UnsupportedDpEvent() return noised_sample, new_global_state, event @@ -448,10 +453,11 @@ def reset_state(self, noised_results, global_state): """ del noised_results new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state) - return attr.evolve( - global_state, + return TreeResidualSumQuery.GlobalState( + tree_state=new_tree_state, + clip_value=global_state.clip_value, previous_tree_noise=self._zero_initial_noise(), - tree_state=new_tree_state) + ) def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev): noise_generator_state = global_state.tree_state.value_generator_state @@ -459,10 +465,16 @@ def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev): tree_aggregation.GaussianNoiseGenerator) noise_generator_state = self._tree_aggregator.value_generator.make_state( noise_generator_state.seeds, stddev) - new_tree_state = attr.evolve( - global_state.tree_state, value_generator_state=noise_generator_state) - return attr.evolve( - global_state, clip_value=clip_norm, tree_state=new_tree_state) + new_tree_state = tree_aggregation.TreeState( + level_buffer=global_state.tree_state.level_buffer, + level_buffer_idx=global_state.tree_state.level_buffer_idx, + value_generator_state=noise_generator_state, + ) + return TreeResidualSumQuery.GlobalState( + tree_state=new_tree_state, + clip_value=clip_norm, + previous_tree_noise=global_state.previous_tree_noise, + ) @classmethod def build_l2_gaussian_query(cls, diff --git a/tensorflow_privacy/privacy/dp_query/tree_range_query.py b/tensorflow_privacy/privacy/dp_query/tree_range_query.py index d86cc89f4..9cfee022c 100644 --- a/tensorflow_privacy/privacy/dp_query/tree_range_query.py +++ b/tensorflow_privacy/privacy/dp_query/tree_range_query.py @@ -18,9 +18,8 @@ import distutils import math -from typing import Optional +from typing import Any, NamedTuple, Optional -import attr import dp_accounting import tensorflow as tf from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query @@ -102,8 +101,7 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery): Improves efficiency and reduces noise scale. """ - @attr.s(frozen=True) - class GlobalState(object): + class GlobalState(NamedTuple): """Class defining global state for TreeRangeSumQuery. Attributes: @@ -111,8 +109,8 @@ class GlobalState(object): internal node has). inner_query_state: The global state of the inner query. """ - arity = attr.ib() - inner_query_state = attr.ib() + arity: Any + inner_query_state: Any def __init__(self, inner_query: dp_query.SumAggregationDPQuery,