diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index 23de260b..7d7ba18f 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -22,7 +22,9 @@ PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]] -GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]] +GeneratorFunction = Callable[[Any, Tuple, Dict], Tuple[Any, Any]] + +LayerFunction = Callable[[tf.keras.layers.Layer], None] def has_internal_compute_graph(input_object: Any): @@ -52,7 +54,7 @@ def _get_internal_layers( def model_forward_pass( input_model: tf.keras.Model, inputs: PackedTensors, - generator_fn: GeneratorFunction = None, + generator_fn: Optional[GeneratorFunction] = None, ) -> Tuple[PackedTensors, List[Any]]: """Does a forward pass of a model and returns useful intermediates. @@ -211,6 +213,55 @@ def add_noise(g): return tf.nest.map_structure(add_noise, clipped_grads) +def depth_first_backward_pass( + outputs: PackedTensors, layer_function: Optional[LayerFunction] = None +): + """Performs a depth-first traversal on a given set of model outputs. + + This function is simplified version of + `tf.keras.engine.functional._build_map()` that allows additional side-effects + performed by an (optional) layer function. + + NOTE: The behavior, name, and implementation details of this function may + change in future versions. Users should avoid using it outside of this module. + + Args: + outputs: A `PackedTensor` that should be generated by calling a + `tf.keras.Model` on a set of non-eager inputs. + layer_function: A callable that consumes a `tf.keras.layers.Layer`. This + callable is applied to every layer in the DAG that generates `outputs`. + """ + + # Helper function that performs the traversal. + finished_nodes = set() + nodes_in_progress = set() + + def graph_crawler(tensor: tf.Tensor): + layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access + node = layer._inbound_nodes[node_index] # pylint: disable=protected-access + # Avoid duplicating work on shared subgraphs. + if node in finished_nodes: + return + # Check if we encountered a cycle. + if node in nodes_in_progress: + raise ValueError( + f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.' + ) + # Apply side-effects and go to the next node (pre-order traversal). + if layer_function is not None: + layer_function(layer) + nodes_in_progress.add(node) + if not node.is_input: + for tensor in node.keras_inputs: + graph_crawler(tensor) + finished_nodes.add(node) + nodes_in_progress.remove(node) + + # Traverse over the outputs. + for output in tf.nest.flatten(outputs): + graph_crawler(output) + + def generate_model_outputs_using_core_keras_layers( input_model: tf.keras.Model, ) -> PackedTensors: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py index f3c84b43..c4d1930b 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -75,5 +75,65 @@ def test_outputs_are_consistent( self.assertAllClose(computed_outputs, true_outputs) +class DepthFirstBackwardPassTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.product( + depth=[1, 2], + input_packing_type=[None, tuple, list, dict], + output_packing_type=[None, tuple, list, dict], + ) + def test_layer_function(self, depth, input_packing_type, output_packing_type): + num_dims = 3 + num_units = 5 + num_inputs = 1 if input_packing_type is None else 2 + num_outputs = 1 if output_packing_type is None else 2 + sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)] + temp_sum = tf.stack(sample_inputs, axis=0) + sample_sum = [ + tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs) + ] + sample_outputs = sample_sum + for _ in range(depth): + sample_outputs = [ + tf.keras.layers.Dense(num_units)(t) for t in sample_outputs + ] + + # Pack inputs. + if input_packing_type is None: + inputs = sample_inputs[0] + elif input_packing_type is not dict: + inputs = input_packing_type(sample_inputs) + else: + inputs = {} + keys = [str(i) for i in range(len(sample_inputs))] + for k, v in zip(keys, sample_inputs): + inputs[k] = v + + # Pack outputs. + if output_packing_type is None: + outputs = sample_outputs[0] + elif output_packing_type is not dict: + outputs = output_packing_type(sample_outputs) + else: + outputs = {} + keys = [str(i) for i in range(len(sample_outputs))] + for k, v in zip(keys, sample_outputs): + outputs[k] = v + + # Append the trainable layers into a list. + layer_list = [] + + def layer_function(layer): + if layer.trainable_variables: + layer_list.append(layer) + + # Run the traversal and verify the outputs that are relevant to + # the above layer function. + gradient_clipping_utils.depth_first_backward_pass(outputs, layer_function) + self.assertLen(layer_list, num_outputs * depth) + for l in layer_list: + self.assertIsInstance(l, tf.keras.layers.Dense) + + if __name__ == '__main__': tf.test.main()