Skip to content

Commit

Permalink
Add DPEvent to the state in DifferentiallyPrivateFactory's next_fn. T…
Browse files Browse the repository at this point in the history
…his requires implementing initial_sample_state in TreeRangeSumQuery for tests to pass.

PiperOrigin-RevId: 479436434
  • Loading branch information
tensorflower-gardener committed Oct 11, 2022
1 parent 0738d6f commit 2a94663
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
9 changes: 5 additions & 4 deletions tensorflow_privacy/privacy/dp_query/dp_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,11 @@ def derive_metrics(self, global_state):

def _zeros_like(arg):
"""A `zeros_like` function that also works for `tf.TensorSpec`s."""
try:
arg = tf.convert_to_tensor(value=arg)
except TypeError:
pass
if not isinstance(arg, tf.TensorSpec):
try:
arg = tf.convert_to_tensor(value=arg)
except TypeError:
pass
return tf.zeros(arg.shape, arg.dtype)


Expand Down
8 changes: 7 additions & 1 deletion tensorflow_privacy/privacy/dp_query/tree_range_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import distutils
import math
from typing import Optional
from typing import Any, Optional

import attr
import dp_accounting
Expand Down Expand Up @@ -136,6 +136,12 @@ def initial_global_state(self):
arity=self._arity,
inner_query_state=self._inner_query.initial_global_state())

def initial_sample_state(self, template: Optional[Any] = None):
"""Implements `tensorflow_privacy.DPQuery.initial_sample_state`."""
return self.preprocess_record(
self.derive_sample_params(self.initial_global_state()),
super().initial_sample_state(template))

def derive_sample_params(self, global_state):
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
return (global_state.arity,
Expand Down

0 comments on commit 2a94663

Please sign in to comment.