Skip to content

Commit

Permalink
Add some Tensorflow graph traversal utility functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 517108819
  • Loading branch information
tensorflower-gardener committed Mar 16, 2023
1 parent 043e8b5 commit c7f75eb
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 2 deletions.
14 changes: 14 additions & 0 deletions tensorflow_privacy/privacy/fast_gradient_clipping/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")

package(default_visibility = ["//visibility:public"])

py_library(
name = "tensorflow_graph_utils",
srcs = ["tensorflow_graph_utils.py"],
srcs_version = "PY3",
)

py_test(
name = "tensorflow_graph_utils_test",
srcs = ["tensorflow_graph_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":tensorflow_graph_utils"],
)

py_library(
name = "gradient_clipping_utils",
srcs = ["gradient_clipping_utils.py"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

InputTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]

GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
GeneratorFunction = Callable[[Any, Tuple, Dict], Tuple[Any, Any]]


def has_internal_compute_graph(input_object: Any):
Expand Down Expand Up @@ -52,7 +52,7 @@ def _get_internal_layers(
def model_forward_pass(
input_model: tf.keras.Model,
inputs: InputTensor,
generator_fn: GeneratorFunction = None,
generator_fn: Optional[GeneratorFunction] = None,
) -> Tuple[tf.Tensor, List[Any]]:
"""Does a forward pass of a model and returns useful intermediates.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions that help in traversing Tensorflow graphs."""

from typing import Any, Callable, Dict, Iterable, Optional, Set, Text, Union

import tensorflow as tf

PackedTensor = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]

LayerFunction = Callable[[tf.keras.layers.Layer], None]


def depth_first_backward_pass(
outputs: PackedTensor, layer_function: Optional[LayerFunction] = None
):
"""Performs a depth-first traversal on a given set of model outputs.
This function is simplified version of
`tf.keras.engine.functional._build_map()` that allows additional side-effects
performed by an (optional) layer function.
Args:
outputs: A `PackedTensor` that should be generated by calling a
`tf.keras.Model` on a set of non-eager inputs.
layer_function: A callable that consumes a `tf.keras.layers.Layer`. This
callable is applied to every layer in the DAG that generates `outputs`.
"""

# Helper function that performs the traversal.
def graph_crawler(
tensor: tf.Tensor, finished_nodes: Set[Any], nodes_in_progress: Set[Any]
):
layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access
node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
# Avoid duplicating work on shared subgraphs.
if node in finished_nodes:
return
# Check if we encountered a cycle.
if node in nodes_in_progress:
raise ValueError(
f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
)
# Apply side-effects and go to the next node (pre-order traversal).
if layer_function is not None:
layer_function(layer)
nodes_in_progress.add(node)
if not node.is_input:
for tensor in node.keras_inputs:
graph_crawler(tensor, finished_nodes, nodes_in_progress)
finished_nodes.add(node)
nodes_in_progress.remove(node)

# Traverse over the outputs.
finished_nodes = set()
nodes_in_progress = set()
for output in tf.nest.flatten(outputs):
graph_crawler(output, finished_nodes, nodes_in_progress)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
import tensorflow as tf

from tensorflow_privacy.privacy.fast_gradient_clipping import tensorflow_graph_utils


# ==============================================================================
# Main tests.
# ==============================================================================
class DepthFirstBackwardPassTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.product(
input_packing_type=[None, tuple, list, dict],
output_packing_type=[None, tuple, list, dict],
)
def test_layer_function(self, input_packing_type, output_packing_type):
num_dims = 3
num_inputs = 1 if input_packing_type is None else 2
num_outputs = 1 if output_packing_type is None else 2
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
temp_sum = tf.stack(sample_inputs, axis=0)
sample_sum = [
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
]
sample_outputs = [tf.keras.layers.Dense(3)(t) for t in sample_sum]

# Pack inputs.
if input_packing_type is None:
inputs = sample_inputs[0]
elif input_packing_type is not dict:
inputs = input_packing_type(sample_inputs)
else:
inputs = {}
keys = [str(i) for i in range(len(sample_inputs))]
for k, v in zip(keys, sample_inputs):
inputs[k] = v

# Pack outputs.
if output_packing_type is None:
outputs = sample_outputs[0]
elif output_packing_type is not dict:
outputs = output_packing_type(sample_outputs)
else:
outputs = {}
keys = [str(i) for i in range(len(sample_outputs))]
for k, v in zip(keys, sample_outputs):
outputs[k] = v

# Append the trainable layers into a list.
layer_list = []

def layer_function(layer):
if layer.trainable_variables:
layer_list.append(layer)

# Run the traversal and verify the outputs that are relevant to
# the above layer function.
tensorflow_graph_utils.depth_first_backward_pass(outputs, layer_function)
self.assertLen(layer_list, num_outputs)
for l in layer_list:
self.assertIsInstance(l, tf.keras.layers.Dense)


if __name__ == '__main__':
tf.test.main()

0 comments on commit c7f75eb

Please sign in to comment.