-`**kwargs`
+**kwargs
|
Any other option for GATv2Conv, except sender_node_feature,
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2HomGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2HomGraphUpdate.md
index cbc66046..acc03341 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2HomGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2HomGraphUpdate.md
@@ -1,17 +1,10 @@
# gat_v2.GATv2HomGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GraphUpdate layer with a Graph Attention Network V2 (GATv2).
@@ -40,21 +33,26 @@ objects instead, such as the GATv2MPNNGraphUpdate.
> explicitly stored in the input GraphTensor. Attention of a node to itself
> requires having an explicit loop in the edge set.
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`num_heads`
+num_heads
|
The number of attention heads.
|
-`per_head_channels`
+per_head_channels
|
The number of channels for each attention head. This
@@ -62,22 +60,22 @@ means that the final output size will be per_head_channels * num_heads.
|
-`receiver_tag`
+receiver_tag
|
-one of `tfgnn.SOURCE` or `tfgnn.TARGET`.
+one of tfgnn.SOURCE or tfgnn.TARGET .
|
-`feature_name`
+feature_name
|
The feature name of node states; defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`heads_merge_type`
+heads_merge_type
|
"concat" or "mean". Gets passed to GATv2Conv, which uses
@@ -85,14 +83,14 @@ it to combine all heads into layer's output.
|
-`name`
+name
|
Optionally, a name for the layer returned.
|
-`**kwargs`
+**kwargs
|
Any optional arguments to GATv2Conv, see there.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2MPNNGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2MPNNGraphUpdate.md
index 0b67e958..ad5a214e 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2MPNNGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/GATv2MPNNGraphUpdate.md
@@ -1,17 +1,10 @@
# gat_v2.GATv2MPNNGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GraphUpdate layer for message passing with GATv2 pooling.
@@ -44,37 +37,42 @@ the messages and their pooling with attention, followed by a dense layer to
compute the new node states from a concatenation of the old node state and all
pooled messages.
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`units`
+units
|
The dimension of output hidden states for each node.
|
-`message_dim`
+message_dim
|
The dimension of messages (attention values) computed on
-each edge. Must be divisible by `num_heads`.
+each edge. Must be divisible by num_heads .
|
-`num_heads`
+num_heads
|
-The number of attention heads used by GATv2. `message_dim`
+The number of attention heads used by GATv2. message_dim
must be divisible by this number.
|
-`heads_merge_type`
+heads_merge_type
|
"concat" or "mean". Gets passed to GATv2Conv, which uses
@@ -82,15 +80,15 @@ it to combine all heads into layer's output.
|
-`receiver_tag`
+receiver_tag
|
-one of `tfgnn.TARGET` or `tfgnn.SOURCE`, to select the
+one of tfgnn.TARGET or tfgnn.SOURCE , to select the
incident node of each edge that receives the message.
|
-`node_set_names`
+node_set_names
|
The names of node sets to update. If unset, updates all
@@ -98,16 +96,16 @@ that are on the receiving end of any edge set.
|
-`edge_feature`
+edge_feature
|
Can be set to a feature name of the edge set to select
-it as an input feature. By default, this set to `None`, which disables
+it as an input feature. By default, this set to None , which disables
this input.
|
-`l2_regularization`
+l2_regularization
|
The coefficient of L2 regularization for weights and
@@ -115,7 +113,7 @@ biases.
|
-`edge_dropout_rate`
+edge_dropout_rate
|
The edge dropout rate applied during attention pooling
@@ -123,26 +121,26 @@ of edges.
|
-`state_dropout_rate`
+state_dropout_rate
|
The dropout rate applied to the resulting node states.
|
-`attention_activation`
+attention_activation
|
The nonlinearity used on the transformed inputs
before multiplying with the trained weights of the attention layer.
This can be specified as a Keras layer, a tf.keras.activations.*
-function, or a string understood by `tf.keras.layers.Activation()`.
+function, or a string understood by tf.keras.layers.Activation() .
Defaults to "leaky_relu", which in turn defaults to a negative slope
-of `alpha=0.2`.
+of alpha=0.2 .
|
-`conv_activation`
+conv_activation
|
The nonlinearity applied to the result of attention on one
@@ -150,7 +148,7 @@ edge set, specified in the same ways as attention_activation.
|
-`activation`
+activation
|
The nonlinearity applied to the new node states computed by
@@ -158,23 +156,24 @@ this graph update.
|
-`kernel_initializer`
+kernel_initializer
|
-Can be set to a `kernel_initializer` as understood
-by `tf.keras.layers.Dense` etc.
+Can be set to a kernel_initializer as understood
+by tf.keras.layers.Dense etc.
|
+
Returns |
A GraphUpdate layer for use on a scalar GraphTensor with
-`tfgnn.HIDDEN_STATE` features on the node sets.
+tfgnn.HIDDEN_STATE features on the node sets.
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_from_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_from_config_dict.md
index 858724c5..ef2f3c62 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_from_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_from_config_dict.md
@@ -1,23 +1,18 @@
# gat_v2.graph_update_from_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GATv2MPNNGraphUpdate initialized from `cfg`.
gat_v2.graph_update_from_config_dict(
- cfg: config_dict.ConfigDict
+ cfg: config_dict.ConfigDict,
+ *,
+ node_set_names: Optional[Collection[tfgnn.NodeSetName]] = None
) -> tf.keras.layers.Layer
@@ -30,15 +25,23 @@ Returns a GATv2MPNNGraphUpdate initialized from `cfg`.
-`cfg`
+cfg
|
-A `ConfigDict` with the fields defined by
-`graph_update_get_config_dict()`. All fields with non-`None` values are
+A ConfigDict with the fields defined by
+graph_update_get_config_dict() . All fields with non-None values are
used as keyword arguments for initializing and returning a
-`GATv2MPNNGraphUpdate` object. For the required arguments of
-`GATv2MPNNGraphUpdate.__init__`, users must set a value in
-`cfg` before passing it here.
+GATv2MPNNGraphUpdate object. For the required arguments of
+GATv2MPNNGraphUpdate.__init__ , users must set a value in
+cfg before passing it here.
+ |
+
+
+node_set_names
+ |
+
+Optionally, the names of NodeSets to update; forwarded to
+GATv2MPNNGraphUpdate.__init__ .
|
@@ -50,7 +53,7 @@ used as keyword arguments for initializing and returning a
| Returns |
-A new `GATv2MPNNGraphUpdate` object.
+A new GATv2MPNNGraphUpdate object.
|
@@ -64,11 +67,11 @@ A new `GATv2MPNNGraphUpdate` object.
-`TypeError`
+TypeError
|
-if `cfg` fails to supply a required argument for
-`GATv2MPNNGraphUpdate.__init__`.
+if cfg fails to supply a required argument for
+GATv2MPNNGraphUpdate.__init__ .
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_get_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_get_config_dict.md
index 0f48b98f..073dfac2 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_get_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gat_v2/graph_update_get_config_dict.md
@@ -1,17 +1,10 @@
# gat_v2.graph_update_get_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns ConfigDict for graph_update_from_config_dict() with defaults.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gcn.md b/tensorflow_gnn/docs/api_docs/python/models/gcn.md
index 0ecf4de6..103f9866 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gcn.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gcn.md
@@ -1,17 +1,10 @@
# Module: gcn
-[TOC]
-
-
+
+ View source
+on GitHub
Graph Convolutional Networks.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNConv.md b/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNConv.md
index d2890c55..ddbeeb2c 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNConv.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNConv.md
@@ -1,17 +1,10 @@
# gcn.GCNConv
-[TOC]
-
-
+
+ View source
+on GitHub
Implements the Graph Convolutional Network by Kipf&Welling (2016).
@@ -71,14 +64,18 @@ $$v_{ij} = w_{ij} / (\sqrt{\deg^{in}_i} \sqrt{\deg^{in}_j}).$$
For symmetric graphs (as in the original GCN paper), `"in_out"` and `"in_in"`
are equal, but the latter needs to compute degrees just once.
+This layer can be restored from config by `tf.keras.models.load_model()` when
+saved as part of a Keras model using `save_format="tf"`.
+
+
Init arguments |
-`units`
+units
|
Number of output units for this transformation applied to sender
@@ -86,26 +83,26 @@ node features.
|
-`receiver_tag`
+receiver_tag
|
This layer's result is obtained by pooling the per-edge
-results at this endpoint of each edge. The default is `tfgnn.TARGET`,
+results at this endpoint of each edge. The default is tfgnn.TARGET ,
but it is perfectly reasonable to do a convolution towards the
-`tfgnn.SOURCE` instead. (Source and target are conventional names for
+tfgnn.SOURCE instead. (Source and target are conventional names for
the incident nodes of a directed edge, data flow in a GNN may happen
in either direction.)
|
-`activation`
+activation
|
Keras activation to apply to the result, defaults to 'relu'.
|
-`use_bias`
+use_bias
|
Whether to add bias in the final transformation. The original
@@ -114,7 +111,7 @@ with Keras and other implementations.
|
-`add_self_loops`
+add_self_loops
|
Whether to compute the result as if a loop from each node
@@ -123,24 +120,24 @@ with an edge weight of one.
|
-`kernel_initializer`
+kernel_initializer
|
-Can be set to a `kernel_initializer` as understood
-by `tf.keras.layers.Dense` etc.
-An `Initializer` object gets cloned before use to ensure a fresh seed,
-if not set explicitly. For more, see `tfgnn.keras.clone_initializer()`.
+Can be set to a kernel_initializer as understood
+by tf.keras.layers.Dense etc.
+An Initializer object gets cloned before use to ensure a fresh seed,
+if not set explicitly. For more, see tfgnn.keras.clone_initializer() .
|
-`node_feature`
+node_feature
|
Name of the node feature to transform.
|
-`edge_weight_feature_name`
+edge_weight_feature_name
|
Can be set to the name of a feature on the edge
@@ -149,15 +146,15 @@ it as the edge's entry in the adjacency matrix, instead of the default 1.
|
-`degree_normalization`
+degree_normalization
|
-Can be set to `"none"`, `"in"`, `"out"`, `"in_out"`,
-or `"in_in"`, as explained above.
+Can be set to "none" , "in" , "out" , "in_out" ,
+or "in_in" , as explained above.
|
-`**kwargs`
+**kwargs
|
additional arguments for the Layer.
@@ -166,23 +163,24 @@ additional arguments for the Layer.
|
+
Call arguments |
-`graph`
+graph
|
The GraphTensor on which to apply the layer.
|
-`edge_set_name`
+edge_set_name
|
-Edge set of `graph` over which to apply the layer.
+Edge set of graph over which to apply the layer.
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNHomGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNHomGraphUpdate.md
index 1d3767e2..3b4663a6 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNHomGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/gcn/GCNHomGraphUpdate.md
@@ -1,17 +1,10 @@
# gcn.GCNHomGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a graph update layer for GCN convolution.
@@ -41,32 +34,37 @@ GCNConv objects instead.
> requires having an explicit loop in the edge set, or setting
> `add_self_loops=True`.
+Thie layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`units`
+units
|
The dimension of output hidden states for each node.
|
-`receiver_tag`
+receiver_tag
|
-The default is `tfgnn.TARGET`,
+The default is tfgnn.TARGET ,
but it is perfectly reasonable to do a convolution towards the
-`tfgnn.SOURCE` instead. (Source and target are conventional names for
+tfgnn.SOURCE instead. (Source and target are conventional names for
the incident nodes of a directed edge, data flow in a GNN may happen
in either direction.)
|
-`add_self_loops`
+add_self_loops
|
Whether to compute the result as if a loop from each node
@@ -74,22 +72,22 @@ to itself had been added to the edge set.
|
-`feature_name`
+feature_name
|
The feature name of node states; defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`name`
+name
|
Optionally, a name for the layer returned.
|
-`**kwargs`
+**kwargs
|
Any optional arguments to GCNConv, see there.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md b/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md
index e418b9a1..09192446 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/graph_sage.md
@@ -1,17 +1,10 @@
# Module: graph_sage
-[TOC]
-
-
+
+ View source
+on GitHub
GraphSAGE.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GCNGraphSAGENodeSetUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GCNGraphSAGENodeSetUpdate.md
index dc2240b4..415ce1a5 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GCNGraphSAGENodeSetUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GCNGraphSAGENodeSetUpdate.md
@@ -1,17 +1,10 @@
# graph_sage.GCNGraphSAGENodeSetUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
GCNGraphSAGENodeSetUpdate is an extension of the mean aggregator operator.
@@ -77,33 +70,37 @@ reduce operation, instead only the sender node states will be accumulated based
on the reduce_type specified. If share_weights is set to True, then single
weight matrix will be used in place of W_E and W_self.
+This layer can be restored from config by `tf.keras.models.load_model()` when
+saved as part of a Keras model using `save_format="tf"`.
+
+
Args |
-`edge_set_names`
+edge_set_names
|
A list of edge set names to broadcast sender node states.
|
-`receiver_tag`
+receiver_tag
|
-Either one of `tfgnn.SOURCE` or `tfgnn.TARGET`. The results
+Either one of tfgnn.SOURCE or tfgnn.TARGET . The results
of GraphSAGE convolution are aggregated for this graph piece. If set to
-`tfgnn.SOURCE` or `tfgnn.TARGET`, the layer will be called for each edge
+tfgnn.SOURCE or tfgnn.TARGET , the layer will be called for each edge
set and will aggregate results at the specified endpoint of the edges.
This should point at the node_set_name for each of the specified edge
set name in the edge_set_name_dict.
|
-`reduce_type`
+reduce_type
|
An aggregation operation name. Supported list of aggregation
@@ -111,24 +108,24 @@ operators are sum or mean.
|
-`self_node_feature`
+self_node_feature
|
Feature name for the self node sets to be aggregated
with the broadcasted sender node states. Default is
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`sender_node_feature`
+sender_node_feature
|
Feature name for the sender node sets. Default is
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`units`
+units
|
Number of output units for the linear transformation applied to
@@ -136,7 +133,7 @@ sender node and self node features.
|
-`dropout_rate`
+dropout_rate
|
Can be set to a dropout rate that will be applied to both
@@ -144,7 +141,7 @@ self node and the sender node states.
|
-`activation`
+activation
|
The nonlinearity applied to the update node states. This can
@@ -153,7 +150,7 @@ string understood by tf.keras.layers.Activation(). Defaults to relu.
|
-`use_bias`
+use_bias
|
If true a bias term will be added to mean aggregated feature
@@ -161,7 +158,7 @@ vectors before applying non-linear activation.
|
-`share_weights`
+share_weights
|
If left unset, separate weights are used to transform the
@@ -171,7 +168,7 @@ applied to all inputs.
|
-`add_self_loop`
+add_self_loop
|
If left at True (the default), each node state update takes
diff --git a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEAggregatorConv.md b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEAggregatorConv.md
index 3bf2a5cd..83cc06ac 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEAggregatorConv.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEAggregatorConv.md
@@ -1,17 +1,10 @@
# graph_sage.GraphSAGEAggregatorConv
-[TOC]
-
-
+
+ View source
+on GitHub
GraphSAGE: element-wise aggregation of neighbors and their linear
transformation.
@@ -45,42 +38,46 @@ besides "mean", see the reduce_type=... argument. For stateful transformation
with a hidden layer, see
graph_sage.GraphSAGEPoolingConv .
+This layer can be restored from config by `tf.keras.models.load_model()` when
+saved as part of a Keras model using `save_format="tf"`.
+
+
Args |
-`receiver_tag`
+receiver_tag
|
-Either one of `tfgnn.SOURCE` or `tfgnn.TARGET`. The results
+Either one of tfgnn.SOURCE or tfgnn.TARGET . The results
of GraphSAGE convolution are aggregated for this graph piece. If set to
-`tfgnn.SOURCE` or `tfgnn.TARGET`, the layer will be called for an edge
+tfgnn.SOURCE or tfgnn.TARGET , the layer will be called for an edge
set and will aggregate results at the specified endpoint of the edges.
|
-`reduce_type`
+reduce_type
|
An aggregation operation name. Supported list of aggregation
operators can be found at
-`tfgnn.get_registered_reduce_operation_names()`.
+tfgnn.get_registered_reduce_operation_names() .
|
-`sender_node_feature`
+sender_node_feature
|
Can be set to specify the feature name for use as the
input feature from sender nodes to GraphSAGE aggregation, defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`units`
+units
|
Number of output units for the linear transformation applied to
@@ -88,7 +85,7 @@ sender node features.
|
-`dropout_rate`
+dropout_rate
|
Can be set to a dropout rate that will be applied to sender
@@ -96,7 +93,7 @@ node features (independently on each edge).
|
-`**kwargs`
+**kwargs
|
Additional arguments for the Layer.
@@ -105,30 +102,31 @@ Additional arguments for the Layer.
|
+
Attributes |
-`takes_receiver_input`
+takes_receiver_input
|
-If `False`, all calls to convolve() will get `receiver_input=None`.
+If False , all calls to convolve() will get receiver_input=None .
|
-`takes_sender_edge_input`
+takes_sender_edge_input
|
-If `False`, all calls to convolve() will get `sender_edge_input=None`.
+If False , all calls to convolve() will get sender_edge_input=None .
|
-`takes_sender_node_input`
+takes_sender_node_input
|
-If `False`, all calls to convolve() will get `sender_node_input=None`.
+If False , all calls to convolve() will get sender_node_input=None .
|
@@ -137,7 +135,7 @@ If `False`, all calls to convolve() will get `sender_node_input=None`.
convolve
-View
+View
source
diff --git a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEGraphUpdate.md
index 3a3fb4c3..c7156448 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEGraphUpdate.md
@@ -1,17 +1,10 @@
# graph_sage.GraphSAGEGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GraphSAGE GraphUpdater layer for nodes in node_set_names.
@@ -43,13 +36,14 @@ applies only one step of GraphSAGE convolution over the incident nodes of the
edge_set_name_list for the specified node_set_name node.
+
Args |
-`units`
+units
|
Number of output units of the linear transformation applied to both
@@ -57,7 +51,7 @@ final aggregated sender node features as well as the self node feature.
|
-`hidden_units`
+hidden_units
|
Number of output units to be configure for GraphSAGE pooling
@@ -65,17 +59,17 @@ type convolution only.
|
-`receiver_tag`
+receiver_tag
|
-Either one of `tfgnn.SOURCE` or `tfgnn.TARGET`. The results of
-GraphSAGE are aggregated for this graph piece. When set to `tfgnn.SOURCE`
-or `tfgnn.TARGET`, the layer is called for an edge set and will aggregate
+Either one of tfgnn.SOURCE or tfgnn.TARGET . The results of
+GraphSAGE are aggregated for this graph piece. When set to tfgnn.SOURCE
+or tfgnn.TARGET , the layer is called for an edge set and will aggregate
results at the specified endpoint of the edges.
|
-`node_set_names`
+node_set_names
|
By default, this layer updates all node sets that receive
@@ -86,15 +80,15 @@ auxiliary node sets.
|
-`reduce_type`
+reduce_type
|
An aggregation operation name. Supported list of aggregation
-operators can be found at `tfgnn.get_registered_reduce_operation_names()`.
+operators can be found at tfgnn.get_registered_reduce_operation_names() .
|
-`use_pooling`
+use_pooling
|
If enabled, graph_sage.GraphSAGEPoolingConv will be used,
@@ -103,7 +97,7 @@ provided edges.
|
-`use_bias`
+use_bias
|
If true a bias term will be added to the linear transformations
@@ -111,7 +105,7 @@ for the incident node features as well as for the self node feature.
|
-`dropout_rate`
+dropout_rate
|
Can be set to a dropout rate that will be applied to both
@@ -119,7 +113,7 @@ incident node features as well as the self node feature.
|
-`l2_normalize`
+l2_normalize
|
If enabled l2 normalization will be applied to final node
@@ -127,7 +121,7 @@ states.
|
-`combine_type`
+combine_type
|
Can be set to "sum" or "concat". If it's specified as concat
@@ -136,32 +130,32 @@ node state will be added with the sender node features.
|
-`activation`
+activation
|
The nonlinearity applied to the concatenated or added node state
and aggregated sender node features. This can be specified as a Keras
layer, a tf.keras.activations.* function, or a string understood by
-`tf.keras.layers.Activation()`. Defaults to relu.
+tf.keras.layers.Activation() . Defaults to relu.
|
-`feature_name`
+feature_name
|
The feature name of node states; defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`name`
+name
|
Optionally, a name for the layer returned.
|
-`**kwargs`
+**kwargs
|
Any optional arguments to graph_sage.GraphSAGEPoolingConv ,
diff --git a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGENextState.md b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGENextState.md
index 62e06194..68d0c6c8 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGENextState.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGENextState.md
@@ -1,17 +1,10 @@
# graph_sage.GraphSAGENextState
-[TOC]
-
-
+
+ View source
+on GitHub
GraphSAGENextState: compute new node states with GraphSAGE algorithm.
@@ -60,7 +53,7 @@ equal, unless `combine_type="concat"`is set.
GraphSAGE is Algorithm 1 in Hamilton et al.:
["Inductive Representation Learning on Large Graphs"](https://arxiv.org/abs/1706.02216),
2017. It computes the new hidden state h_v for each node v from a concatenation
-of the previous hiddden state with an aggregation of the neighbor states as
+of the previous hidden state with an aggregation of the neighbor states as
$$h_v = \sigma(W \text{ concat}(h_v, h_{N(v)}))$$
@@ -87,13 +80,14 @@ Beyond the original GraphSAGE, this class supports:
* additional options to influence normalization, activation, etc.
+
Args |
-`units`
+units
|
Number of output units for the linear transformation applied to the
@@ -101,7 +95,7 @@ node feature.
|
-`use_bias`
+use_bias
|
If true a bias term will be added to the linear transformations
@@ -109,7 +103,7 @@ for the self node feature.
|
-`dropout_rate`
+dropout_rate
|
Can be set to a dropout rate that will be applied to the
@@ -117,15 +111,15 @@ node feature.
|
-`feature_name`
+feature_name
|
The feature name of node states; defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`l2_normalize`
+l2_normalize
|
If enabled l2 normalization will be applied to node state
@@ -133,7 +127,7 @@ vectors.
|
-`combine_type`
+combine_type
|
Can be set to "sum" or "concat". The default "sum" recovers
@@ -144,17 +138,17 @@ Setting this to "concat" concatenates the results of the transformations
|
-`activation`
+activation
|
The nonlinearity applied to the concatenated or added node
state and aggregated sender node features. This can be specified as a
Keras layer, a tf.keras.activations.* function, or a string understood
-by `tf.keras.layers.Activation()`. Defaults to relu.
+by tf.keras.layers.Activation() . Defaults to relu.
|
-`**kwargs`
+**kwargs
|
Forwarded to the base class tf.keras.layers.Layer.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEPoolingConv.md b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEPoolingConv.md
index 3d53c792..a23368ed 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEPoolingConv.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/graph_sage/GraphSAGEPoolingConv.md
@@ -1,17 +1,10 @@
# graph_sage.GraphSAGEPoolingConv
-[TOC]
-
-
+
+ View source
+on GitHub
GraphSAGE: pooling aggregator transform of neighbors followed by linear
transformation.
@@ -53,33 +46,37 @@ involves the aforementioned hidden layer. For element-wise aggregation (as in
`tfgnn.pool_edges_to_node()`), see
graph_sage.GraphSAGEAggregatorConv .
+This layer can be restored from config by `tf.keras.models.load_model()` when
+saved as part of a Keras model using `save_format="tf"`.
+
+
Args |
-`receiver_tag`
+receiver_tag
|
-Either one of `tfgnn.SOURCE` or `tfgnn.TARGET`. The results
+Either one of tfgnn.SOURCE or tfgnn.TARGET . The results
of GraphSAGE are aggregated for this graph piece. If set to
-`tfgnn.SOURCE` or `tfgnn.TARGET`, the layer will be called for an edge
+tfgnn.SOURCE or tfgnn.TARGET , the layer will be called for an edge
set and will aggregate results at the specified endpoint of the edges.
|
-`sender_node_feature`
+sender_node_feature
|
Can be set to specify the feature name for use as the
input feature from sender nodes to GraphSAGE aggregation, defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`units`
+units
|
Number of output units for the final dimensionality of the output
@@ -87,7 +84,7 @@ from the layer.
|
-`hidden_units`
+hidden_units
|
Number of output units for the linear transformation applied
@@ -97,16 +94,16 @@ W_pool from Eq. (3) in
|
-`reduce_type`
+reduce_type
|
An aggregation operation name. Supported list of aggregation
operators can be found at
-`tfgnn.get_registered_reduce_operation_names()`.
+tfgnn.get_registered_reduce_operation_names() .
|
-`use_bias`
+use_bias
|
If true a bias term will be added to the linear transformations
@@ -114,7 +111,7 @@ for the sender node features.
|
-`dropout_rate`
+dropout_rate
|
Can be set to a dropout rate that will be applied to sender
@@ -122,17 +119,17 @@ node features (independently on each edge).
|
-`activation`
+activation
|
The nonlinearity applied to the concatenated or added node
state and aggregated sender node features. This can be specified as a
Keras layer, a tf.keras.activations.* function, or a string understood
-by `tf.keras.layers.Activation()`. Defaults to relu.
+by tf.keras.layers.Activation() . Defaults to relu.
|
-`**kwargs`
+**kwargs
|
Additional arguments for the Layer.
@@ -141,30 +138,31 @@ Additional arguments for the Layer.
|
+
Attributes |
-`takes_receiver_input`
+takes_receiver_input
|
-If `False`, all calls to convolve() will get `receiver_input=None`.
+If False , all calls to convolve() will get receiver_input=None .
|
-`takes_sender_edge_input`
+takes_sender_edge_input
|
-If `False`, all calls to convolve() will get `sender_edge_input=None`.
+If False , all calls to convolve() will get sender_edge_input=None .
|
-`takes_sender_node_input`
+takes_sender_node_input
|
-If `False`, all calls to convolve() will get `sender_node_input=None`.
+If False , all calls to convolve() will get sender_node_input=None .
|
@@ -173,7 +171,7 @@ If `False`, all calls to convolve() will get `sender_node_input=None`.
convolve
-View
+View
source
diff --git a/tensorflow_gnn/docs/api_docs/python/models/mt_albis.md b/tensorflow_gnn/docs/api_docs/python/models/mt_albis.md
index 61451923..299ee813 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/mt_albis.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/mt_albis.md
@@ -1,17 +1,10 @@
# Module: mt_albis
-[TOC]
-
-
+
+ View source
+on GitHub
TF-GNN's Model Template "Albis".
diff --git a/tensorflow_gnn/docs/api_docs/python/models/mt_albis/MtAlbisGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/mt_albis/MtAlbisGraphUpdate.md
index 1b9c26f4..bfb58361 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/mt_albis/MtAlbisGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/mt_albis/MtAlbisGraphUpdate.md
@@ -1,17 +1,10 @@
# mt_albis.MtAlbisGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns GraphUpdate layer for message passing with Model Template "Albis".
@@ -44,6 +37,10 @@ Returns GraphUpdate layer for message passing with Model Template "Albis".
The TF-GNN Model Template "Albis" provides a small selection of field-tested GNN
architectures through the unified interface of this class.
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
@@ -52,29 +49,29 @@ architectures through the unified interface of this class.
-`units`
+units
|
The dimension of node states in the output GraphTensor.
|
-`message_dim`
+message_dim
|
The dimension of messages computed transiently on each edge.
|
-`receiver_tag`
+receiver_tag
|
-One of `tfgnn.SOURCE` or `tfgnn.TARGET`. The messages are
+One of tfgnn.SOURCE or tfgnn.TARGET . The messages are
sent to the nodes at this endpoint of edges.
|
-`node_set_names`
+node_set_names
|
Optionally, the names of NodeSets to update. By default,
@@ -82,7 +79,7 @@ all NodeSets are updated that receive from at least one EdgeSet.
|
-`edge_feature_name`
+edge_feature_name
|
Optionally, the name of an edge feature to include in
@@ -90,53 +87,54 @@ message computation on edges.
|
-`attention_type`
+attention_type
|
-`"none"`, `"multi_head"`, or `"gat_v2"`. Selects whether
+"none" , "multi_head" , or "gat_v2" . Selects whether
messages are pooled with data-dependent weights computed by a trained
attention mechansim.
|
-`attention_edge_set_names`
+attention_edge_set_names
|
If set, edge sets other than those named here
-will be treated as if `attention_type="none"` regardless.
+will be treated as if attention_type="none" regardless.
|
-`attention_num_heads`
+attention_num_heads
|
-For attention_types `"multi_head"` or `"gat_v2"`,
+For attention_types "multi_head" or "gat_v2" ,
the number of attention heads.
|
-`simple_conv_reduce_type`
+simple_conv_reduce_type
|
-For attention_type `"none"`, controls how messages
-are aggregated on an EdgeSet for each receiver node. Defaults to `"mean"`;
-other recommened values are the concatenations `"mean|sum"`, `"mean|max"`,
-and `"mean|sum|max"` (but mind the increased output dimension and the
-corresponding increase in the number of weights in the next-state layer).
-Technically, can be set to any reduce_type understood by `tfgnn.pool()`.
+For attention_type "none" , controls how messages
+are aggregated on an EdgeSet for each receiver node. Defaults to "mean" ;
+other recommended values are the concatenations "mean|sum" ,
+"mean|max" , and "mean|sum|max" (but mind the increased output
+dimension and the corresponding increase in the number of weights in the
+next-state layer). Technically, can be set to any reduce_type understood
+by tfgnn.pool() .
|
-`simple_conv_use_receiver_state`
+simple_conv_use_receiver_state
|
-For attention_type `"none"`, controls
+For attention_type "none" , controls
whether the receiver node state is used in computing each edge's message
-(in addition to the sender node state and possibly an `edge feature`).
+(in addition to the sender node state and possibly an edge feature ).
|
-`state_dropout_rate`
+state_dropout_rate
|
The dropout rate applied to the pooled and combined
@@ -147,7 +145,7 @@ is applied to messages after pooling.)
|
-`edge_dropout_rate`
+edge_dropout_rate
|
Can be set to a dropout rate for entire edges during
@@ -156,7 +154,7 @@ an edge is dropped, as if the edge were not present in the graph.
|
-`l2_regularization`
+l2_regularization
|
The coefficient of L2 regularization for trained weights.
@@ -164,54 +162,54 @@ The coefficient of L2 regularization for trained weights.
|
-`kernel_initializer`
+kernel_initializer
|
-Can be set to a `kernel_initializer` as understood
-by `tf.keras.layers.Dense` etc.
-An `Initializer` object gets cloned before use to ensure a fresh seed,
-if not set explicitly. For more, see `tfgnn.keras.clone_initializer()`.
+Can be set to a kernel_initializer as understood
+by tf.keras.layers.Dense etc.
+An Initializer object gets cloned before use to ensure a fresh seed,
+if not set explicitly. For more, see tfgnn.keras.clone_initializer() .
|
-`normalization_type`
+normalization_type
|
controls the normalization of output node states.
-By default (`"layer"`), LayerNormalization is used. Can be set to
-`"none"`, or to `"batch"` for BatchNormalization.
+By default ("layer" ), LayerNormalization is used. Can be set to
+"none" , or to "batch" for BatchNormalization.
|
-`batch_normalization_momentum`
+batch_normalization_momentum
|
-If `normalization_type="batch"`, sets the
-`BatchNormalization(momentum=...)` parameter. Ignored otherwise.
+If normalization_type="batch" , sets the
+BatchNormalization(momentum=...) parameter. Ignored otherwise.
|
-`next_state_type`
+next_state_type
|
-`"dense"` or `"residual"`. With the latter, a residual
+"dense" or "residual" . With the latter, a residual
link is added from the old to the new node state, which requires that all
-input node states already have size `units` (unless their size is 0, as
+input node states already have size units (unless their size is 0, as
for latent node sets, in which case the residual link is omitted).
|
-`edge_set_combine_type`
+edge_set_combine_type
|
-`"concat"` or `"sum"`. Controls how pooled messages
+"concat" or "sum" . Controls how pooled messages
from various edge sets are combined as inputs to the NextState layer
-that updates the node states. Defaults to `"concat"`, which gives the
+that updates the node states. Defaults to "concat" , which gives the
pooled messages from each edge set separate weights in the NextState
-layer, namely `units * message_dim * num_incident_edge_sets` per node set.
-Setting this to `"sum"` adds up the pooled messages into a single
+layer, namely units * message_dim * num_incident_edge_sets per node set.
+Setting this to "sum" adds up the pooled messages into a single
vector before passing them into the NextState layer, which requires just
-`units * message_dim` weights per node set.
+units * message_dim weights per node set.
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_from_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_from_config_dict.md
index 6843881a..5eb5a0e6 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_from_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_from_config_dict.md
@@ -1,23 +1,17 @@
# mt_albis.graph_update_from_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Constructs a MtAlbisGraphUpdate from a ConfigDict.
mt_albis.graph_update_from_config_dict(
- cfg: config_dict.ConfigDict
+ cfg: config_dict.ConfigDict,
+ node_set_names: Optional[Collection[tfgnn.NodeSetName]] = None
) -> tf.keras.layers.Layer
@@ -30,16 +24,24 @@ Constructs a MtAlbisGraphUpdate from a ConfigDict.
|
-`cfg`
+cfg
|
-A `ConfigDict` with the fields defined by
-`graph_update_get_config_dict()`. All fields with non-`None` values are
+A ConfigDict with the fields defined by
+graph_update_get_config_dict() . All fields with non-None values are
used as keyword arguments for initializing and returning a
-`MtAlbisGraphUpdate` object. For the required arguments of
-`MtAlbisGraphUpdate.__init__`, users must set a value in `cfg` before
+MtAlbisGraphUpdate object. For the required arguments of
+MtAlbisGraphUpdate.__init__ , users must set a value in cfg before
passing it here.
|
+
+
+node_set_names
+ |
+
+Optionally, the names of NodeSets to update; forwarded to
+MtAlbisGraphUpdate.__init__ .
+ |
@@ -50,7 +52,7 @@ passing it here.
| Returns |
-A new Layer object as returned by `MtAlbisGraphUpdate()`.
+A new Layer object as returned by MtAlbisGraphUpdate() .
|
@@ -64,11 +66,11 @@ A new Layer object as returned by `MtAlbisGraphUpdate()`.
-`TypeError`
+TypeError
|
-if `cfg` fails to supply a required argument for
-`MtAlbisGraphUpdate()`.
+if cfg fails to supply a required argument for
+MtAlbisGraphUpdate() .
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_get_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_get_config_dict.md
index a104fb31..4211a131 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_get_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/mt_albis/graph_update_get_config_dict.md
@@ -1,17 +1,10 @@
# mt_albis.graph_update_get_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns ConfigDict for graph_update_from_config_dict() with defaults.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md
index 47ae18eb..f4dc2f07 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention.md
@@ -1,17 +1,10 @@
# Module: multi_head_attention
-[TOC]
-
-
+
+ View source
+on GitHub
Transformer-style multi-head attention.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionConv.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionConv.md
index 36aab5a5..c5d90833 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionConv.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionConv.md
@@ -1,17 +1,10 @@
# multi_head_attention.MultiHeadAttentionConv
-[TOC]
-
-
+
+ View source
+on GitHub
Transformer-style (dot-product) multi-head attention on GNNs.
@@ -32,7 +25,7 @@ Transformer-style (dot-product) multi-head attention on GNNs.
kernel_initializer: Any = None,
kernel_regularizer: Any = None,
transform_keys: bool = True,
- score_scaling: Literal['none', 'rsqrt_dim', 'trainable_sigmoid'] = 'rsqrt_dim',
+ score_scaling: Literal['none', 'rsqrt_dim', 'trainable_elup1'] = 'rsqrt_dim',
transform_values_after_pooling: bool = False,
**kwargs
)
@@ -96,50 +89,54 @@ using the other arguments.
Example: Transformer-style attention on neighbors along incoming edges whose
result is concatenated with the old node state and passed through a Dense layer
-to compute the new node state. `dense = tf.keras.layers.Dense graph =
-tfgnn.keras.layers.GraphUpdate( node_sets={"paper":
-tfgnn.keras.layers.NodeSetUpdate( {"cites":
-tfgnn.keras.layers.MultiHeadAttentionConv( message_dim,
-receiver_tag=tfgnn.TARGET)},
-tfgnn.keras.layers.NextStateFromConcat(dense(node_state_dim)))} )(graph)`
+to compute the new node state.
+
+```
+dense = tf.keras.layers.Dense
+graph = tfgnn.keras.layers.GraphUpdate(
+ node_sets={"paper": tfgnn.keras.layers.NodeSetUpdate(
+ {"cites": tfgnn.keras.layers.MultiHeadAttentionConv(
+ message_dim, receiver_tag=tfgnn.TARGET)},
+ tfgnn.keras.layers.NextStateFromConcat(dense(node_state_dim)))}
+)(graph)
+```
For now, there is a variant that modifies the inputs transformation part and
could potentially be beneficial:
-```
-1. (transform_keys is False) Instead of projecting both queries and
- keys when computing attention weights, we only project the queries
- because the two linear projections can be collapsed to a single
- projection:
-
- $$ (Q_v W_Q^k)(K_u W_K^k)^T
- = Q_v (W_Q^k {W_K^k}^T) K_u^T
- = Q_v W_{QK}^k K_u^T $$
-
- where $d$ is the key width. (Following "Attention is all you need",
- this scaling is meant to achieve unit variance of the results, assuming
- that $Q_v W_{QK}^k$ has unit variance due to the initialization of
- $Q_v W_{QK}^k$.)
-
- NOTE: The single projection matrix behaves differently in
- gradient-descent training than the product of two matrices.
-```
+1. (transform_keys is False) Instead of projecting both queries and keys when
+ computing attention weights, we only project the queries because the two
+ linear projections can be collapsed to a single projection:
+
+ $$ (Q_v W_Q^k)(K_u W_K^k)^T = Q_v (W_Q^k {W_K^k}^T) K_u^T = Q_v W_{QK}^k
+ K_u^T $$
+
+ where $d$ is the key width. (Following "Attention is all you need", this
+ scaling is meant to achieve unit variance of the results, assuming that $Q_v
+ W_{QK}^k$ has unit variance due to the initialization of $Q_v W_{QK}^k$.)
+
+ NOTE: The single projection matrix behaves differently in gradient-descent
+ training than the product of two matrices.
+
+This layer can be restored from config by `tf.keras.models.load_model()` when
+saved as part of a Keras model using `save_format="tf"`.
+
Init args |
-`num_heads`
+num_heads
|
The number of attention heads.
|
-`per_head_channels`
+per_head_channels
|
The number of channels for each attention head. This
@@ -147,51 +144,51 @@ means that the final output size will be per_head_channels * num_heads.
|
-`receiver_tag`
+receiver_tag
|
-one of `tfgnn.SOURCE`, `tfgnn.TARGET` or `tfgnn.CONTEXT`.
+one of tfgnn.SOURCE , tfgnn.TARGET or tfgnn.CONTEXT .
The results of attention are aggregated for this graph piece.
-If set to `tfgnn.SOURCE` or `tfgnn.TARGET`, the layer can be called for
+If set to tfgnn.SOURCE or tfgnn.TARGET , the layer can be called for
an edge set and will aggregate results at the specified endpoint of the
edges.
-If set to `tfgnn.CONTEXT`, the layer can be called for an edge set or
+If set to tfgnn.CONTEXT , the layer can be called for an edge set or
node set.
If left unset for init, the tag must be passed at call time.
|
-`receiver_feature`
+receiver_feature
|
-Can be set to override `tfgnn.HIDDEN_STATE`
+Can be set to override tfgnn.HIDDEN_STATE
for use as the receiver's input feature to attention. (The attention key
is derived from this input.)
|
-`sender_node_feature`
+sender_node_feature
|
-Can be set to override `tfgnn.HIDDEN_STATE`
+Can be set to override tfgnn.HIDDEN_STATE
for use as the input feature from sender nodes to attention.
-IMPORTANT: Must be set to `None` for use with `receiver_tag=tfgnn.CONTEXT`
+IMPORTANT: Must be set to None for use with receiver_tag=tfgnn.CONTEXT
on an edge set, or for pooling from edges without sender node states.
|
-`sender_edge_feature`
+sender_edge_feature
|
Can be set to a feature name of the edge set to select
-it as an input feature. By default, this set to `None`, which disables
+it as an input feature. By default, this set to None , which disables
this input.
-IMPORTANT: Must be set for use with `receiver_tag=tfgnn.CONTEXT`
+IMPORTANT: Must be set for use with receiver_tag=tfgnn.CONTEXT
on an edge set.
|
-`use_bias`
+use_bias
|
If true, bias terms are added to the transformations of query,
@@ -199,7 +196,7 @@ key and value inputs.
|
-`edge_dropout`
+edge_dropout
|
Can be set to a dropout rate for edge dropout. (When pooling
@@ -208,7 +205,7 @@ is dropped out.)
|
-`inputs_dropout`
+inputs_dropout
|
Dropout rate for random dropout on the inputs to this
@@ -216,18 +213,18 @@ convolution layer, i.e. the receiver, sender node, and sender edge inputs.
|
-`attention_activation`
+attention_activation
|
The nonlinearity used on the transformed inputs
-(query, and keys if `transform_keys` is `True`) before computing the
+(query, and keys if transform_keys is True ) before computing the
attention scores. This can be specified as a Keras layer, a
tf.keras.activations.* function, or a string understood by
-`tf.keras.layers.Activation`. Defaults to None.
+tf.keras.layers.Activation . Defaults to None.
|
-`activation`
+activation
|
The nonlinearity applied to the final result of attention,
@@ -235,25 +232,25 @@ specified in the same ways as attention_activation.
|
-`kernel_initializer`
+kernel_initializer
|
-Can be set to a `kernel_initializer` as understood
-by `tf.keras.layers.Dense` etc.
-An `Initializer` object gets cloned before use to ensure a fresh seed,
-if not set explicitly. For more, see `tfgnn.keras.clone_initializer()`.
+Can be set to a kernel_initializer as understood
+by tf.keras.layers.Dense etc.
+An Initializer object gets cloned before use to ensure a fresh seed,
+if not set explicitly. For more, see tfgnn.keras.clone_initializer() .
|
-`kernel_regularizer`
+kernel_regularizer
|
-Can be set to a `kernel_regularized` as understood
-by `tf.keras.layers.Dense` etc.
+Can be set to a kernel_regularized as understood
+by tf.keras.layers.Dense etc.
|
-`transform_keys`
+transform_keys
|
If true, transform both queries and keys inputs. Otherwise,
@@ -263,22 +260,24 @@ independent of this arg.)
|
-`score_scaling`
+score_scaling
|
-One of either `"none"`, `"rsqrt_dim"`, or
-`"trainable_sigmoid"`. If set to `"rsqrt_dim"`, the attention scores are
+One of either "rsqrt_dim" (default), "trainable_elup1" ,
+or "none" . If set to "rsqrt_dim" , the attention scores are
divided by the square root of the dimension of keys (i.e.,
-`per_head_channels` if `transform_keys=True`, otherwise whatever the
-dimension of combined sender inputs is). If set to `"trainable_sigmoid"`,
-the scores are scaled with `sigmoid(x)`, where `x` is a trainable weight
-of the model that is initialized to `-5.0`, which initially makes all the
-attention weights equal and slowly ramps up as the other weights in the
-layer converge. Defaults to `"rsqrt_dim"`.
+per_head_channels if transform_keys=True , otherwise whatever the
+dimension of combined sender inputs is). If set to "trainable_elup1" ,
+the scores are scaled with elu(x) + 1 , where elu is the Exponential
+Linear Unit (see tf.keras.activations.elu ), and x is a per-head
+trainable weight of the model that is initialized to 0.0 . Recall that
+elu(x) + 1 == exp(x) if x<0 else x+1 , so the
+initial scaling factor is 1.0 , decreases exponentially below 1.0, and
+grows linearly above 1.0.
|
-`transform_values_after_pooling`
+transform_values_after_pooling
|
By default, each attention head applies
@@ -292,56 +291,58 @@ IMPORTANT: Toggling this option breaks checkpoint compatibility.
|
+
Args |
- `receiver_tag` | one of
-`tfgnn.SOURCE`, `tfgnn.TARGET` or `tfgnn.CONTEXT`. The results are aggregated
-for this graph piece. If set to `tfgnn.SOURCE` or `tfgnn.TARGET`, the layer can
-be called for an edge set and will aggregate results at the specified endpoint
-of the edges. If set to `tfgnn.CONTEXT`, the layer can be called for an edge set
-or a node set and will aggregate results for context (per graph component). If
-left unset for init, the tag must be passed at call time. |
-`receiver_feature` | The name of the
-feature that is read from the receiver graph piece and passed as
-convolve(receiver_input=...). |
-`sender_node_feature` | The name of the
-feature that is read from the sender nodes, if any, and passed as
-convolve(sender_node_input=...). NOTICE this must be `None` for use with
-`receiver_tag=tfgnn.CONTEXT` on an edge set, or for pooling from edges without
-sender node states. |
-`sender_edge_feature` | The name of the
-feature that is read from the sender edges, if any, and passed as
-convolve(sender_edge_input=...). NOTICE this must not be `None` for use with
-`receiver_tag=tfgnn.CONTEXT` on an edge set. |
-`extra_receiver_ops` | A str-keyed
-dictionary of Python callables that are wrapped to bind some arguments and then
-passed on to `convolve()`. Sample usage: `extra_receiver_ops={"softmax":
-tfgnn.softmax}`. The values passed in this dict must be callable as follows,
-with two positional arguments:
+ | receiver_tag | one of
+tfgnn.SOURCE , tfgnn.TARGET or
+tfgnn.CONTEXT . The results are aggregated for this graph piece. If
+set to tfgnn.SOURCE or tfgnn.TARGET , the layer can be
+called for an edge set and will aggregate results at the specified endpoint of
+the edges. If set to tfgnn.CONTEXT , the layer can be called for an
+edge set or a node set and will aggregate results for context (per graph
+component). If left unset for init, the tag must be passed at call time. |
+ receiver_feature |
+ The name of the feature that is read from the receiver graph piece and
+passed as convolve(receiver_input=...). |
+sender_node_feature | The
+name of the feature that is read from the sender nodes, if any, and passed as
+convolve(sender_node_input=...). NOTICE this must be None for use
+with receiver_tag=tfgnn.CONTEXT on an edge set, or for pooling from
+edges without sender node states. |
+sender_edge_feature | The
+name of the feature that is read from the sender edges, if any, and passed as
+convolve(sender_edge_input=...). NOTICE this must not be None for
+use with receiver_tag=tfgnn.CONTEXT on an edge set. |
+ extra_receiver_ops | A
+str-keyed dictionary of Python callables that are wrapped to bind some arguments
+and then passed on to convolve() . Sample usage:
+extra_receiver_ops={"softmax": tfgnn.softmax} . The values passed in
+this dict must be callable as follows, with two positional arguments:
```python
f(graph, receiver_tag, node_set_name=..., feature_value=..., ...)
f(graph, receiver_tag, edge_set_name=..., feature_value=..., ...)
```
-The wrapped callables seen by `convolve()` can be called like
+The wrapped callables seen by convolve() can be called like
```python
wrapped_f(feature_value, ...)
```
-The first three arguments of `f` are set to the input GraphTensor of
-the layer and the tag/name pair required by `tfgnn.broadcast()` and
-`tfgnn.pool()` to move values between the receiver and the messages that
+The first three arguments of f are set to the input GraphTensor of
+the layer and the tag/name pair required by tfgnn.broadcast() and
+tfgnn.pool() to move values between the receiver and the messages that
are computed inside the convolution. The sole positional argument of
-`wrapped_f()` is passed to `f()` as `feature_value=`, and any keyword
+wrapped_f() is passed to f() as feature_value= , and any keyword
arguments are forwarded.
|
-`**kwargs`
+**kwargs
|
Forwarded to the base class tf.keras.layers.Layer.
@@ -350,30 +351,31 @@ Forwarded to the base class tf.keras.layers.Layer.
|
+
Attributes |
-`takes_receiver_input`
+takes_receiver_input
|
-If `False`, all calls to convolve() will get `receiver_input=None`.
+If False , all calls to convolve() will get receiver_input=None .
|
-`takes_sender_edge_input`
+takes_sender_edge_input
|
-If `False`, all calls to convolve() will get `sender_edge_input=None`.
+If False , all calls to convolve() will get sender_edge_input=None .
|
-`takes_sender_node_input`
+takes_sender_node_input
|
-If `False`, all calls to convolve() will get `sender_node_input=None`.
+If False , all calls to convolve() will get sender_node_input=None .
|
@@ -382,7 +384,7 @@ If `False`, all calls to convolve() will get `sender_node_input=None`.
convolve
-View
+View
source
@@ -408,44 +410,45 @@ from nodes to context). In the end, values have to be pooled from there into a
Tensor with a leading dimension indexed by receivers, see `pool_to_receiver`.
+
Args |
-`sender_node_input`
+sender_node_input
|
-The input Tensor from the sender NodeSet, or `None`.
-If self.takes_sender_node_input is `False`, this arg will be `None`.
-(If it is `True`, that depends on how this layer gets called.)
+The input Tensor from the sender NodeSet, or None .
+If self.takes_sender_node_input is False , this arg will be None .
+(If it is True , that depends on how this layer gets called.)
See also broadcast_from_sender_node.
|
-`sender_edge_input`
+sender_edge_input
|
-The input Tensor from the sender EdgeSet, or `None`.
-If self.takes_sender_edge_input is `False`, this arg will be `None`.
-(If it is `True`, it depends on how this layer gets called.)
+The input Tensor from the sender EdgeSet, or None .
+If self.takes_sender_edge_input is False , this arg will be None .
+(If it is True , it depends on how this layer gets called.)
If present, this Tensor is already indexed by the items for which
messages are computed.
|
-`receiver_input`
+receiver_input
|
The input Tensor from the receiver NodeSet or Context,
-or None. If self.takes_receiver_input is `False`, this arg will be
-`None`. (If it is `True`, it depends on how this layer gets called.)
+or None. If self.takes_receiver_input is False , this arg will be
+None . (If it is True , it depends on how this layer gets called.)
See broadcast_from_receiver.
|
-`broadcast_from_sender_node`
+broadcast_from_sender_node
|
A function that broadcasts a Tensor indexed
@@ -454,25 +457,25 @@ messages are computed.
|
-`broadcast_from_receiver`
+broadcast_from_receiver
|
-Call this as `broadcast_from_receiver(value)`
+Call this as broadcast_from_receiver(value)
to broadcast a Tensor indexed like receiver_input to a Tensor indexed
by the items for which messages are computed.
|
-`pool_to_receiver`
+pool_to_receiver
|
-Call this as `pool_to_receiver(value, reduce_type=...)`
+Call this as pool_to_receiver(value, reduce_type=...)
to pool an item-indexed Tensor to a receiver-indexed tensor, using
a reduce_type understood by tfgnn.pool(), such as "sum".
|
-`extra_receiver_ops`
+extra_receiver_ops
|
The extra_receiver_ops passed to init, see there,
@@ -482,10 +485,10 @@ this argument, so subclass implementors not using it can omit it.
|
-`training`
+training
|
-The `training` boolean that was passed to Layer.call(). If true,
+The training boolean that was passed to Layer.call(). If true,
the result is computed for training rather than inference. For example,
calls to tf.nn.dropout() are usually conditioned on this flag.
By contrast, calling another Keras layer (like tf.keras.layers.Dropout)
@@ -495,6 +498,7 @@ does not require forwarding this arg, Keras does that automatically.
|
+
Returns |
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionEdgePool.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionEdgePool.md
index d535a182..6fa90816 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionEdgePool.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionEdgePool.md
@@ -1,17 +1,10 @@
# multi_head_attention.MultiHeadAttentionEdgePool
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a layer for pooling edges with Transformer-style Multi-Head Attention.
@@ -40,21 +33,26 @@ an edge set to do the analogous pooling of edge states to context.
NOTE: This layer cannot pool node states. For that, use MultiHeadAttentionConv.
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`num_heads`
+num_heads
|
The number of attention heads.
|
-`per_head_channels`
+per_head_channels
|
The number of channels for each attention head. This
@@ -62,19 +60,19 @@ means that the final output size will be per_head_channels * num_heads.
|
-`receiver_tag`
+receiver_tag
|
The results of attention are aggregated for this graph piece.
-If set to `tfgnn.CONTEXT`, the layer can be called for an edge set or node
-set. If set to an IncidentNodeTag (e.g., `tfgnn.SOURCE` or
-`tfgnn.TARGET`), the layer can be called for an edge set and will
+If set to tfgnn.CONTEXT , the layer can be called for an edge set or node
+set. If set to an IncidentNodeTag (e.g., tfgnn.SOURCE or
+tfgnn.TARGET ), the layer can be called for an edge set and will
aggregate results at the specified endpoint of the edges. If left unset,
the tag must be passed when calling the layer.
|
-`receiver_feature`
+receiver_feature
|
By default, the default state feature of the receiver is
@@ -83,7 +81,7 @@ selected by setting this argument.
|
-`sender_feature`
+sender_feature
|
By default, the default state feature of the edge set is
@@ -92,7 +90,7 @@ selected by setting this argument.
|
-`**kwargs`
+**kwargs
|
Any other option for MultiHeadAttentionConv, except
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionHomGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionHomGraphUpdate.md
index 27cc34b8..d532de39 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionHomGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionHomGraphUpdate.md
@@ -1,17 +1,10 @@
# multi_head_attention.MultiHeadAttentionHomGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GraphUpdate layer with a transformer-style multihead attention.
@@ -42,21 +35,26 @@ details).
> that are explicitly stored in the input GraphTensor. Attention of a node to
> itself requires having an explicit loop in the edge set.
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`num_heads`
+num_heads
|
The number of attention heads.
|
-`per_head_channels`
+per_head_channels
|
The number of channels for each attention head. This
@@ -64,29 +62,29 @@ means that the final output size will be per_head_channels * num_heads.
|
-`receiver_tag`
+receiver_tag
|
-one of `tfgnn.SOURCE` or `tfgnn.TARGET`.
+one of tfgnn.SOURCE or tfgnn.TARGET .
|
-`feature_name`
+feature_name
|
The feature name of node states; defaults to
-`tfgnn.HIDDEN_STATE`.
+tfgnn.HIDDEN_STATE .
|
-`name`
+name
|
Optionally, a name for the layer returned.
|
-`**kwargs`
+**kwargs
|
Any optional arguments to MultiHeadAttentionConv, see there.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionMPNNGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionMPNNGraphUpdate.md
index d6eb153e..9c496104 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionMPNNGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/MultiHeadAttentionMPNNGraphUpdate.md
@@ -1,17 +1,10 @@
# multi_head_attention.MultiHeadAttentionMPNNGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GraphUpdate layer for message passing with MultiHeadAttention pooling.
@@ -44,45 +37,50 @@ layer to compute the new node states from a concatenation of the old node state
and all pooled messages, analogous to TF-GNN's
`vanilla_mpnn.VanillaMPNNGraphUpdate` and `gat_v2.GATv2MPNNGraphUpdate`.
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`units`
+units
|
The dimension of output hidden states for each node.
|
-`message_dim`
+message_dim
|
The dimension of messages (attention values) computed on each
-edge. Must be divisible by `num_heads`.
+edge. Must be divisible by num_heads .
|
-`num_heads`
+num_heads
|
The number of attention heads used by MultiHeadAttention.
-`message_dim` must be divisible by this number.
+message_dim must be divisible by this number.
|
-`receiver_tag`
+receiver_tag
|
-one of `tfgnn.TARGET` or `tfgnn.SOURCE`, to select the
+one of tfgnn.TARGET or tfgnn.SOURCE , to select the
incident node of each edge that receives the message.
|
-`node_set_names`
+node_set_names
|
The names of node sets to update. If unset, updates all that
@@ -90,16 +88,16 @@ are on the receiving end of any edge set.
|
-`edge_feature`
+edge_feature
|
Can be set to a feature name of the edge set to select it as
-an input feature. By default, this set to `None`, which disables this
+an input feature. By default, this set to None , which disables this
input.
|
-`l2_regularization`
+l2_regularization
|
The coefficient of L2 regularization for weights and
@@ -107,7 +105,7 @@ biases.
|
-`edge_dropout_rate`
+edge_dropout_rate
|
The edge dropout rate applied during attention pooling of
@@ -115,24 +113,24 @@ edges.
|
-`state_dropout_rate`
+state_dropout_rate
|
The dropout rate applied to the resulting node states.
|
-`attention_activation`
+attention_activation
|
The nonlinearity used on the transformed inputs before
multiplying with the trained weights of the attention layer. This can be
specified as a Keras layer, a tf.keras.activations.* function, or a string
-understood by `tf.keras.layers.Activation`. Defaults to None.
+understood by tf.keras.layers.Activation . Defaults to None.
|
-`conv_activation`
+conv_activation
|
The nonlinearity applied to the result of attention on one
@@ -140,7 +138,7 @@ edge set, specified in the same ways as attention_activation.
|
-`activation`
+activation
|
The nonlinearity applied to the new node states computed by this
@@ -148,25 +146,26 @@ graph update.
|
-`kernel_initializer`
+kernel_initializer
|
-Can be set to a `kernel_initializer` as understood
-by `tf.keras.layers.Dense` etc.
-An `Initializer` object gets cloned before use to ensure a fresh seed,
-if not set explicitly. For more, see `tfgnn.keras.clone_initializer()`.
+Can be set to a kernel_initializer as understood
+by tf.keras.layers.Dense etc.
+An Initializer object gets cloned before use to ensure a fresh seed,
+if not set explicitly. For more, see tfgnn.keras.clone_initializer() .
|
+
Returns |
A GraphUpdate layer for use on a scalar GraphTensor with
-`tfgnn.HIDDEN_STATE` features on the node sets.
+tfgnn.HIDDEN_STATE features on the node sets.
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_from_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_from_config_dict.md
index f6bbf78e..efac65b5 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_from_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_from_config_dict.md
@@ -1,72 +1,77 @@
# multi_head_attention.graph_update_from_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a MultiHeadAttentionMPNNGraphUpdate initialized from `cfg`.
multi_head_attention.graph_update_from_config_dict(
- cfg: config_dict.ConfigDict
+ cfg: config_dict.ConfigDict,
+ *,
+ node_set_names: Optional[Collection[tfgnn.NodeSetName]] = None
) -> tf.keras.layers.Layer
-
+
Args |
-`cfg`
+cfg
|
-A `ConfigDict` with the fields defined by
-`graph_update_get_config_dict()`. All fields with non-`None` values are
+A ConfigDict with the fields defined by
+graph_update_get_config_dict() . All fields with non-None values are
used as keyword arguments for initializing and returning a
-`MultiHeadAttentionMPNNGraphUpdate` object. For the required arguments of
-`MultiHeadAttentionMPNNGraphUpdate.__init__`, users must set a value in
-`cfg` before passing it here.
+MultiHeadAttentionMPNNGraphUpdate object. For the required arguments of
+MultiHeadAttentionMPNNGraphUpdate.__init__ , users must set a value in
+cfg before passing it here.
+ |
+
+
+node_set_names
+ |
+
+Optionally, the names of NodeSets to update; forwarded to
+MultiHeadAttentionMPNNGraphUpdate.__init__ .
|
+
Returns |
-A new `MultiHeadAttentionMPNNGraphUpdate` object.
+A new MultiHeadAttentionMPNNGraphUpdate object.
|
+
Raises |
-`TypeError`
+TypeError
|
-if `cfg` fails to supply a required argument for
-`MultiHeadAttentionMPNNGraphUpdate.__init__`.
+if cfg fails to supply a required argument for
+MultiHeadAttentionMPNNGraphUpdate.__init__ .
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_get_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_get_config_dict.md
index 017fba0a..51a8aee9 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_get_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/multi_head_attention/graph_update_get_config_dict.md
@@ -1,17 +1,10 @@
# multi_head_attention.graph_update_get_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns ConfigDict for graph_update_from_config_dict() with defaults.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn.md b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn.md
index 755c6346..148b9f45 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn.md
@@ -1,17 +1,10 @@
# Module: vanilla_mpnn
-[TOC]
-
-
+
+ View source
+on GitHub
TF-GNN's "Vanilla MPNN" model.
diff --git a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/VanillaMPNNGraphUpdate.md b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/VanillaMPNNGraphUpdate.md
index 00b4f949..9159630a 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/VanillaMPNNGraphUpdate.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/VanillaMPNNGraphUpdate.md
@@ -1,17 +1,10 @@
# vanilla_mpnn.VanillaMPNNGraphUpdate
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a GraphUpdate layer for a Vanilla MPNN.
@@ -55,36 +48,41 @@ and the pooled messages from all incident node sets E_1, E_2, ...:
$$h_v := \text{ReLU}(
W_{\text{state}} (h_v || m_{E_1} || m_{E_2} || \ldots)).$$
+The layer returned by this function can be restored from config by
+`tf.keras.models.load_model()` when saved as part of a Keras model using
+`save_format="tf"`.
+
+
Args |
-`units`
+units
|
The dimension of output hidden states for each node.
|
-`message_dim`
+message_dim
|
The dimension of messages computed on each edge.
|
-`receiver_tag`
+receiver_tag
|
-one of `tfgnn.TARGET` or `tfgnn.SOURCE`, to select the
+one of tfgnn.TARGET or tfgnn.SOURCE , to select the
incident node of each edge that receives the message.
|
-`node_set_names`
+node_set_names
|
The names of node sets to update. If unset, updates all
@@ -92,26 +90,26 @@ that are on the receiving end of any edge set.
|
-`edge_feature`
+edge_feature
|
Can be set to a feature name of the edge set to select
-it as an input feature. By default, this set to `None`, which disables
+it as an input feature. By default, this set to None , which disables
this input.
|
-`reduce_type`
+reduce_type
|
How to pool the messages from edges to receiver nodes; defaults
-to `"sum"`. Can be any reduce_type understood by `tfgnn.pool()`, including
-concatenations like `"sum|max"` (but mind the increased dimension of the
+to "sum" . Can be any reduce_type understood by tfgnn.pool() , including
+concatenations like "sum|max" (but mind the increased dimension of the
result and the growing number of model weights in the next-state layer).
|
-`l2_regularization`
+l2_regularization
|
The coefficient of L2 regularization for weights and
@@ -119,7 +117,7 @@ biases.
|
-`dropout_rate`
+dropout_rate
|
The dropout rate applied to messages on each edge and to the
@@ -127,17 +125,17 @@ new node state.
|
-`kernel_initializer`
+kernel_initializer
|
-Can be set to a `kernel_initializer` as understood
-by `tf.keras.layers.Dense` etc.
-An `Initializer` object gets cloned before use to ensure a fresh seed,
-if not set explicitly. For more, see `tfgnn.keras.clone_initializer()`.
+Can be set to a kernel_initializer as understood
+by tf.keras.layers.Dense etc.
+An Initializer object gets cloned before use to ensure a fresh seed,
+if not set explicitly. For more, see tfgnn.keras.clone_initializer() .
|
-`use_layer_normalization`
+use_layer_normalization
|
Flag to determine whether to apply layer
@@ -147,13 +145,14 @@ normalization to the new node state.
|
+
Returns |
A GraphUpdate layer for use on a scalar GraphTensor with
-`tfgnn.HIDDEN_STATE` features on the node sets.
+tfgnn.HIDDEN_STATE features on the node sets.
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_from_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_from_config_dict.md
index 539deb0f..c08e7a78 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_from_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_from_config_dict.md
@@ -1,72 +1,77 @@
# vanilla_mpnn.graph_update_from_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns a VanillaMPNNGraphUpdate initialized from `cfg`.
vanilla_mpnn.graph_update_from_config_dict(
- cfg: config_dict.ConfigDict
+ cfg: config_dict.ConfigDict,
+ *,
+ node_set_names: Optional[Collection[tfgnn.NodeSetName]] = None
) -> tf.keras.layers.Layer
-
+
Args |
-`cfg`
+cfg
|
-A `ConfigDict` with the fields defined by
-`graph_update_get_config_dict()`. All fields with non-`None` values are
+A ConfigDict with the fields defined by
+graph_update_get_config_dict() . All fields with non-None values are
used as keyword arguments for initializing and returning a
-`VanillaMPNNGraphUpdate` object. For the required arguments of
-`VanillaMPNNGraphUpdate.__init__`, users must set a value in `cfg` before
+VanillaMPNNGraphUpdate object. For the required arguments of
+VanillaMPNNGraphUpdate.__init__ , users must set a value in cfg before
passing it here.
|
+
+
+node_set_names
+ |
+
+Optionally, the names of NodeSets to update; forwarded to
+MtAlbisGraphUpdate.__init__ .
+ |
+
Returns |
-A new `VanillaMPNNGraphUpdate` object.
+A new VanillaMPNNGraphUpdate object.
|
+
Raises |
-`TypeError`
+TypeError
|
-if `cfg` fails to supply a required argument for
-`VanillaMPNNGraphUpdate.__init__`.
+if cfg fails to supply a required argument for
+VanillaMPNNGraphUpdate.__init__ .
|
diff --git a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_get_config_dict.md b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_get_config_dict.md
index 82ee8053..879fba16 100644
--- a/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_get_config_dict.md
+++ b/tensorflow_gnn/docs/api_docs/python/models/vanilla_mpnn/graph_update_get_config_dict.md
@@ -1,17 +1,10 @@
# vanilla_mpnn.graph_update_get_config_dict
-[TOC]
-
-
+
+ View source
+on GitHub
Returns ConfigDict for graph_update_from_config_dict() with defaults.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner.md b/tensorflow_gnn/docs/api_docs/python/runner.md
new file mode 100644
index 00000000..e8a39924
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner.md
@@ -0,0 +1,173 @@
+# Module: runner
+
+
+
+
+ View source
+on GitHub
+
+A general purpose runner for TF-GNN.
+
+## Classes
+
+[`class ContextLabelFn`](./runner/ContextLabelFn.md): Reads out a `tfgnn.Field`
+from the `GraphTensor` context.
+
+[`class DatasetProvider`](./runner/DatasetProvider.md): Helper class that
+provides a standard way to create an ABC using inheritance.
+
+[`class DotProductLinkPrediction`](./runner/DotProductLinkPrediction.md):
+Implements edge score as dot product of features of endpoint nodes.
+
+[`class FitOrSkipPadding`](./runner/FitOrSkipPadding.md): Calculates fit or skip
+`SizeConstraints` for `GraphTensor` padding.
+
+[`class GraphBinaryClassification`](./runner/GraphBinaryClassification.md):
+Graph binary (or multi-label) classification from pooled node states.
+
+[`class GraphMeanAbsoluteError`](./runner/GraphMeanAbsoluteError.md): Mean
+absolute error task.
+
+[`class GraphMeanAbsolutePercentageError`](./runner/GraphMeanAbsolutePercentageError.md):
+Mean absolute percentage error task.
+
+[`class GraphMeanSquaredError`](./runner/GraphMeanSquaredError.md): Mean squared
+error task.
+
+[`class GraphMeanSquaredLogScaledError`](./runner/GraphMeanSquaredLogScaledError.md):
+Mean squared log scaled error task.
+
+[`class GraphMeanSquaredLogarithmicError`](./runner/GraphMeanSquaredLogarithmicError.md):
+Mean squared logarithmic error task.
+
+[`class GraphMulticlassClassification`](./runner/GraphMulticlassClassification.md):
+Graph multiclass classification from pooled node states.
+
+[`class GraphTensorPadding`](./runner/GraphTensorPadding.md): Collects
+`GraphtTensor` padding helpers.
+
+[`class GraphTensorProcessorFn`](./runner/GraphTensorProcessorFn.md): A class
+for `GraphTensor` processing.
+
+[`class HadamardProductLinkPrediction`](./runner/HadamardProductLinkPrediction.md):
+Implements edge score as hadamard product of features of endpoint nodes.
+
+[`class IntegratedGradientsExporter`](./runner/IntegratedGradientsExporter.md):
+Exports a Keras model with an additional integrated gradients signature.
+
+[`class KerasModelExporter`](./runner/KerasModelExporter.md): Exports a Keras
+model (with Keras API) via `tf.keras.models.save_model`.
+
+[`class KerasTrainer`](./runner/KerasTrainer.md): Trains using the
+`tf.keras.Model.fit` training loop.
+
+[`class KerasTrainerCheckpointOptions`](./runner/KerasTrainerCheckpointOptions.md):
+Provides Keras Checkpointing related configuration options.
+
+[`class KerasTrainerOptions`](./runner/KerasTrainerOptions.md): Provides Keras
+training related options.
+
+[`class ModelExporter`](./runner/ModelExporter.md): Saves a Keras model.
+
+[`class NodeBinaryClassification`](./runner/NodeBinaryClassification.md): Node
+binary (or multi-label) classification via structured readout.
+
+[`class NodeMulticlassClassification`](./runner/NodeMulticlassClassification.md):
+Node multiclass classification via structured readout.
+
+[`class ParameterServerStrategy`](./runner/ParameterServerStrategy.md): A
+`ParameterServerStrategy` convenience wrapper.
+
+[`class PassthruDatasetProvider`](./runner/PassthruDatasetProvider.md): Builds a
+`tf.data.Dataset` from a pass thru dataset.
+
+[`class PassthruSampleDatasetsProvider`](./runner/PassthruSampleDatasetsProvider.md):
+Builds a sampled `tf.data.Dataset` from multiple pass thru datasets.
+
+[`class RootNodeBinaryClassification`](./runner/RootNodeBinaryClassification.md):
+Root node binary (or multi-label) classification.
+
+[`class RootNodeLabelFn`](./runner/RootNodeLabelFn.md): Reads out a
+`tfgnn.Field` from the `GraphTensor` root (i.e. first) node.
+
+[`class RootNodeMeanAbsoluteError`](./runner/RootNodeMeanAbsoluteError.md): Mean
+absolute error task.
+
+[`class RootNodeMeanAbsoluteLogarithmicError`](./runner/RootNodeMeanAbsoluteLogarithmicError.md):
+Root node mean absolute logarithmic error task.
+
+[`class RootNodeMeanAbsolutePercentageError`](./runner/RootNodeMeanAbsolutePercentageError.md):
+Mean absolute percentage error task.
+
+[`class RootNodeMeanSquaredError`](./runner/RootNodeMeanSquaredError.md): Mean
+squared error task.
+
+[`class RootNodeMeanSquaredLogScaledError`](./runner/RootNodeMeanSquaredLogScaledError.md):
+Mean squared log scaled error task.
+
+[`class RootNodeMeanSquaredLogarithmicError`](./runner/RootNodeMeanSquaredLogarithmicError.md):
+Mean squared logarithmic error task.
+
+[`class RootNodeMulticlassClassification`](./runner/RootNodeMulticlassClassification.md):
+Root node multiclass classification.
+
+[`class RunResult`](./runner/RunResult.md): Holds the return values of
+`run(...)`.
+
+[`class SampleTFRecordDatasetsProvider`](./runner/SampleTFRecordDatasetsProvider.md):
+Builds a sampling `tf.data.Dataset` from multiple filenames.
+
+[`class SimpleDatasetProvider`](./runner/SimpleDatasetProvider.md): Builds a
+`tf.data.Dataset` from a list of files.
+
+[`class SimpleSampleDatasetsProvider`](./runner/SimpleSampleDatasetsProvider.md):
+Builds a sampling `tf.data.Dataset` from multiple filenames.
+
+[`class SubmoduleExporter`](./runner/SubmoduleExporter.md): Exports a Keras
+submodule.
+
+[`class TFDataServiceConfig`](./runner/TFDataServiceConfig.md): Provides tf.data
+service related configuration options.
+
+[`class TFRecordDatasetProvider`](./runner/TFRecordDatasetProvider.md): Builds a
+`tf.data.Dataset` from a list of files.
+
+[`class TPUStrategy`](./runner/TPUStrategy.md): A `TPUStrategy` convenience
+wrapper.
+
+[`class Task`](./runner/Task.md): Defines a learning objective for a GNN.
+
+[`class TightPadding`](./runner/TightPadding.md): Calculates tight
+`SizeConstraints` for `GraphTensor` padding.
+
+[`class Trainer`](./runner/Trainer.md): A class for training and validation of a
+Keras model.
+
+## Functions
+
+[`export_model(...)`](./runner/export_model.md): Exports a Keras model without
+traces s.t. it is loadable without TF-GNN.
+
+[`incrementing_model_dir(...)`](./runner/incrementing_model_dir.md): Create,
+given some `dirname`, an incrementing model directory.
+
+[`integrated_gradients(...)`](./runner/integrated_gradients.md): Integrated
+gradients.
+
+[`one_node_per_component(...)`](./runner/one_node_per_component.md): Returns a
+`Mapping` `node_set_name: 1` for every node set in `gtspec`.
+
+[`run(...)`](./runner/run.md): Runs training (and validation) of a model on
+task(s) with the given data.
+
+## Type Aliases
+
+[`Loss`](./runner/Loss.md)
+
+[`Losses`](./runner/Losses.md)
+
+[`Metric`](./runner/Loss.md)
+
+[`Metrics`](./runner/Metrics.md)
+
+[`Predictions`](./runner/Predictions.md)
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/ContextLabelFn.md b/tensorflow_gnn/docs/api_docs/python/runner/ContextLabelFn.md
new file mode 100644
index 00000000..b5145681
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/ContextLabelFn.md
@@ -0,0 +1,17 @@
+# runner.ContextLabelFn
+
+
+
+
+ View source
+on GitHub
+
+Reads out a `tfgnn.Field` from the `GraphTensor` context.
+
+
+runner.ContextLabelFn(
+ feature_name: str, **kwargs
+)
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/DatasetProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/DatasetProvider.md
new file mode 100644
index 00000000..e7f9b971
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/DatasetProvider.md
@@ -0,0 +1,27 @@
+# runner.DatasetProvider
+
+
+
+
+ View source
+on GitHub
+
+Helper class that provides a standard way to create an ABC using inheritance.
+
+
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+@abc.abstractmethod
+get_dataset(
+ context: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Get a `tf.data.Dataset` by `context` per replica.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/DotProductLinkPrediction.md b/tensorflow_gnn/docs/api_docs/python/runner/DotProductLinkPrediction.md
new file mode 100644
index 00000000..dd562c93
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/DotProductLinkPrediction.md
@@ -0,0 +1,200 @@
+# runner.DotProductLinkPrediction
+
+
+
+
+ View source
+on GitHub
+
+Implements edge score as dot product of features of endpoint nodes.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.DotProductLinkPrediction(
+ *,
+ node_feature_name: tfgnn.FieldName = tfgnn.HIDDEN_STATE,
+ readout_label_feature_name: str = 'label',
+ readout_node_set_name: tfgnn.NodeSetName = '_readout'
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+node_feature_name
+ |
+
+Name of feature where node state for link-prediction
+is read from. The final link prediction score will be:
+`score(graph.node_sets[source][node_feature_name],
+ graph.node_sets[target][node_feature_name])`
+where source and target , respectively, are:
+graph.edge_sets[readout_node_set_name+"/source"].adjacency.source_name
+and
+graph.edge_sets[readout_node_set_name+"/target"].adjacency.source_name
+ |
+
+
+readout_label_feature_name
+ |
+
+The labels for edge connections,
+source nodes
+graph.edge_sets[readout_node_set_name+"/source"].adjacency.source in
+node set graph.node_sets[source] against target nodes
+graph.edge_sets[readout_node_set_name+"/target"].adjacency.source in
+node set graph.node_sets[source] , must be stored in
+graph.node_sets[readout_node_set_name][readout_label_feature_name] .
+ |
+
+
+readout_node_set_name
+ |
+
+Determines the readout node-set, which must have
+feature readout_label_feature_name , and must receive connections (at
+target endpoints) from edge-sets readout_node_set_name+"/source" and
+readout_node_set_name+"/target" .
+ |
+
+
+
+## Methods
+
+losses
+
+View
+source
+
+
+losses() -> runner.Losses
+
+
+Binary cross-entropy.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ graph: tfgnn.GraphTensor
+) -> runner.Predictions
+
+
+Produces prediction outputs for the learning objective.
+
+Overall model composition* makes use of the Keras Functional API
+(https://www.tensorflow.org/guide/keras/functional) to map symbolic Keras
+`GraphTensor` inputs to symbolic Keras `Field` outputs. Outputs must match the
+structure (one or mapping) of labels from `preprocess`.
+
+*) `outputs = predict(GNN(inputs))` where `inputs` are those `GraphTensor`
+returned by `preprocess(...)`, `GNN` is the base GNN, `predict` is this method
+and `outputs` are the prediction outputs for the learning objective.
+
+
+
+
+
+Args |
+
+
+
+*args
+ |
+
+The symbolic Keras GraphTensor inputs(s). These inputs correspond
+(in sequence) to the base GNN output of each GraphTensor returned by
+preprocess(...) .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The model's prediction output for this task.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ gt: tfgnn.GraphTensor
+) -> Tuple[tfgnn.GraphTensor, tfgnn.Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/FitOrSkipPadding.md b/tensorflow_gnn/docs/api_docs/python/runner/FitOrSkipPadding.md
new file mode 100644
index 00000000..3e349260
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/FitOrSkipPadding.md
@@ -0,0 +1,49 @@
+# runner.FitOrSkipPadding
+
+
+
+
+ View source
+on GitHub
+
+Calculates fit or skip `SizeConstraints` for `GraphTensor` padding.
+
+Inherits From: [`GraphTensorPadding`](../runner/GraphTensorPadding.md)
+
+
+runner.FitOrSkipPadding(
+ gtspec: tfgnn.GraphTensorSpec,
+ dataset_provider: runner.DatasetProvider ,
+ min_nodes_per_component: Optional[Mapping[str, int]] = None,
+ fit_or_skip_sample_sample_size: int = 10000,
+ fit_or_skip_success_ratio: float = 0.99
+)
+
+
+
+
+See: `tfgnn.learn_fit_or_skip_size_constraints.`
+
+## Methods
+
+get_filter_fn
+
+View
+source
+
+
+get_filter_fn(
+ size_constraints: SizeConstraints
+) -> Callable[..., bool]
+
+
+get_size_constraints
+
+View
+source
+
+
+get_size_constraints(
+ target_batch_size: int
+) -> SizeConstraints
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphBinaryClassification.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphBinaryClassification.md
new file mode 100644
index 00000000..237ce1c9
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphBinaryClassification.md
@@ -0,0 +1,222 @@
+# runner.GraphBinaryClassification
+
+
+
+
+ View source
+on GitHub
+
+Graph binary (or multi-label) classification from pooled node states.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphBinaryClassification(
+ node_set_name: str,
+ units: int = 1,
+ *,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ reduce_type: str = 'mean',
+ name: str = 'classification_logits',
+ label_fn: Optional[LabelFn] = None,
+ label_feature_name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+node_set_name
+ |
+
+The node set to pool.
+ |
+
+
+units
+ |
+
+The units for the classification head. (Typically 1 for binary
+classification and the number of labels for multi-label classification.)
+ |
+
+
+state_name
+ |
+
+The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
+ |
+
+
+reduce_type
+ |
+
+The context pooling reduction type.
+ |
+
+
+name
+ |
+
+The classification head's layer name. To control the naming of saved
+model outputs see the runner model exporters (e.g.,
+KerasModelExporter ).
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> Field
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+Returns arbitrary task specific losses.
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for classification.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for classification.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The classification logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanAbsoluteError.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanAbsoluteError.md
new file mode 100644
index 00000000..0894d92a
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanAbsoluteError.md
@@ -0,0 +1,195 @@
+# runner.GraphMeanAbsoluteError
+
+
+
+
+ View source
+on GitHub
+
+Mean absolute error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphMeanAbsoluteError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ reduce_type: str = 'mean',
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanAbsolutePercentageError.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanAbsolutePercentageError.md
new file mode 100644
index 00000000..4794f58e
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanAbsolutePercentageError.md
@@ -0,0 +1,195 @@
+# runner.GraphMeanAbsolutePercentageError
+
+
+
+
+ View source
+on GitHub
+
+Mean absolute percentage error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphMeanAbsolutePercentageError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ reduce_type: str = 'mean',
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredError.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredError.md
new file mode 100644
index 00000000..fc7f9332
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredError.md
@@ -0,0 +1,195 @@
+# runner.GraphMeanSquaredError
+
+
+
+
+ View source
+on GitHub
+
+Mean squared error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphMeanSquaredError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ reduce_type: str = 'mean',
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredLogScaledError.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredLogScaledError.md
new file mode 100644
index 00000000..00e3fe85
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredLogScaledError.md
@@ -0,0 +1,155 @@
+# runner.GraphMeanSquaredLogScaledError
+
+
+
+
+ View source
+on GitHub
+
+Mean squared log scaled error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphMeanSquaredLogScaledError(
+ *args,
+ alpha_loss_param: float = 5.0,
+ epsilon_loss_param: float = 1e-08,
+ reduction: tf.keras.losses.Reduction = AUTO,
+ name: Optional[str] = None,
+ **kwargs
+)
+
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredLogarithmicError.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredLogarithmicError.md
new file mode 100644
index 00000000..3031008c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphMeanSquaredLogarithmicError.md
@@ -0,0 +1,195 @@
+# runner.GraphMeanSquaredLogarithmicError
+
+
+
+
+ View source
+on GitHub
+
+Mean squared logarithmic error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphMeanSquaredLogarithmicError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ reduce_type: str = 'mean',
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphMulticlassClassification.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphMulticlassClassification.md
new file mode 100644
index 00000000..e667763c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphMulticlassClassification.md
@@ -0,0 +1,239 @@
+# runner.GraphMulticlassClassification
+
+
+
+
+ View source
+on GitHub
+
+Graph multiclass classification from pooled node states.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.GraphMulticlassClassification(
+ node_set_name: str,
+ *,
+ num_classes: Optional[int] = None,
+ class_names: Optional[Sequence[str]] = None,
+ per_class_statistics: bool = False,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ reduce_type: str = 'mean',
+ name: str = 'classification_logits',
+ label_fn: Optional[LabelFn] = None,
+ label_feature_name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+node_set_name
+ |
+
+The node set to pool.
+ |
+
+
+num_classes
+ |
+
+The number of classes. Exactly one of num_classes or
+class_names must be specified
+ |
+
+
+class_names
+ |
+
+The class names. Exactly one of num_classes or
+class_names must be specified
+ |
+
+
+per_class_statistics
+ |
+
+Whether to compute statistics per class.
+ |
+
+
+state_name
+ |
+
+The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
+ |
+
+
+reduce_type
+ |
+
+The context pooling reduction type.
+ |
+
+
+name
+ |
+
+The classification head's layer name. To control the naming of saved
+model outputs see the runner model exporters (e.g.,
+KerasModelExporter ).
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> Field
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+Sparse categorical crossentropy loss.
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Sparse categorical metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for classification.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for classification.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The classification logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorPadding.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorPadding.md
new file mode 100644
index 00000000..63ece9ef
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorPadding.md
@@ -0,0 +1,37 @@
+# runner.GraphTensorPadding
+
+
+
+
+ View source
+on GitHub
+
+Collects `GraphtTensor` padding helpers.
+
+
+
+## Methods
+
+get_filter_fn
+
+View
+source
+
+
+@abc.abstractmethod
+get_filter_fn(
+ size_constraints: SizeConstraints
+) -> Callable[..., bool]
+
+
+get_size_constraints
+
+View
+source
+
+
+@abc.abstractmethod
+get_size_constraints(
+ target_batch_size: int
+) -> SizeConstraints
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorProcessorFn.md b/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorProcessorFn.md
new file mode 100644
index 00000000..55b2eb0e
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/GraphTensorProcessorFn.md
@@ -0,0 +1,26 @@
+# runner.GraphTensorProcessorFn
+
+
+
+
+ View source
+on GitHub
+
+A class for `GraphTensor` processing.
+
+
+
+## Methods
+
+__call__
+
+View
+source
+
+
+__call__(
+ inputs: GraphTensor
+) -> GraphTensor
+
+
+Processes a `GraphTensor`.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/HadamardProductLinkPrediction.md b/tensorflow_gnn/docs/api_docs/python/runner/HadamardProductLinkPrediction.md
new file mode 100644
index 00000000..d66469ba
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/HadamardProductLinkPrediction.md
@@ -0,0 +1,203 @@
+# runner.HadamardProductLinkPrediction
+
+
+
+
+ View source
+on GitHub
+
+Implements edge score as hadamard product of features of endpoint nodes.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.HadamardProductLinkPrediction(
+ *,
+ node_feature_name: tfgnn.FieldName = tfgnn.HIDDEN_STATE,
+ readout_label_feature_name: str = 'label',
+ readout_node_set_name: tfgnn.NodeSetName = '_readout'
+)
+
+
+
+
+The hadamard product is followed by one layer with scalar output.
+
+
+
+
+
+Args |
+
+
+
+node_feature_name
+ |
+
+Name of feature where node state for link-prediction
+is read from. The final link prediction score will be:
+`score(graph.node_sets[source][node_feature_name],
+ graph.node_sets[target][node_feature_name])`
+where source and target , respectively, are:
+graph.edge_sets[readout_node_set_name+"/source"].adjacency.source_name
+and
+graph.edge_sets[readout_node_set_name+"/target"].adjacency.source_name
+ |
+
+
+readout_label_feature_name
+ |
+
+The labels for edge connections,
+source nodes
+graph.edge_sets[readout_node_set_name+"/source"].adjacency.source in
+node set graph.node_sets[source] against target nodes
+graph.edge_sets[readout_node_set_name+"/target"].adjacency.source in
+node set graph.node_sets[source] , must be stored in
+graph.node_sets[readout_node_set_name][readout_label_feature_name] .
+ |
+
+
+readout_node_set_name
+ |
+
+Determines the readout node-set, which must have
+feature readout_label_feature_name , and must receive connections (at
+target endpoints) from edge-sets readout_node_set_name+"/source" and
+readout_node_set_name+"/target" .
+ |
+
+
+
+## Methods
+
+losses
+
+View
+source
+
+
+losses() -> runner.Losses
+
+
+Binary cross-entropy.
+
+metrics
+
+View
+source
+
+
+metrics() -> runner.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ graph: tfgnn.GraphTensor
+) -> runner.Predictions
+
+
+Produces prediction outputs for the learning objective.
+
+Overall model composition* makes use of the Keras Functional API
+(https://www.tensorflow.org/guide/keras/functional) to map symbolic Keras
+`GraphTensor` inputs to symbolic Keras `Field` outputs. Outputs must match the
+structure (one or mapping) of labels from `preprocess`.
+
+*) `outputs = predict(GNN(inputs))` where `inputs` are those `GraphTensor`
+returned by `preprocess(...)`, `GNN` is the base GNN, `predict` is this method
+and `outputs` are the prediction outputs for the learning objective.
+
+
+
+
+
+Args |
+
+
+
+*args
+ |
+
+The symbolic Keras GraphTensor inputs(s). These inputs correspond
+(in sequence) to the base GNN output of each GraphTensor returned by
+preprocess(...) .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The model's prediction output for this task.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ gt: tfgnn.GraphTensor
+) -> Tuple[tfgnn.GraphTensor, tfgnn.Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/IntegratedGradientsExporter.md b/tensorflow_gnn/docs/api_docs/python/runner/IntegratedGradientsExporter.md
new file mode 100644
index 00000000..2aac8e65
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/IntegratedGradientsExporter.md
@@ -0,0 +1,127 @@
+# runner.IntegratedGradientsExporter
+
+
+
+
+ View source
+on GitHub
+
+Exports a Keras model with an additional integrated gradients signature.
+
+Inherits From: [`ModelExporter`](../runner/ModelExporter.md)
+
+
+runner.IntegratedGradientsExporter(
+ integrated_gradients_output_name: Optional[str] = None,
+ subdirectory: Optional[str] = None,
+ random_counterfactual: bool = True,
+ steps: int = 32,
+ seed: Optional[int] = None,
+ options: Optional[tf.saved_model.SaveOptions] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+integrated_gradients_output_name
+ |
+
+The name for the integrated gradients
+output tensor. If unset, the tensor will be named by Keras defaults.
+ |
+
+
+subdirectory
+ |
+
+An optional subdirectory, if set: models are exported to
+os.path.join(export_dir, subdirectory).
+ |
+
+
+random_counterfactual
+ |
+
+Whether to use a random uniform counterfactual.
+ |
+
+
+steps
+ |
+
+The number of interpolations of the Riemann sum approximation.
+ |
+
+
+seed
+ |
+
+An optional random seed.
+ |
+
+
+options
+ |
+
+Options for saving to SavedModel.
+ |
+
+
+
+## Methods
+
+save
+
+View
+source
+
+
+save(
+ run_result: runner.RunResult ,
+ export_dir: str
+)
+
+
+Exports a Keras model with an additional integrated gradients signature.
+
+Importantly: the `run_result.preprocess_model`, if provided, and
+`run_result.trained_model` are stacked before any export. Stacking involves the
+chaining of the first output of `run_result.preprocess_model` to the only input
+of `run_result.trained_model.` The result is a model with the input of
+`run_result.preprocess_model` and the output of `run_result.trained_model.`
+
+Two serving signatures are exported:
+
+'serving_default') The default serving signature (i.e., the `preprocess_model`
+input signature), 'integrated_gradients') The integrated gradients signature
+(i.e., the `preprocess_model` input signature).
+
+
+
+
+
+Args |
+
+
+
+run_result
+ |
+
+A RunResult from training.
+ |
+
+
+export_dir
+ |
+
+A destination directory.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/KerasModelExporter.md b/tensorflow_gnn/docs/api_docs/python/runner/KerasModelExporter.md
new file mode 100644
index 00000000..4de9ca64
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/KerasModelExporter.md
@@ -0,0 +1,121 @@
+# runner.KerasModelExporter
+
+
+
+
+ View source
+on GitHub
+
+Exports a Keras model (with Keras API) via `tf.keras.models.save_model`.
+
+Inherits From: [`ModelExporter`](../runner/ModelExporter.md)
+
+
+runner.KerasModelExporter(
+ *,
+ output_names: Optional[Any] = None,
+ subdirectory: Optional[str] = None,
+ include_preprocessing: bool = True,
+ options: Optional[tf.saved_model.SaveOptions] = None,
+ use_legacy_model_save: Optional[bool] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+output_names
+ |
+
+By default, each output of the exported model uses the name
+of the final Keras layer that created it as its key in the SavedModel
+signature. This argument can be set to a single str name or a nested
+structure of str names to override the output names. Its nesting
+structure must match the exported model's output (as checked by
+tf.nest.assert_same_structure ). Any None values in output_names
+are ignored, leaving that output with its default name.
+ |
+
+
+subdirectory
+ |
+
+An optional subdirectory, if set: models are exported to
+os.path.join(export_dir, subdirectory).
+ |
+
+
+include_preprocessing
+ |
+
+Whether to include any preprocess_model.
+ |
+
+
+options
+ |
+
+Options for saving to a TensorFlow SavedModel .
+ |
+
+
+use_legacy_model_save
+ |
+
+Optional; most users can leave it unset to get a
+useful default for export to inference. See runner.export_model()
+for more.
+ |
+
+
+
+## Methods
+
+save
+
+View
+source
+
+
+save(
+ run_result: runner.RunResult ,
+ export_dir: str
+)
+
+
+Exports a Keras model (with Keras API) via tf.keras.models.save_model.
+
+Importantly: the `run_result.preprocess_model`, if provided, and
+`run_result.trained_model` are stacked before any export. Stacking involves the
+chaining of the first output of `run_result.preprocess_model` to the only input
+of `run_result.trained_model.` The result is a model with the input of
+`run_result.preprocess_model` and the output of `run_result.trained_model.`
+
+
+
+
+
+Args |
+
+
+
+run_result
+ |
+
+A RunResult from training.
+ |
+
+
+export_dir
+ |
+
+A destination directory.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainer.md b/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainer.md
new file mode 100644
index 00000000..a8468368
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainer.md
@@ -0,0 +1,263 @@
+# runner.KerasTrainer
+
+
+
+
+ View source
+on GitHub
+
+Trains using the `tf.keras.Model.fit` training loop.
+
+Inherits From: [`Trainer`](../runner/Trainer.md)
+
+
+runner.KerasTrainer(
+ strategy: tf.distribute.Strategy,
+ *,
+ model_dir: str,
+ checkpoint_options: Optional[runner.KerasTrainerCheckpointOptions ] = None,
+ backup_dir: Optional[str] = None,
+ steps_per_epoch: Optional[int] = None,
+ verbose: Union[int, str] = 'auto',
+ validation_steps: Optional[int] = None,
+ validation_per_epoch: Optional[int] = None,
+ validation_freq: Optional[int] = None,
+ summarize_every_n_steps: Union[int, str] = 500,
+ checkpoint_every_n_steps: Union[int, str] = 'epoch',
+ backup_and_restore: bool = True,
+ callbacks: Optional[Sequence[tf.keras.callbacks.Callback]] = None,
+ restore_best_weights: Optional[bool] = None,
+ options: Optional[runner.KerasTrainerOptions ] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+strategy
+ |
+
+A tf.distribute.Strategy.
+ |
+
+
+model_dir
+ |
+
+A model directory for summaries.
+ |
+
+
+checkpoint_options
+ |
+
+An optional configuration for checkpointing related
+configs. If checkpoint_options.checkpoint_dir is unset;
+os.path.join(model_dir, "ckpnt") is used.
+ |
+
+
+backup_dir
+ |
+
+An optional directory for backup, if unset;
+(os.path.join(model_dir, "backup"),) is used.
+ |
+
+
+steps_per_epoch
+ |
+
+The number of training steps per epoch. Optional,
+if unspecified: epochs are at tf.data.Dataset end.
+ |
+
+
+verbose
+ |
+
+Forwarded to tf.keras.Model.fit() . Possible values are
+0 (silent), 1 (print progress bar), 2 (one line per epoch), and
+"auto" (default) defers to keras to select verbosity.
+ |
+
+
+validation_steps
+ |
+
+The number of steps used during validation. Optional,
+if unspecified: the entire validation tf.data.Dataset is evaluated.
+ |
+
+
+validation_per_epoch
+ |
+
+The number of validations done per training epoch.
+Optional, if unspecified: Perform one validation per training epoch.
+Only one of validation_per_epoch and validation_freq can be
+specified.
+ |
+
+
+validation_freq
+ |
+
+Specifies how many training epochs to run before a new
+validation run is performed. Optional, if unspecified: Performs
+validation after every training epoch. Only one of
+validation_per_epoch and validation_freq can be specified.
+ |
+
+
+summarize_every_n_steps
+ |
+
+The frequency for writing TensorBoard summaries,
+as an integer number of steps, or "epoch" for once per epoch, or
+"never".
+ |
+
+
+checkpoint_every_n_steps
+ |
+
+The frequency for writing latest models, as an
+integer number of steps, or "epoch" for once per epoch, or "never".
+The best model will always be saved after each validation epoch except
+when this parameter is set to "never", because the validation metric is
+available only after validation epoch.
+ |
+
+
+backup_and_restore
+ |
+
+Whether to backup and restore (According to
+tf.keras.callbacks.BackupAndRestore ). The backup
+directory is determined by backup_dir .
+ |
+
+
+callbacks
+ |
+
+Optional additional tf.keras.callbacks.Callback for
+tf.keras.Model.fit.
+ |
+
+
+restore_best_weights
+ |
+
+Requires a checkpoint_every_n_steps other than
+"never." Whether to restore the best model weights as determined by
+tf.keras.callbacks.ModelCheckpoint after training. If unspecified,
+its value is determined at train(...) invocation: True if
+valid_ds_provider is not None else False .
+ |
+
+
+options
+ |
+
+A KerasTrainerOptions.
+ |
+
+
+
+
+
+
+
+Attributes |
+
+ model_dir |
+
+ | strategy |
+
+ |
+
+
+
+## Methods
+
+train
+
+View
+source
+
+
+train(
+ model_fn: Callable[[], tf.keras.Model],
+ train_ds_provider: runner.DatasetProvider ,
+ *,
+ epochs: int = 1,
+ valid_ds_provider: Optional[runner.DatasetProvider ] = None
+) -> tf.keras.Model
+
+
+Runs `tf.keras.Model.fit` with the`tf.distribute.Strategy` provided.
+
+
+
+
+
+Args |
+
+
+
+model_fn
+ |
+
+A ModelFn , to be invoked in the tf.distribute.Strategty
+scope.
+ |
+
+
+train_ds_provider
+ |
+
+A function that returns a tf.data.Dataset for
+training.The items of the tf.data.Dataset are pairs
+(graph_tensor, label) that represent one batch of per-replica training
+inputs after GraphTensor.merge_batch_to_components() has been applied.
+ |
+
+
+epochs
+ |
+
+The epochs to train: adjusted for validation_per_epoch.
+ |
+
+
+valid_ds_provider
+ |
+
+An optional function that returns a tf.data.Dataset
+for validation. The items of the tf.data.Dataset are pairs
+(graph_tensor, label) that represent one batch of per-replica training
+inputs after GraphTensor.merge_batch_to_components() has been applied.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A trained tf.keras.Model.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainerCheckpointOptions.md b/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainerCheckpointOptions.md
new file mode 100644
index 00000000..b3fbff8d
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainerCheckpointOptions.md
@@ -0,0 +1,108 @@
+# runner.KerasTrainerCheckpointOptions
+
+
+
+
+ View source
+on GitHub
+
+Provides Keras Checkpointing related configuration options.
+
+
+runner.KerasTrainerCheckpointOptions(
+ checkpoint_dir: Optional[str] = None,
+ best_checkpoint: str = 'best',
+ latest_checkpoint: str = 'latest'
+)
+
+
+
+
+
+
+
+Attributes |
+
+
+
+checkpoint_dir
+ |
+
+Directory path to save checkpoint files.
+ |
+
+
+best_checkpoint
+ |
+
+Filename for the best checkpoint.
+ |
+
+
+latest_checkpoint
+ |
+
+Filename for the latest checkpoint.
+ |
+
+
+
+## Methods
+
+best_checkpoint_filepath
+
+View
+source
+
+
+best_checkpoint_filepath() -> str
+
+
+latest_checkpoint_filepath
+
+View
+source
+
+
+latest_checkpoint_filepath() -> str
+
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+Return self==value.
+
+
+
+
+
+Class Variables |
+
+
+
+best_checkpoint
+ |
+
+'best'
+ |
+
+
+checkpoint_dir
+ |
+
+None
+ |
+
+
+latest_checkpoint
+ |
+
+'latest'
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainerOptions.md b/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainerOptions.md
new file mode 100644
index 00000000..fd30bc44
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/KerasTrainerOptions.md
@@ -0,0 +1,90 @@
+# runner.KerasTrainerOptions
+
+
+
+
+ View source
+on GitHub
+
+Provides Keras training related options.
+
+
+runner.KerasTrainerOptions(
+ policy: Optional[Union[str, tf.keras.mixed_precision.Policy]] = None,
+ soft_device_placement: bool = False,
+ enable_check_numerics: bool = False
+)
+
+
+
+
+
+
+
+Attributes |
+
+
+
+policy
+ |
+
+Dataclass field
+ |
+
+
+soft_device_placement
+ |
+
+Dataclass field
+ |
+
+
+enable_check_numerics
+ |
+
+Dataclass field
+ |
+
+
+
+## Methods
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+Return self==value.
+
+
+
+
+
+Class Variables |
+
+
+
+enable_check_numerics
+ |
+
+False
+ |
+
+
+policy
+ |
+
+None
+ |
+
+
+soft_device_placement
+ |
+
+False
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/Loss.md b/tensorflow_gnn/docs/api_docs/python/runner/Loss.md
new file mode 100644
index 00000000..fa45eb7e
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/Loss.md
@@ -0,0 +1,17 @@
+# runner.Loss
+
+
+
+This symbol is a **type alias**.
+
+#### Source:
+
+
+Loss = Callable[
+ tf.Tensor,
+ tf.Tensor,
+ tf.Tensor
+]
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/Losses.md b/tensorflow_gnn/docs/api_docs/python/runner/Losses.md
new file mode 100644
index 00000000..839e40c3
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/Losses.md
@@ -0,0 +1,16 @@
+# runner.Losses
+
+
+
+This symbol is a **type alias**.
+
+#### Source:
+
+
+Losses = Union[
+ runner.Loss ,
+ Mapping[str, runner.Loss ]
+]
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/Metrics.md b/tensorflow_gnn/docs/api_docs/python/runner/Metrics.md
new file mode 100644
index 00000000..5d281a66
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/Metrics.md
@@ -0,0 +1,17 @@
+# runner.Metrics
+
+
+
+This symbol is a **type alias**.
+
+#### Source:
+
+
+Metrics = Union[
+ runner.Loss ,
+ Sequence[runner.Loss ],
+ Mapping[str, Union[runner.Loss , Sequence[runner.Loss ]]]
+]
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/ModelExporter.md b/tensorflow_gnn/docs/api_docs/python/runner/ModelExporter.md
new file mode 100644
index 00000000..3fbc5b4c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/ModelExporter.md
@@ -0,0 +1,53 @@
+# runner.ModelExporter
+
+
+
+
+ View source
+on GitHub
+
+Saves a Keras model.
+
+
+
+## Methods
+
+save
+
+View
+source
+
+
+@abc.abstractmethod
+save(
+ run_result: RunResult, export_dir: str
+)
+
+
+Saves a Keras model.
+
+All persistence decisions are left to the implementation: e.g., a Keras model
+with full API or a simple `tf.train.Checkpoint` may be saved.
+
+
+
+
+
+Args |
+
+
+
+run_result
+ |
+
+A RunResult from training.
+ |
+
+
+export_dir
+ |
+
+A destination directory.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/NodeBinaryClassification.md b/tensorflow_gnn/docs/api_docs/python/runner/NodeBinaryClassification.md
new file mode 100644
index 00000000..2926e68c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/NodeBinaryClassification.md
@@ -0,0 +1,237 @@
+# runner.NodeBinaryClassification
+
+
+
+
+ View source
+on GitHub
+
+Node binary (or multi-label) classification via structured readout.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.NodeBinaryClassification(
+ key: str = 'seed',
+ units: int = 1,
+ *,
+ feature_name: str = tfgnn.HIDDEN_STATE,
+ readout_node_set: tfgnn.NodeSetName = '_readout',
+ validate: bool = True,
+ name: str = 'classification_logits',
+ label_fn: Optional[LabelFn] = None,
+ label_feature_name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+key
+ |
+
+A string key to select between possibly multiple named readouts.
+ |
+
+
+units
+ |
+
+The units for the classification head. (Typically 1 for binary
+classification and the number of labels for multi-label classification.)
+ |
+
+
+feature_name
+ |
+
+The name of the feature to read. If unset,
+tfgnn.HIDDEN_STATE will be read.
+ |
+
+
+readout_node_set
+ |
+
+A string, defaults to "_readout" . This is used as the
+name for the readout node set and as a name prefix for its edge sets.
+ |
+
+
+validate
+ |
+
+Setting this to false disables the validity checks for the
+auxiliary edge sets. This is stronlgy discouraged, unless great care is
+taken to run tfgnn.validate_graph_tensor_for_readout() earlier on
+structurally unchanged GraphTensors.
+ |
+
+
+name
+ |
+
+The classification head's layer name. To control the naming of saved
+model outputs see the runner model exporters (e.g.,
+KerasModelExporter ).
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> Field
+
+
+Gather activations from auxiliary node (and edge) sets.
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+Returns arbitrary task specific losses.
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for classification.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for classification.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The classification logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/NodeMulticlassClassification.md b/tensorflow_gnn/docs/api_docs/python/runner/NodeMulticlassClassification.md
new file mode 100644
index 00000000..308eea6c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/NodeMulticlassClassification.md
@@ -0,0 +1,254 @@
+# runner.NodeMulticlassClassification
+
+
+
+
+ View source
+on GitHub
+
+Node multiclass classification via structured readout.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.NodeMulticlassClassification(
+ key: str = 'seed',
+ *,
+ feature_name: str = tfgnn.HIDDEN_STATE,
+ readout_node_set: tfgnn.NodeSetName = '_readout',
+ validate: bool = True,
+ num_classes: Optional[int] = None,
+ class_names: Optional[Sequence[str]] = None,
+ per_class_statistics: bool = False,
+ name: str = 'classification_logits',
+ label_fn: Optional[LabelFn] = None,
+ label_feature_name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+key
+ |
+
+A string key to select between possibly multiple named readouts.
+ |
+
+
+feature_name
+ |
+
+The name of the feature to read. If unset,
+tfgnn.HIDDEN_STATE will be read.
+ |
+
+
+readout_node_set
+ |
+
+A string, defaults to "_readout" . This is used as the
+name for the readout node set and as a name prefix for its edge sets.
+ |
+
+
+validate
+ |
+
+Setting this to false disables the validity checks for the
+auxiliary edge sets. This is stronlgy discouraged, unless great care is
+taken to run tfgnn.validate_graph_tensor_for_readout() earlier on
+structurally unchanged GraphTensors.
+ |
+
+
+num_classes
+ |
+
+The number of classes. Exactly one of num_classes or
+class_names must be specified
+ |
+
+
+class_names
+ |
+
+The class names. Exactly one of num_classes or
+class_names must be specified
+ |
+
+
+per_class_statistics
+ |
+
+Whether to compute statistics per class.
+ |
+
+
+name
+ |
+
+The classification head's layer name. To control the naming of saved
+model outputs see the runner model exporters (e.g.,
+KerasModelExporter ).
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> Field
+
+
+Gather activations from auxiliary node (and edge) sets.
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+Sparse categorical crossentropy loss.
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Sparse categorical metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for classification.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for classification.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The classification logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/ParameterServerStrategy.md b/tensorflow_gnn/docs/api_docs/python/runner/ParameterServerStrategy.md
new file mode 100644
index 00000000..342da409
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/ParameterServerStrategy.md
@@ -0,0 +1,1082 @@
+# runner.ParameterServerStrategy
+
+
+
+
+ View source
+on GitHub
+
+A `ParameterServerStrategy` convenience wrapper.
+
+
+runner.ParameterServerStrategy(
+ min_shard_bytes: Optional[int] = None
+)
+
+
+
+
+
+
+
+Args |
+
+ cluster_resolver |
+a tf.distribute.cluster_resolver.ClusterResolver object. |
+
+variable_partitioner | a
+distribute.experimental.partitioners.Partitioner that specifies how
+to partition variables. If None , variables will not be partitioned.
+
+* Predefined partitioners in
+ tf.distribute.experimental.partitioners can be used for this
+ argument. A commonly used partitioner is
+ MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards =
+ num_ps) , which allocates at least 256K per shard, and each ps gets at
+ most one shard.
+
+* variable_partitioner will be called for each variable created
+ under strategy scope to instruct how the variable should be
+ partitioned. Variables that have only one partition along the partitioning
+ axis (i.e., no need for partition) will be created as a normal
+ tf.Variable .
+
+* Only the first / outermost axis partitioning is supported.
+
+* Div partition strategy is used to partition variables. Assuming we assign
+ consecutive integer ids along the first axis of a variable, then ids are
+ assigned to shards in a contiguous manner, while attempting to keep each
+ shard size identical. If the ids do not evenly divide the number of shards,
+ each of the first several shards will be assigned one more id. For instance,
+ a variable whose first dimension is 13 has 13 ids, and they are split across
+ 5 shards as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11,
+ 12]] .
+
+* Variables created under strategy.extended.colocate_vars_with will
+ not be partitioned.
+ |
+
+
+
+
+
+
+
+Attributes |
+
+ cluster_resolver |
+Returns the cluster resolver associated with this strategy.
+
+In general, when using a multi-worker tf.distribute strategy such
+as tf.distribute.experimental.MultiWorkerMirroredStrategy or
+tf.distribute.TPUStrategy() , there is a
+tf.distribute.cluster_resolver.ClusterResolver associated with the
+strategy used, and such an instance is returned by this property.
+
+Strategies that intend to have an associated
+tf.distribute.cluster_resolver.ClusterResolver must set the
+relevant attribute, or override this property; otherwise, None is
+returned by default. Those strategies should also provide information regarding
+what is returned by this property.
+
+Single-worker strategies usually do not have a
+tf.distribute.cluster_resolver.ClusterResolver , and in those cases
+this property will return None .
+
+The tf.distribute.cluster_resolver.ClusterResolver may be useful
+when the user needs to access information such as the cluster spec, task type or
+task id. For example,
+
+```python
+
+os.environ['TF_CONFIG'] = json.dumps({
+ 'cluster': {
+ 'worker': ["localhost:12345", "localhost:23456"],
+ 'ps': ["localhost:34567"]
+ },
+ 'task': {'type': 'worker', 'index': 0}
+})
+
+# This implicitly uses TF_CONFIG for the cluster and current task info.
+strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
+
+...
+
+if strategy.cluster_resolver.task_type == 'worker':
+ # Perform something that's only applicable on workers. Since we set this
+ # as a worker above, this block will run on this particular instance.
+elif strategy.cluster_resolver.task_type == 'ps':
+ # Perform something that's only applicable on parameter servers. Since we
+ # set this as a worker above, this block will not run on this particular
+ # instance.
+```
+
+For more information, please see
+tf.distribute.cluster_resolver.ClusterResolver 's API docstring.
+ |
+
+
+extended
+ |
+
+tf.distribute.StrategyExtended with additional methods.
+ |
+
+
+num_replicas_in_sync
+ |
+
+Returns number of replicas over which gradients are aggregated.
+ |
+
+
+
+## Methods
+
+distribute_datasets_from_function
+
+
+distribute_datasets_from_function(
+ dataset_fn, options=None
+)
+
+
+Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
+
+The argument `dataset_fn` that users pass in is an input function that has a
+`tf.distribute.InputContext` argument and returns a `tf.data.Dataset` instance.
+It is expected that the returned dataset from `dataset_fn` is already batched by
+per-replica batch size (i.e. global batch size divided by the number of replicas
+in sync) and sharded. `tf.distribute.Strategy.distribute_datasets_from_function`
+does not batch or shard the `tf.data.Dataset` instance returned from the input
+function. `dataset_fn` will be called on the CPU device of each of the workers
+and each generates a dataset where every replica on that worker will dequeue one
+batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued
+from the `Dataset` every step).
+
+This method can be used for several purposes. First, it allows you to specify
+your own batching and sharding logic. (In contrast,
+`tf.distribute.experimental_distribute_dataset` does batching and sharding for
+you.) For example, where `experimental_distribute_dataset` is unable to shard
+the input files, this method might be used to manually shard the dataset
+(avoiding the slow fallback behavior in `experimental_distribute_dataset`). In
+cases where the dataset is infinite, this sharding can be done by creating
+dataset replicas that differ only in their random seed.
+
+The `dataset_fn` should take an `tf.distribute.InputContext` instance where
+information about batching and input replication can be accessed.
+
+You can use `element_spec` property of the `tf.distribute.DistributedDataset`
+returned by this API to query the `tf.TypeSpec` of the elements returned by the
+iterator. This can be used to set the `input_signature` property of a
+`tf.function`. Follow `tf.distribute.DistributedDataset.element_spec` to see an
+example.
+
+IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
+per-replica batch size, unlike `experimental_distribute_dataset`, which uses the
+global batch size. This may be computed using
+`input_context.get_per_replica_batch_size`.
+
+Note: If you are using TPUStrategy, the order in which the data is processed by
+the workers when using `tf.distribute.Strategy.experimental_distribute_dataset`
+or `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed.
+This is typically required if you are using `tf.distribute` to scale prediction.
+You can however insert an index for each element in the batch and order outputs
+accordingly. Refer to
+[this snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
+for an example of how to order outputs.
+
+Note: Stateful dataset transformations are currently not supported with
+`tf.distribute.experimental_distribute_dataset` or
+`tf.distribute.distribute_datasets_from_function`. Any stateful ops that the
+dataset may have are currently ignored. For example, if your dataset has a
+`map_fn` that uses `tf.random.uniform` to rotate an image, then you have a
+dataset graph that depends on state (i.e the random seed) on the local machine
+where the python process is being executed.
+
+For a tutorial on more usage and properties of this method, refer to the
+[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
+If you are interested in last partial batch handling, read
+[this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
+
+
+
+
+
+Args |
+
+
+
+dataset_fn
+ |
+
+A function taking a tf.distribute.InputContext instance and
+returning a tf.data.Dataset .
+ |
+
+
+options
+ |
+
+tf.distribute.InputOptions used to control options on how this
+dataset is distributed.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.distribute.DistributedDataset .
+ |
+
+
+
+
+experimental_distribute_dataset
+
+
+experimental_distribute_dataset(
+ dataset, options=None
+)
+
+
+Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
+
+The returned `tf.distribute.DistributedDataset` can be iterated over similar to
+regular datasets. NOTE: The user cannot add any more transformations to a
+`tf.distribute.DistributedDataset`. You can only create an iterator or examine
+the `tf.TypeSpec` of the data generated by it. See API docs of
+`tf.distribute.DistributedDataset` to learn more.
+
+The following is an example:
+
+```
+>>> global_batch_size = 2
+>>> # Passing the devices is optional.
+... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
+>>> # Create a dataset
+... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
+>>> # Distribute that dataset
+... dist_dataset = strategy.experimental_distribute_dataset(dataset)
+>>> @tf.function
+... def replica_fn(input):
+... return input*2
+>>> result = []
+>>> # Iterate over the `tf.distribute.DistributedDataset`
+... for x in dist_dataset:
+... # process dataset elements
+... result.append(strategy.run(replica_fn, args=(x,)))
+>>> print(result)
+[PerReplica:{
+ 0: ,
+ 1:
+}, PerReplica:{
+ 0: ,
+ 1:
+}]
+```
+
+Three key actions happening under the hood of this method are batching,
+sharding, and prefetching.
+
+In the code snippet above, `dataset` is batched by `global_batch_size`, and
+calling `experimental_distribute_dataset` on it rebatches `dataset` to a new
+batch size that is equal to the global batch size divided by the number of
+replicas in sync. We iterate through it using a Pythonic for loop. `x` is a
+`tf.distribute.DistributedValues` containing data for all replicas, and each
+replica gets data of the new batch size. `tf.distribute.Strategy.run` will take
+care of feeding the right per-replica data in `x` to the right `replica_fn`
+executed on each replica.
+
+Sharding contains autosharding across multiple workers and within every worker.
+First, in multi-worker distributed training (i.e. when you use
+`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
+`tf.distribute.TPUStrategy`), autosharding a dataset over a set of workers means
+that each worker is assigned a subset of the entire dataset (if the right
+`tf.data.experimental.AutoShardPolicy` is set). This is to ensure that at each
+step, a global batch size of non-overlapping dataset elements will be processed
+by each worker. Autosharding has a couple of different options that can be
+specified using `tf.data.experimental.DistributeOptions`. Then, sharding within
+each worker means the method will split the data among all the worker devices
+(if more than one a present). This will happen regardless of multi-worker
+autosharding.
+
+Note: for autosharding across multiple workers, the default mode is
+`tf.data.experimental.AutoShardPolicy.AUTO`. This mode will attempt to shard the
+input dataset by files if the dataset is being created out of reader datasets
+(e.g. `tf.data.TFRecordDataset`, `tf.data.TextLineDataset`, etc.) or otherwise
+shard the dataset by data, where each of the workers will read the entire
+dataset and only process the shard assigned to it. However, if you have less
+than one input file per worker, we suggest that you disable dataset autosharding
+across workers by setting the
+`tf.data.experimental.DistributeOptions.auto_shard_policy` to be
+`tf.data.experimental.AutoShardPolicy.OFF`.
+
+By default, this method adds a prefetch transformation at the end of the user
+provided `tf.data.Dataset` instance. The argument to the prefetch transformation
+which is `buffer_size` is equal to the number of replicas in sync.
+
+If the above batch splitting and dataset sharding logic is undesirable, please
+use `tf.distribute.Strategy.distribute_datasets_from_function` instead, which
+does not do any automatic batching or sharding for you.
+
+Note: If you are using TPUStrategy, the order in which the data is processed by
+the workers when using `tf.distribute.Strategy.experimental_distribute_dataset`
+or `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed.
+This is typically required if you are using `tf.distribute` to scale prediction.
+You can however insert an index for each element in the batch and order outputs
+accordingly. Refer to
+[this snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
+for an example of how to order outputs.
+
+Note: Stateful dataset transformations are currently not supported with
+`tf.distribute.experimental_distribute_dataset` or
+`tf.distribute.distribute_datasets_from_function`. Any stateful ops that the
+dataset may have are currently ignored. For example, if your dataset has a
+`map_fn` that uses `tf.random.uniform` to rotate an image, then you have a
+dataset graph that depends on state (i.e the random seed) on the local machine
+where the python process is being executed.
+
+For a tutorial on more usage and properties of this method, refer to the
+[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
+If you are interested in last partial batch handling, read
+[this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
+
+
+
+
+
+Args |
+
+
+
+dataset
+ |
+
+tf.data.Dataset that will be sharded across all replicas using
+the rules stated above.
+ |
+
+
+options
+ |
+
+tf.distribute.InputOptions used to control options on how this
+dataset is distributed.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.distribute.DistributedDataset .
+ |
+
+
+
+
+experimental_distribute_values_from_function
+
+
+experimental_distribute_values_from_function(
+ value_fn
+)
+
+
+Generates `tf.distribute.DistributedValues` from `value_fn`.
+
+This function is to generate `tf.distribute.DistributedValues` to pass into
+`run`, `reduce`, or other methods that take distributed values when not using
+datasets.
+
+
+
+
+
+Args |
+
+
+
+value_fn
+ |
+
+The function to run to generate values. It is called for
+each replica with tf.distribute.ValueContext as the sole argument. It
+must return a Tensor or a type that can be converted to a Tensor.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.distribute.DistributedValues containing a value for each replica.
+ |
+
+
+
+
+#### Example usage:
+
+1. Return constant value per replica:
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> def value_fn(ctx):
+ ... return tf.constant(1.)
+ >>> distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ >>> local_result = strategy.experimental_local_results(
+ ... distributed_values)
+ >>> local_result
+ (,
+ )
+ ```
+
+2. Distribute values in array based on replica_id: {: value=2}
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> array_value = np.array([3., 2., 1.])
+ >>> def value_fn(ctx):
+ ... return array_value[ctx.replica_id_in_sync_group]
+ >>> distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ >>> local_result = strategy.experimental_local_results(
+ ... distributed_values)
+ >>> local_result
+ (3.0, 2.0)
+ ```
+
+3. Specify values using num_replicas_in_sync: {: value=3}
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> def value_fn(ctx):
+ ... return ctx.num_replicas_in_sync
+ >>> distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ >>> local_result = strategy.experimental_local_results(
+ ... distributed_values)
+ >>> local_result
+ (2, 2)
+ ```
+
+4. Place values on devices and distribute: {: value=4}
+
+ ```
+ strategy = tf.distribute.TPUStrategy()
+ worker_devices = strategy.extended.worker_devices
+ multiple_values = []
+ for i in range(strategy.num_replicas_in_sync):
+ with tf.device(worker_devices[i]):
+ multiple_values.append(tf.constant(1.0))
+
+ def value_fn(ctx):
+ return multiple_values[ctx.replica_id_in_sync_group]
+
+ distributed_values = strategy.
+ experimental_distribute_values_from_function(
+ value_fn)
+ ```
+
+experimental_local_results
+
+
+experimental_local_results(
+ value
+)
+
+
+Returns the list of all local per-replica values contained in `value`.
+
+Note: This only returns values on the worker initiated by this client. When
+using a `tf.distribute.Strategy` like
+`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker will be
+its own client, and this function will only return values computed on that
+worker.
+
+
+
+
+
+Args |
+
+
+
+value
+ |
+
+A value returned by experimental_run() , run(), or a variable
+created in scope`.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of values contained in value where ith element corresponds to
+ith replica. If value represents a single value, this returns
+(value,).
+ |
+
+
+
+
+gather
+
+
+gather(
+ value, axis
+)
+
+
+Gather `value` across replicas along `axis` to the current device.
+
+Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like object `value`,
+this API gathers and concatenates `value` across replicas along the `axis`-th
+dimension. The result is copied to the "current" device, which would typically
+be the CPU of the worker on which the program is running. For
+`tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client
+`tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of each worker.
+
+This API can only be called in the cross-replica context. For a counterpart in
+the replica context, see `tf.distribute.ReplicaContext.all_gather`.
+
+Note: For all strategies except `tf.distribute.TPUStrategy`, the input `value`
+on different replicas must have the same rank, and their shapes must be the same
+in all dimensions except the `axis`-th dimension. In other words, their shapes
+cannot be different in a dimension `d` where `d` does not equal to the `axis`
+argument. For example, given a `tf.distribute.DistributedValues` with component
+tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
+`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
+`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, all
+tensors must have exactly the same rank and same shape.
+
+Note: Given a `tf.distribute.DistributedValues` `value`, its component tensors
+must have a non-zero rank. Otherwise, consider using `tf.expand_dims` before
+gathering them.
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
+... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
+>>> @tf.function
+... def run():
+... return strategy.gather(distributed_values, axis=0)
+>>> run()
+
+```
+
+Consider the following example for more combinations:
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
+>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
+>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
+>>> @tf.function
+... def run(axis):
+... return strategy.gather(distributed_values, axis=axis)
+>>> axis=0
+>>> run(axis)
+
+>>> axis=1
+>>> run(axis)
+
+>>> axis=2
+>>> run(axis)
+
+```
+
+
+
+
+
+Args |
+
+
+
+value
+ |
+
+a tf.distribute.DistributedValues instance, e.g. returned by
+Strategy.run , to be combined into a single tensor. It can also be a
+regular tensor when used with tf.distribute.OneDeviceStrategy or the
+default strategy. The tensors that constitute the DistributedValues
+can only be dense tensors with non-zero rank, NOT a tf.IndexedSlices .
+ |
+
+
+axis
+ |
+
+0-D int32 Tensor. Dimension along which to gather. Must be in the
+range [0, rank(value)).
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A Tensor that's the concatenation of value across replicas along
+axis dimension.
+ |
+
+
+
+
+reduce
+
+
+reduce(
+ reduce_op, value, axis
+)
+
+
+Reduce `value` across replicas and return result on current device.
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+>>> def step_fn():
+... i = tf.distribute.get_replica_context().replica_id_in_sync_group
+... return tf.identity(i)
+>>>
+>>> per_replica_result = strategy.run(step_fn)
+>>> total = strategy.reduce("SUM", per_replica_result, axis=None)
+>>> total
+
+```
+
+To see how this would look with multiple replicas, consider the same example
+with MirroredStrategy with 2 GPUs:
+
+```python
+strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
+def step_fn():
+ i = tf.distribute.get_replica_context().replica_id_in_sync_group
+ return tf.identity(i)
+
+per_replica_result = strategy.run(step_fn)
+# Check devices on which per replica result is:
+strategy.experimental_local_results(per_replica_result)[0].device
+# /job:localhost/replica:0/task:0/device:GPU:0
+strategy.experimental_local_results(per_replica_result)[1].device
+# /job:localhost/replica:0/task:0/device:GPU:1
+
+total = strategy.reduce("SUM", per_replica_result, axis=None)
+# Check device on which reduced result is:
+total.device
+# /job:localhost/replica:0/task:0/device:CPU:0
+
+```
+
+This API is typically used for aggregating the results returned from different
+replicas, for reporting etc. For example, loss computed from different replicas
+can be averaged using this API before printing.
+
+Note: The result is copied to the "current" device - which would typically be
+the CPU of the worker on which the program is running. For `TPUStrategy`, it is
+the first TPU host. For multi client `MultiWorkerMirroredStrategy`, this is CPU
+of each worker.
+
+There are a number of different tf.distribute APIs for reducing values across
+replicas: * `tf.distribute.ReplicaContext.all_reduce`: This differs from
+`Strategy.reduce` in that it is for replica context and does not copy the
+results to the host device. `all_reduce` should be typically used for reductions
+inside the training step such as gradients. *
+`tf.distribute.StrategyExtended.reduce_to` and
+`tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more advanced
+versions of `Strategy.reduce` as they allow customizing the destination of the
+result. They are also called in cross replica context.
+
+*What should axis be?*
+
+Given a per-replica value returned by `run`, say a per-example loss, the batch
+will be divided across all the replicas. This function allows you to aggregate
+across replicas and optionally also across batch elements by specifying the axis
+parameter accordingly.
+
+For example, if you have a global batch size of 8 and 2 replicas, values for
+examples `[0, 1, 2, 3]` will be on replica 0 and `[4, 5, 6, 7]` will be on
+replica 1. With `axis=None`, `reduce` will aggregate only across replicas,
+returning `[0+4, 1+5, 2+6, 3+7]`. This is useful when each replica is computing
+a scalar or some other value that doesn't have a "batch" dimension (like a
+gradient or loss). `strategy.reduce("sum", per_replica_result, axis=None)`
+
+Sometimes, you will want to aggregate across both the global batch *and* all
+replicas. You can get this behavior by specifying the batch dimension as the
+`axis`, typically `axis=0`. In this case it would return a scalar
+`0+1+2+3+4+5+6+7`. `strategy.reduce("sum", per_replica_result, axis=0)`
+
+If there is a last partial batch, you will need to specify an axis so that the
+resulting shape is consistent across replicas. So if the last batch has size 6
+and it is divided into [0, 1, 2, 3] and [4, 5], you would get a shape mismatch
+unless you specify `axis=0`. If you specify `tf.distribute.ReduceOp.MEAN`, using
+`axis=0` will use the correct denominator of 6. Contrast this with computing
+`reduce_mean` to get a scalar value on each replica and this function to average
+those means, which will weigh some values `1/8` and others `1/4`.
+
+
+
+
+
+Args |
+
+
+
+reduce_op
+ |
+
+a tf.distribute.ReduceOp value specifying how values should
+be combined. Allows using string representation of the enum such as
+"SUM", "MEAN".
+ |
+
+
+value
+ |
+
+a tf.distribute.DistributedValues instance, e.g. returned by
+Strategy.run , to be combined into a single tensor. It can also be a
+regular tensor when used with OneDeviceStrategy or default strategy.
+ |
+
+
+axis
+ |
+
+specifies the dimension to reduce along within each
+replica's tensor. Should typically be set to the batch dimension, or
+None to only reduce across replicas (e.g. if the tensor has no batch
+dimension).
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A Tensor .
+ |
+
+
+
+
+run
+
+
+run(
+ fn, args=(), kwargs=None, options=None
+)
+
+
+Invokes `fn` on each replica, with the given arguments.
+
+This method is the primary way to distribute your computation with a
+tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
+have `tf.distribute.DistributedValues`, such as those produced by a
+`tf.distribute.DistributedDataset` from
+`tf.distribute.Strategy.experimental_distribute_dataset` or
+`tf.distribute.Strategy.distribute_datasets_from_function`, when `fn` is
+executed on a particular replica, it will be executed with the component of
+`tf.distribute.DistributedValues` that correspond to that replica.
+
+`fn` is invoked under a replica context. `fn` may call
+`tf.distribute.get_replica_context()` to access members such as `all_reduce`.
+Please see the module-level docstring of tf.distribute for the concept of
+replica context.
+
+All arguments in `args` or `kwargs` can be a nested structure of tensors, e.g. a
+list of tensors, in which case `args` and `kwargs` will be passed to the `fn`
+invoked on each replica. Or `args` or `kwargs` can be
+`tf.distribute.DistributedValues` containing tensors or composite tensors, i.e.
+`tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call will get
+the component of a `tf.distribute.DistributedValues` corresponding to its
+replica. Note that arbitrary Python values that are not of the types above are
+not supported.
+
+IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
+whether eager execution is enabled, `fn` may be called one or more times. If
+`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is called
+inside a `tf.function` (eager execution is disabled inside a `tf.function` by
+default), `fn` is called once per replica to generate a Tensorflow graph, which
+will then be reused for execution with new inputs. Otherwise, if eager execution
+is enabled, `fn` will be called once per replica every step just like regular
+python code.
+
+#### Example usage:
+
+1. Constant tensor input.
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> tensor_input = tf.constant(3.0)
+ >>> @tf.function
+ ... def replica_fn(input):
+ ... return input*2.0
+ >>> result = strategy.run(replica_fn, args=(tensor_input,))
+ >>> result
+ PerReplica:{
+ 0: ,
+ 1:
+ }
+ ```
+
+2. DistributedValues input. {: value=2}
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> @tf.function
+ ... def run():
+ ... def value_fn(value_context):
+ ... return value_context.num_replicas_in_sync
+ ... distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ ... def replica_fn2(input):
+ ... return input*2
+ ... return strategy.run(replica_fn2, args=(distributed_values,))
+ >>> result = run()
+ >>> result
+
+ ```
+
+3. Use `tf.distribute.ReplicaContext` to allreduce values. {: value=3}
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
+ >>> @tf.function
+ ... def run():
+ ... def value_fn(value_context):
+ ... return tf.constant(value_context.replica_id_in_sync_group)
+ ... distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ ... def replica_fn(input):
+ ... return tf.distribute.get_replica_context().all_reduce(
+ ... "sum", input)
+ ... return strategy.run(replica_fn, args=(distributed_values,))
+ >>> result = run()
+ >>> result
+ PerReplica:{
+ 0: ,
+ 1:
+ }
+ ```
+
+
+
+
+
+Args |
+
+
+
+fn
+ |
+
+The function to run on each replica.
+ |
+
+
+args
+ |
+
+Optional positional arguments to fn . Its element can be a tensor,
+a nested structure of tensors or a tf.distribute.DistributedValues .
+ |
+
+
+kwargs
+ |
+
+Optional keyword arguments to fn . Its element can be a tensor,
+a nested structure of tensors or a tf.distribute.DistributedValues .
+ |
+
+
+options
+ |
+
+An optional instance of tf.distribute.RunOptions specifying
+the options to run fn .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Merged return value of fn across replicas. The structure of the return
+value is the same as the return value from fn . Each element in the
+structure can either be tf.distribute.DistributedValues , Tensor
+objects, or Tensor s (for example, if running on a single replica).
+ |
+
+
+
+
+scope
+
+
+scope()
+
+
+Context manager to make the strategy current and distribute variables.
+
+This method returns a context manager, and is used as follows:
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+>>> # Variable created inside scope:
+>>> with strategy.scope():
+... mirrored_variable = tf.Variable(1.)
+>>> mirrored_variable
+MirroredVariable:{
+ 0: ,
+ 1:
+}
+>>> # Variable created outside scope:
+>>> regular_variable = tf.Variable(1.)
+>>> regular_variable
+
+```
+
+*What happens when Strategy.scope is entered?*
+
+* `strategy` is installed in the global context as the "current" strategy.
+ Inside this scope, `tf.distribute.get_strategy()` will now return this
+ strategy. Outside this scope, it returns the default no-op strategy.
+* Entering the scope also enters the "cross-replica context". See
+ `tf.distribute.StrategyExtended` for an explanation on cross-replica and
+ replica contexts.
+* Variable creation inside `scope` is intercepted by the strategy. Each
+ strategy defines how it wants to affect the variable creation. Sync
+ strategies like `MirroredStrategy`, `TPUStrategy` and
+ `MultiWorkerMiroredStrategy` create variables replicated on each replica,
+ whereas `ParameterServerStrategy` creates variables on the parameter
+ servers. This is done using a custom `tf.variable_creator_scope`.
+* In some strategies, a default device scope may also be entered: in
+ `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is entered
+ on each worker.
+
+Note: Entering a scope does not automatically distribute a computation, except
+in the case of high level training framework like keras `model.fit`. If you're
+not using `model.fit`, you need to use `strategy.run` API to explicitly
+distribute that computation. See an example in the
+[custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
+
+*What should be in scope and what should be outside?*
+
+There are a number of requirements on what needs to happen inside the scope.
+However, in places where we have information about which strategy is in use, we
+often enter the scope for the user, so they don't have to do it explicitly (i.e.
+calling those either inside or outside the scope is OK).
+
+* Anything that creates variables that should be distributed variables must be
+ called in a `strategy.scope`. This can be accomplished either by directly
+ calling the variable creating function within the scope context, or by
+ relying on another API like `strategy.run` or `keras.Model.fit` to
+ automatically enter it for you. Any variable that is created outside scope
+ will not be distributed and may have performance implications. Some common
+ objects that create variables in TF are Models, Optimizers, Metrics. Such
+ objects should always be initialized in the scope, and any functions that
+ may lazily create variables (e.g., `Model.__call__()`, tracing a
+ `tf.function`, etc.) should similarly be called within scope. Another source
+ of variable creation can be a checkpoint restore - when variables are
+ created lazily. Note that any variable created inside a strategy captures
+ the strategy information. So reading and writing to these variables outside
+ the `strategy.scope` can also work seamlessly, without the user having to
+ enter the scope.
+* Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
+ require to be in a strategy's scope, enter the scope automatically, which
+ means when using those APIs you don't need to explicitly enter the scope
+ yourself.
+* When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
+ object captures the scope information. When high level training framework
+ methods such as `model.compile`, `model.fit`, etc. are then called, the
+ captured scope will be automatically entered, and the associated strategy
+ will be used to distribute the training etc. See a detailed example in
+ [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
+ WARNING: Simply calling `model(..)` does not automatically enter the
+ captured scope -- only high level training framework APIs support this
+ behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
+ and `model.save` can all be called inside or outside the scope.
+* The following can be either inside or outside the scope:
+ * Creating the input datasets
+ * Defining `tf.function`s that represent your training step
+ * Saving APIs such as `tf.saved_model.save`. Loading creates variables, so
+ that should go inside the scope if you want to train the model in a
+ distributed way.
+ * Checkpoint saving. As mentioned above - `checkpoint.restore` may
+ sometimes need to be inside scope if it creates variables.
+
+
+
+
+
+Returns |
+
+
+A context manager.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/PassthruDatasetProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/PassthruDatasetProvider.md
new file mode 100644
index 00000000..3e1bdaea
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/PassthruDatasetProvider.md
@@ -0,0 +1,40 @@
+# runner.PassthruDatasetProvider
+
+
+
+
+ View source
+on GitHub
+
+Builds a `tf.data.Dataset` from a pass thru dataset.
+
+Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md)
+
+
+runner.PassthruDatasetProvider(
+ dataset: tf.data.Dataset,
+ *,
+ shuffle_datasets: bool = False,
+ examples_shuffle_size: Optional[int] = None
+)
+
+
+
+
+Passes any `dataset` thru: omitting any sharding. For detailed documentation,
+see the filename dataset provider complement: `SimpleDatasetsProvider.`
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+get_dataset(
+ _: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Gets a `tf.data.Dataset` omitting any input context.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/PassthruSampleDatasetsProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/PassthruSampleDatasetsProvider.md
new file mode 100644
index 00000000..0c841412
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/PassthruSampleDatasetsProvider.md
@@ -0,0 +1,46 @@
+# runner.PassthruSampleDatasetsProvider
+
+
+
+
+ View source
+on GitHub
+
+Builds a sampled `tf.data.Dataset` from multiple pass thru datasets.
+
+Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md)
+
+
+runner.PassthruSampleDatasetsProvider(
+ principal_dataset: tf.data.Dataset,
+ extra_datasets: Sequence[tf.data.Dataset],
+ principal_weight: Optional[float] = None,
+ extra_weights: Optional[Sequence[float]] = None,
+ *,
+ principal_cardinality: Optional[int] = None,
+ fixed_cardinality: bool = False,
+ shuffle_dataset: bool = False,
+ examples_shuffle_size: Optional[int] = None
+)
+
+
+
+
+Passes any `principal_dataset` and `extra_datasets` thru: omitting any sharding.
+For detailed documentation, see the filename dataset provider complement:
+`SimpleSampleDatasetsProvider.`
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+get_dataset(
+ _: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Gets a sampled `tf.data.Dataset` omitting any input context.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/Predictions.md b/tensorflow_gnn/docs/api_docs/python/runner/Predictions.md
new file mode 100644
index 00000000..31446b86
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/Predictions.md
@@ -0,0 +1,17 @@
+# runner.Predictions
+
+
+
+This symbol is a **type alias**.
+
+#### Source:
+
+
+Predictions = Union[
+ tf.Tensor,
+ tf.RaggedTensor,
+ Mapping[str, Union[tf.Tensor, tf.RaggedTensor]]
+]
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeBinaryClassification.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeBinaryClassification.md
new file mode 100644
index 00000000..817ae0f5
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeBinaryClassification.md
@@ -0,0 +1,216 @@
+# runner.RootNodeBinaryClassification
+
+
+
+
+ View source
+on GitHub
+
+Root node binary (or multi-label) classification.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeBinaryClassification(
+ node_set_name: str,
+ units: int = 1,
+ *,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ name: str = 'classification_logits',
+ label_fn: Optional[LabelFn] = None,
+ label_feature_name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+node_set_name
+ |
+
+The node set containing the root node.
+ |
+
+
+units
+ |
+
+The units for the classification head. (Typically 1 for binary
+classification and the number of labels for multi-label classification.)
+ |
+
+
+state_name
+ |
+
+The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
+ |
+
+
+name
+ |
+
+The classification head's layer name. To control the naming of saved
+model outputs see the runner model exporters (e.g.,
+KerasModelExporter ).
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> Field
+
+
+Gather activations from root nodes.
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+Returns arbitrary task specific losses.
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for classification.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for classification.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The classification logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeLabelFn.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeLabelFn.md
new file mode 100644
index 00000000..80a27eee
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeLabelFn.md
@@ -0,0 +1,20 @@
+# runner.RootNodeLabelFn
+
+
+
+
+ View source
+on GitHub
+
+Reads out a `tfgnn.Field` from the `GraphTensor` root (i.e. first) node.
+
+
+runner.RootNodeLabelFn(
+ node_set_name: tfgnn.NodeSetName,
+ *,
+ feature_name: tfgnn.FieldName = tfgnn.HIDDEN_STATE,
+ **kwargs
+)
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsoluteError.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsoluteError.md
new file mode 100644
index 00000000..08daf609
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsoluteError.md
@@ -0,0 +1,194 @@
+# runner.RootNodeMeanAbsoluteError
+
+
+
+
+ View source
+on GitHub
+
+Mean absolute error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMeanAbsoluteError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsoluteLogarithmicError.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsoluteLogarithmicError.md
new file mode 100644
index 00000000..57c56240
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsoluteLogarithmicError.md
@@ -0,0 +1,152 @@
+# runner.RootNodeMeanAbsoluteLogarithmicError
+
+
+
+
+ View source
+on GitHub
+
+Root node mean absolute logarithmic error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMeanAbsoluteLogarithmicError(
+ reduction: tf.keras.losses.Reduction = AUTO,
+ name: Optional[str] = None,
+ **kwargs
+)
+
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a head with ReLU for nonnegative regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor use for prediction.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The nonnegative logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsolutePercentageError.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsolutePercentageError.md
new file mode 100644
index 00000000..defda629
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanAbsolutePercentageError.md
@@ -0,0 +1,194 @@
+# runner.RootNodeMeanAbsolutePercentageError
+
+
+
+
+ View source
+on GitHub
+
+Mean absolute percentage error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMeanAbsolutePercentageError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredError.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredError.md
new file mode 100644
index 00000000..18b717d5
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredError.md
@@ -0,0 +1,194 @@
+# runner.RootNodeMeanSquaredError
+
+
+
+
+ View source
+on GitHub
+
+Mean squared error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMeanSquaredError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredLogScaledError.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredLogScaledError.md
new file mode 100644
index 00000000..d90fc77c
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredLogScaledError.md
@@ -0,0 +1,155 @@
+# runner.RootNodeMeanSquaredLogScaledError
+
+
+
+
+ View source
+on GitHub
+
+Mean squared log scaled error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMeanSquaredLogScaledError(
+ *args,
+ alpha_loss_param: float = 5.0,
+ epsilon_loss_param: float = 1e-08,
+ reduction: tf.keras.losses.Reduction = AUTO,
+ name: Optional[str] = None,
+ **kwargs
+)
+
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredLogarithmicError.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredLogarithmicError.md
new file mode 100644
index 00000000..bbe8e2d8
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMeanSquaredLogarithmicError.md
@@ -0,0 +1,194 @@
+# runner.RootNodeMeanSquaredLogarithmicError
+
+
+
+
+ View source
+on GitHub
+
+Mean squared logarithmic error task.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMeanSquaredLogarithmicError(
+ node_set_name: str,
+ *,
+ units: int = 1,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ **kwargs
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+units
+ |
+
+The units for the regression head.
+ |
+
+
+name
+ |
+
+The regression head's layer name. This name typically appears in
+the exported model's SignatureDef.
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> tf.Tensor
+
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Regression metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for regression.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for regression.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The regression logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMulticlassClassification.md b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMulticlassClassification.md
new file mode 100644
index 00000000..a87e6a36
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RootNodeMulticlassClassification.md
@@ -0,0 +1,233 @@
+# runner.RootNodeMulticlassClassification
+
+
+
+
+ View source
+on GitHub
+
+Root node multiclass classification.
+
+Inherits From: [`Task`](../runner/Task.md)
+
+
+runner.RootNodeMulticlassClassification(
+ node_set_name: str,
+ *,
+ num_classes: Optional[int] = None,
+ class_names: Optional[Sequence[str]] = None,
+ per_class_statistics: bool = False,
+ state_name: str = tfgnn.HIDDEN_STATE,
+ name: str = 'classification_logits',
+ label_fn: Optional[LabelFn] = None,
+ label_feature_name: Optional[str] = None
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+node_set_name
+ |
+
+The node set containing the root node.
+ |
+
+
+num_classes
+ |
+
+The number of classes. Exactly one of num_classes or
+class_names must be specified
+ |
+
+
+class_names
+ |
+
+The class names. Exactly one of num_classes or
+class_names must be specified
+ |
+
+
+per_class_statistics
+ |
+
+Whether to compute statistics per class.
+ |
+
+
+state_name
+ |
+
+The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
+ |
+
+
+name
+ |
+
+The classification head's layer name. To control the naming of saved
+model outputs see the runner model exporters (e.g.,
+KerasModelExporter ).
+ |
+
+
+label_fn
+ |
+
+A label extraction function. This function mutates the input
+GraphTensor . Mutually exclusive with label_feature_name .
+ |
+
+
+label_feature_name
+ |
+
+A label feature name for readout from the auxiliary
+'_readout' node set. Readout does not mutate the input GraphTensor .
+Mutually exclusive with label_fn .
+ |
+
+
+
+## Methods
+
+gather_activations
+
+View
+source
+
+
+gather_activations(
+ inputs: GraphTensor
+) -> Field
+
+
+Gather activations from root nodes.
+
+losses
+
+View
+source
+
+
+losses() -> interfaces.Losses
+
+
+Sparse categorical crossentropy loss.
+
+metrics
+
+View
+source
+
+
+metrics() -> interfaces.Metrics
+
+
+Sparse categorical metrics.
+
+predict
+
+View
+source
+
+
+predict(
+ inputs: tfgnn.GraphTensor
+) -> interfaces.Predictions
+
+
+Apply a linear head for classification.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A tfgnn.GraphTensor for classification.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The classification logits.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+preprocess(
+ inputs: GraphTensor
+) -> tuple[GraphTensor, Field]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/RunResult.md b/tensorflow_gnn/docs/api_docs/python/runner/RunResult.md
new file mode 100644
index 00000000..8e8aa163
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/RunResult.md
@@ -0,0 +1,71 @@
+# runner.RunResult
+
+
+
+
+ View source
+on GitHub
+
+Holds the return values of `run(...)`.
+
+
+runner.RunResult(
+ preprocess_model: Optional[tf.keras.Model],
+ base_model: tf.keras.Model,
+ trained_model: tf.keras.Model
+)
+
+
+
+
+
+
+
+Attributes |
+
+
+
+preprocess_model
+ |
+
+Keras model containing only the computation for
+preprocessing inputs. It is not trained. The model takes serialized
+GraphTensor s as its inputs and returns preprocessed GraphTensor s.
+None when no preprocess model exists.
+ |
+
+
+base_model
+ |
+
+Keras base GNN (as returned by the user provided model_fn ).
+The model both takes and returns GraphTensor s. The model contains
+any--but not all--trained weights. The trained_model contains all
+base_model trained weights in addition to any prediction trained
+weights.
+ |
+
+
+trained_model
+ |
+
+Keras model for the e2e GNN. (Base GNN plus any prediction
+head(s).) The model takes preprocess_model output as its inputs and
+returns Task predictions as its output. Output matches the structure of
+the Task : an atom for single- or a mapping for multi- Task training.
+The model contains all trained weights.
+ |
+
+
+
+## Methods
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+Return self==value.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/SampleTFRecordDatasetsProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/SampleTFRecordDatasetsProvider.md
new file mode 100644
index 00000000..d3950aec
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/SampleTFRecordDatasetsProvider.md
@@ -0,0 +1,229 @@
+# runner.SampleTFRecordDatasetsProvider
+
+
+
+
+ View source
+on GitHub
+
+Builds a sampling `tf.data.Dataset` from multiple filenames.
+
+Inherits From:
+[`SimpleSampleDatasetsProvider`](../runner/SimpleSampleDatasetsProvider.md),
+[`DatasetProvider`](../runner/DatasetProvider.md)
+
+
+runner.SampleTFRecordDatasetsProvider(
+ *args, **kwargs
+)
+
+
+
+
+For complete explanations regarding sampling see `_process_sampled_dataset()`.
+
+This `SimpleSampleDatasetsProvider` builds a `tf.data.Dataset` as follows:
+
+- The object is initialized with a list of filenames specified by
+ `principle_filenames` and `extra_filenames` argument. For convenience, the
+ corresponding file pattern `principal_file_pattern` and
+ `extra_file_patterns` can be specified instead, which will be expanded to a
+ sorted list.
+- The filenames are sharded between replicas according to the `InputContext`
+ (order matters).
+- Filenames are shuffled per replica (if requested).
+- Examples from all file patterns are sampled according to `principal_weight`
+ and `extra_weights.`
+- The files in each shard are interleaved after being read by the
+ `interleave_fn`.
+- Examples are shuffled (if requested), auto-prefetched, and returned for use
+ in one replica of the trainer.
+
+
+
+
+
+Args |
+
+
+
+principal_file_pattern
+ |
+
+A principal file pattern for sampling, to be
+expanded by tf.io.gfile.glob and sorted into the list of
+principal_filenames .
+ |
+
+
+extra_file_patterns
+ |
+
+File patterns, to be expanded by tf.io.gfile.glob
+and sorted into the list of extra_filenames .
+ |
+
+
+principal_weight
+ |
+
+An optional weight for the dataset corresponding to
+principal_file_pattern. Required iff extra_weights are also
+provided.
+ |
+
+
+extra_weights
+ |
+
+Optional weights corresponding to file_patterns for
+sampling. Required iff principal_weight is also provided.
+ |
+
+
+principal_filenames
+ |
+
+A list of principal filenames, specified explicitly.
+This argument is mutually exclusive with principal_file_pattern .
+ |
+
+
+extra_filenames
+ |
+
+A list of extra filenames, specified explicitly.
+This argument is mutually exclusive with extra_file_patterns .
+ |
+
+
+principal_cardinality
+ |
+
+Iff fixed_cardinality =True, the size of the
+returned dataset is computed as principal_cardinality /
+principal_weight (with a default of uniform weights).
+ |
+
+
+fixed_cardinality
+ |
+
+Whether to take a fixed number of elements.
+ |
+
+
+shuffle_filenames
+ |
+
+If enabled, filenames will be shuffled after sharding
+ between replicas, before any file reads. Through interleaving, some
+files may be read in parallel: the details are auto-tuned for throughput.
+ |
+
+
+interleave_fn
+ |
+
+A fn applied with tf.data.Dataset.interleave.
+ |
+
+
+examples_shuffle_size
+ |
+
+An optional buffer size for example shuffling. If
+specified, the size is adjusted to shuffle_size //
+(len(file_patterns) + 1).
+ |
+
+
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+get_dataset(
+ context: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Creates a `tf.data.Dataset` by sampling.
+
+The contents of the resulting `tf.data.Dataset` are sampled from several
+sources, each stored as a sharded dataset: * one principal input, whose size
+determines the size of the resulting `tf.data.Dataset`; * zero or more side
+inputs, which are repeated if necessary to preserve the requested samping
+weights.
+
+Each input dataset is shared before interleaving. The result of interleaving is
+only shuffled if a `examples_shuffle_size` is provided.
+
+Datasets are sampled from with `tf.data.Dataset.sample_from_datasets.` For
+sampling details, please refer to the TensorFlow documentation at:
+https://www.tensorflow.org/api_docs/python/tf/data/Dataset#sample_from_datasets.
+
+Two methods are supported to determine the end of the resulting
+`tf.data.Dataset`:
+
+fixed_cardinality=True) Returns a dataset with a fixed cardinality, set at
+`principal_cardinality` // `principal_weight.` `principal_dataset` and
+`principal_cardinality` are required for this method. `principal_weight` is
+required iff `extra_weights` are also provided.
+
+fixed_cardinality=False) Returns a dataset that ends after the principal input
+has been exhausted, subject to the random selection of samples.
+`principal_dataset` is required for this method. `principal_weight` is required
+iff `extra_weights` are also provided.
+
+The choice of `principal_dataset` is important and should, in most cases, be
+chosen as the largest underlying dataset as compared to `extra_datasets.`
+`positives` and `negatives` where `len(negatives)` >> `len(positives)` and with
+`positives` corresponding to `principal_dataset,` the desired behavior of epochs
+determined by the exhaustion of `positives` and the continued mixing of unique
+elements from `negatives` may not occur: On sampled dataset reiteration
+`positives` will again be exhausted but elements from `negatives` may be those
+same seen in the previous epoch (as they occur at the beginning of the same,
+reiterated underlying `negatives` dataset). In this case, the recommendations
+are to:
+
+1) Reformulate the sampling in terms of the larger dataset (`negatives`), where,
+with `fixed_cardinality=False`, if the exhaustion of `negatives` is desired, or,
+with `fixed_cardinality=True`, when `principal_cardinality` can be used to
+specify the desired number of elements from `negatives.` 2) Ensure that the
+underlying `principal_dataset` of `negatives` are well-sharded. In this way, the
+nondeterminism of interleaving will randomly access elements of `negatives` on
+reiteration.
+
+
+
+
+
+Args |
+
+
+
+context
+ |
+
+An tf.distribute.InputContext for sharding.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.data.Dataset.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/SimpleDatasetProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/SimpleDatasetProvider.md
new file mode 100644
index 00000000..bb730e58
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/SimpleDatasetProvider.md
@@ -0,0 +1,97 @@
+# runner.SimpleDatasetProvider
+
+
+
+
+ View source
+on GitHub
+
+Builds a `tf.data.Dataset` from a list of files.
+
+Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md)
+
+
+runner.SimpleDatasetProvider(
+ file_pattern: Optional[str] = None,
+ *,
+ filenames: Optional[Sequence[str]] = None,
+ shuffle_filenames: bool = False,
+ interleave_fn: Callable[..., tf.data.Dataset],
+ examples_shuffle_size: Optional[int] = None
+)
+
+
+
+
+This `SimpleDatasetProvider` builds a `tf.data.Dataset` as follows: - The object
+is initialized with a list of filenames. For convenience, a file pattern can be
+specified instead, which will be expanded to a sorted list. - The filenames are
+sharded between replicas according to the `InputContext` (order matters). -
+Filenames are shuffled per replica (if requested). - The files in each shard are
+interleaved after being read by the `interleave_fn`. - Examples are shuffled (if
+requested), auto-prefetched, and returned for use in one replica of the trainer.
+
+
+
+
+
+Args |
+
+
+
+file_pattern
+ |
+
+A file pattern, to be expanded by tf.io.gfile.glob
+and sorted into the list of all filenames .
+ |
+
+
+filenames
+ |
+
+A list of all filenames, specified explicitly.
+This argument is mutually exclusive with file_pattern .
+ |
+
+
+shuffle_filenames
+ |
+
+If enabled, filenames will be shuffled after sharding
+between replicas, before any file reads. Through interleaving, some
+files may be read in parallel: the details are auto-tuned for
+throughput.
+ |
+
+
+interleave_fn
+ |
+
+A callback that receives a single filename and returns
+a tf.data.Dataset with the tf.Example values from that file.
+ |
+
+
+examples_shuffle_size
+ |
+
+An optional buffer size for example shuffling.
+ |
+
+
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+get_dataset(
+ context: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Gets a `tf.data.Dataset` by `context` per replica.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/SimpleSampleDatasetsProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/SimpleSampleDatasetsProvider.md
new file mode 100644
index 00000000..87e5c5e4
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/SimpleSampleDatasetsProvider.md
@@ -0,0 +1,238 @@
+# runner.SimpleSampleDatasetsProvider
+
+
+
+
+ View source
+on GitHub
+
+Builds a sampling `tf.data.Dataset` from multiple filenames.
+
+Inherits From: [`DatasetProvider`](../runner/DatasetProvider.md)
+
+
+runner.SimpleSampleDatasetsProvider(
+ principal_file_pattern: Optional[str] = None,
+ extra_file_patterns: Optional[Sequence[str]] = None,
+ principal_weight: Optional[float] = None,
+ extra_weights: Optional[Sequence[float]] = None,
+ *,
+ principal_filenames: Optional[Sequence[str]] = None,
+ extra_filenames: Optional[Sequence[Sequence[str]]] = None,
+ principal_cardinality: Optional[int] = None,
+ fixed_cardinality: bool = False,
+ shuffle_filenames: bool = False,
+ interleave_fn: Callable[..., tf.data.Dataset],
+ examples_shuffle_size: Optional[int] = None
+)
+
+
+
+
+For complete explanations regarding sampling see `_process_sampled_dataset()`.
+
+This `SimpleSampleDatasetsProvider` builds a `tf.data.Dataset` as follows:
+
+- The object is initialized with a list of filenames specified by
+ `principle_filenames` and `extra_filenames` argument. For convenience, the
+ corresponding file pattern `principal_file_pattern` and
+ `extra_file_patterns` can be specified instead, which will be expanded to a
+ sorted list.
+- The filenames are sharded between replicas according to the `InputContext`
+ (order matters).
+- Filenames are shuffled per replica (if requested).
+- Examples from all file patterns are sampled according to `principal_weight`
+ and `extra_weights.`
+- The files in each shard are interleaved after being read by the
+ `interleave_fn`.
+- Examples are shuffled (if requested), auto-prefetched, and returned for use
+ in one replica of the trainer.
+
+
+
+
+
+Args |
+
+
+
+principal_file_pattern
+ |
+
+A principal file pattern for sampling, to be
+expanded by tf.io.gfile.glob and sorted into the list of
+principal_filenames .
+ |
+
+
+extra_file_patterns
+ |
+
+File patterns, to be expanded by tf.io.gfile.glob
+and sorted into the list of extra_filenames .
+ |
+
+
+principal_weight
+ |
+
+An optional weight for the dataset corresponding to
+principal_file_pattern. Required iff extra_weights are also
+provided.
+ |
+
+
+extra_weights
+ |
+
+Optional weights corresponding to file_patterns for
+sampling. Required iff principal_weight is also provided.
+ |
+
+
+principal_filenames
+ |
+
+A list of principal filenames, specified explicitly.
+This argument is mutually exclusive with principal_file_pattern .
+ |
+
+
+extra_filenames
+ |
+
+A list of extra filenames, specified explicitly.
+This argument is mutually exclusive with extra_file_patterns .
+ |
+
+
+principal_cardinality
+ |
+
+Iff fixed_cardinality =True, the size of the
+returned dataset is computed as principal_cardinality /
+principal_weight (with a default of uniform weights).
+ |
+
+
+fixed_cardinality
+ |
+
+Whether to take a fixed number of elements.
+ |
+
+
+shuffle_filenames
+ |
+
+If enabled, filenames will be shuffled after sharding
+ between replicas, before any file reads. Through interleaving, some
+files may be read in parallel: the details are auto-tuned for throughput.
+ |
+
+
+interleave_fn
+ |
+
+A fn applied with tf.data.Dataset.interleave.
+ |
+
+
+examples_shuffle_size
+ |
+
+An optional buffer size for example shuffling. If
+specified, the size is adjusted to shuffle_size //
+(len(file_patterns) + 1).
+ |
+
+
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+get_dataset(
+ context: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Creates a `tf.data.Dataset` by sampling.
+
+The contents of the resulting `tf.data.Dataset` are sampled from several
+sources, each stored as a sharded dataset: * one principal input, whose size
+determines the size of the resulting `tf.data.Dataset`; * zero or more side
+inputs, which are repeated if necessary to preserve the requested samping
+weights.
+
+Each input dataset is shared before interleaving. The result of interleaving is
+only shuffled if a `examples_shuffle_size` is provided.
+
+Datasets are sampled from with `tf.data.Dataset.sample_from_datasets.` For
+sampling details, please refer to the TensorFlow documentation at:
+https://www.tensorflow.org/api_docs/python/tf/data/Dataset#sample_from_datasets.
+
+Two methods are supported to determine the end of the resulting
+`tf.data.Dataset`:
+
+fixed_cardinality=True) Returns a dataset with a fixed cardinality, set at
+`principal_cardinality` // `principal_weight.` `principal_dataset` and
+`principal_cardinality` are required for this method. `principal_weight` is
+required iff `extra_weights` are also provided.
+
+fixed_cardinality=False) Returns a dataset that ends after the principal input
+has been exhausted, subject to the random selection of samples.
+`principal_dataset` is required for this method. `principal_weight` is required
+iff `extra_weights` are also provided.
+
+The choice of `principal_dataset` is important and should, in most cases, be
+chosen as the largest underlying dataset as compared to `extra_datasets.`
+`positives` and `negatives` where `len(negatives)` >> `len(positives)` and with
+`positives` corresponding to `principal_dataset,` the desired behavior of epochs
+determined by the exhaustion of `positives` and the continued mixing of unique
+elements from `negatives` may not occur: On sampled dataset reiteration
+`positives` will again be exhausted but elements from `negatives` may be those
+same seen in the previous epoch (as they occur at the beginning of the same,
+reiterated underlying `negatives` dataset). In this case, the recommendations
+are to:
+
+1) Reformulate the sampling in terms of the larger dataset (`negatives`), where,
+with `fixed_cardinality=False`, if the exhaustion of `negatives` is desired, or,
+with `fixed_cardinality=True`, when `principal_cardinality` can be used to
+specify the desired number of elements from `negatives.` 2) Ensure that the
+underlying `principal_dataset` of `negatives` are well-sharded. In this way, the
+nondeterminism of interleaving will randomly access elements of `negatives` on
+reiteration.
+
+
+
+
+
+Args |
+
+
+
+context
+ |
+
+An tf.distribute.InputContext for sharding.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.data.Dataset.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/SubmoduleExporter.md b/tensorflow_gnn/docs/api_docs/python/runner/SubmoduleExporter.md
new file mode 100644
index 00000000..bb2dc054
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/SubmoduleExporter.md
@@ -0,0 +1,125 @@
+# runner.SubmoduleExporter
+
+
+
+
+ View source
+on GitHub
+
+Exports a Keras submodule.
+
+Inherits From: [`ModelExporter`](../runner/ModelExporter.md)
+
+
+runner.SubmoduleExporter(
+ sublayer_name: str,
+ *,
+ output_names: Optional[Any] = None,
+ subdirectory: Optional[str] = None,
+ include_preprocessing: bool = False,
+ options: Optional[tf.saved_model.SaveOptions] = None
+)
+
+
+
+
+Given a `RunResult`, this exporter creates and exports a submodule with inputs
+identical to the trained model and outputs from some intermediate layer (named
+`sublayer_name`). For example, with pseudocode:
+
+`trained_model = tf.keras.Sequential([layer1, layer2, layer3, layer4])` and
+`SubmoduleExporter(sublayer_name='layer2')`
+
+The exported submodule is:
+
+`submodule = tf.keras.Sequential([layer1, layer2])`
+
+
+
+
+
+Args |
+
+
+
+sublayer_name
+ |
+
+The name of the submodule's final layer.
+ |
+
+
+output_names
+ |
+
+The names for output Tensor(s), see: KerasModelExporter .
+ |
+
+
+subdirectory
+ |
+
+An optional subdirectory, if set: submodules are exported
+to os.path.join(export_dir, subdirectory) .
+ |
+
+
+include_preprocessing
+ |
+
+Whether to include any preprocess_model .
+ |
+
+
+options
+ |
+
+Options for saving to a TensorFlow SavedModel .
+ |
+
+
+
+## Methods
+
+save
+
+View
+source
+
+
+save(
+ run_result: runner.RunResult ,
+ export_dir: str
+)
+
+
+Saves a Keras model submodule.
+
+Importantly: the `run_result.preprocess_model`, if provided, and
+`run_result.trained_model` are stacked before any export. Stacking involves the
+chaining of the first output of `run_result.preprocess_model` to the only input
+of `run_result.trained_model.` The result is a model with the input of
+`run_result.preprocess_model` and the output of `run_result.trained_model.`
+
+
+
+
+
+Args |
+
+
+
+run_result
+ |
+
+A RunResult from training.
+ |
+
+
+export_dir
+ |
+
+A destination directory.
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/TFDataServiceConfig.md b/tensorflow_gnn/docs/api_docs/python/runner/TFDataServiceConfig.md
new file mode 100644
index 00000000..f94a5164
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/TFDataServiceConfig.md
@@ -0,0 +1,66 @@
+# runner.TFDataServiceConfig
+
+
+
+
+ View source
+on GitHub
+
+Provides tf.data service related configuration options.
+
+
+runner.TFDataServiceConfig(
+ tf_data_service_address: str,
+ tf_data_service_job_name: str,
+ tf_data_service_mode: Union[str, tf.data.experimental.service.ShardingPolicy]
+)
+
+
+
+
+tf.data service has data flexible visitation guarantees, its impact over your
+training pipelines will be empirical. Check out the tf.data service internals
+and operation details from
+https://www.tensorflow.org/api_docs/python/tf/data/experimental/service.
+
+
+
+
+
+Attributes |
+
+
+
+tf_data_service_address
+ |
+
+Dataclass field
+ |
+
+
+tf_data_service_job_name
+ |
+
+Dataclass field
+ |
+
+
+tf_data_service_mode
+ |
+
+Dataclass field
+ |
+
+
+
+## Methods
+
+__eq__
+
+
+__eq__(
+ other
+)
+
+
+Return self==value.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/TFRecordDatasetProvider.md b/tensorflow_gnn/docs/api_docs/python/runner/TFRecordDatasetProvider.md
new file mode 100644
index 00000000..d3b381de
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/TFRecordDatasetProvider.md
@@ -0,0 +1,93 @@
+# runner.TFRecordDatasetProvider
+
+
+
+
+ View source
+on GitHub
+
+Builds a `tf.data.Dataset` from a list of files.
+
+Inherits From: [`SimpleDatasetProvider`](../runner/SimpleDatasetProvider.md),
+[`DatasetProvider`](../runner/DatasetProvider.md)
+
+
+runner.TFRecordDatasetProvider(
+ *args, **kwargs
+)
+
+
+
+
+This `SimpleDatasetProvider` builds a `tf.data.Dataset` as follows: - The object
+is initialized with a list of filenames. For convenience, a file pattern can be
+specified instead, which will be expanded to a sorted list. - The filenames are
+sharded between replicas according to the `InputContext` (order matters). -
+Filenames are shuffled per replica (if requested). - The files in each shard are
+interleaved after being read by the `interleave_fn`. - Examples are shuffled (if
+requested), auto-prefetched, and returned for use in one replica of the trainer.
+
+
+
+
+
+Args |
+
+
+
+file_pattern
+ |
+
+A file pattern, to be expanded by tf.io.gfile.glob
+and sorted into the list of all filenames .
+ |
+
+
+filenames
+ |
+
+A list of all filenames, specified explicitly.
+This argument is mutually exclusive with file_pattern .
+ |
+
+
+shuffle_filenames
+ |
+
+If enabled, filenames will be shuffled after sharding
+between replicas, before any file reads. Through interleaving, some
+files may be read in parallel: the details are auto-tuned for
+throughput.
+ |
+
+
+interleave_fn
+ |
+
+A callback that receives a single filename and returns
+a tf.data.Dataset with the tf.Example values from that file.
+ |
+
+
+examples_shuffle_size
+ |
+
+An optional buffer size for example shuffling.
+ |
+
+
+
+## Methods
+
+get_dataset
+
+View
+source
+
+
+get_dataset(
+ context: tf.distribute.InputContext
+) -> tf.data.Dataset
+
+
+Gets a `tf.data.Dataset` by `context` per replica.
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/TPUStrategy.md b/tensorflow_gnn/docs/api_docs/python/runner/TPUStrategy.md
new file mode 100644
index 00000000..d80ce7b5
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/TPUStrategy.md
@@ -0,0 +1,1228 @@
+# runner.TPUStrategy
+
+
+
+
+ View source
+on GitHub
+
+A `TPUStrategy` convenience wrapper.
+
+
+runner.TPUStrategy(
+ tpu: str = ''
+)
+
+
+
+
+
+
+
+Args |
+
+
+
+tpu_cluster_resolver
+ |
+
+A
+tf.distribute.cluster_resolver.TPUClusterResolver instance, which
+provides information about the TPU cluster. If None, it will assume
+running on a local TPU worker.
+ |
+
+
+experimental_device_assignment
+ |
+
+Optional
+tf.tpu.experimental.DeviceAssignment to specify the placement of
+replicas on the TPU cluster.
+ |
+
+
+experimental_spmd_xla_partitioning
+ |
+
+If True, enable the SPMD (Single
+Program Multiple Data) mode in XLA compiler. This flag only affects the
+performance of XLA compilation and the HBM requirement of the compiled
+TPU program. Ceveat: if this flag is True, calling
+tf.distribute.TPUStrategy.experimental_assign_to_logical_device will
+result in a ValueError.
+ |
+
+
+
+
+
+
+
+Attributes |
+
+ cluster_resolver |
+Returns the cluster resolver associated with this strategy.
+
+tf.distribute.TPUStrategy provides the associated
+tf.distribute.cluster_resolver.ClusterResolver . If the user provides one
+in **init** , that instance is returned; if the user does not, a default
+tf.distribute.cluster_resolver.TPUClusterResolver is provided.
+ |
+
+
+extended
+ |
+
+tf.distribute.StrategyExtended with additional methods.
+ |
+
+
+num_replicas_in_sync
+ |
+
+Returns number of replicas over which gradients are aggregated.
+ |
+
+
+
+## Methods
+
+distribute_datasets_from_function
+
+
+distribute_datasets_from_function(
+ dataset_fn, options=None
+)
+
+
+Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
+
+The argument `dataset_fn` that users pass in is an input function that has a
+`tf.distribute.InputContext` argument and returns a `tf.data.Dataset` instance.
+It is expected that the returned dataset from `dataset_fn` is already batched by
+per-replica batch size (i.e. global batch size divided by the number of replicas
+in sync) and sharded. `tf.distribute.Strategy.distribute_datasets_from_function`
+does not batch or shard the `tf.data.Dataset` instance returned from the input
+function. `dataset_fn` will be called on the CPU device of each of the workers
+and each generates a dataset where every replica on that worker will dequeue one
+batch of inputs (i.e. if a worker has two replicas, two batches will be dequeued
+from the `Dataset` every step).
+
+This method can be used for several purposes. First, it allows you to specify
+your own batching and sharding logic. (In contrast,
+`tf.distribute.experimental_distribute_dataset` does batching and sharding for
+you.) For example, where `experimental_distribute_dataset` is unable to shard
+the input files, this method might be used to manually shard the dataset
+(avoiding the slow fallback behavior in `experimental_distribute_dataset`). In
+cases where the dataset is infinite, this sharding can be done by creating
+dataset replicas that differ only in their random seed.
+
+The `dataset_fn` should take an `tf.distribute.InputContext` instance where
+information about batching and input replication can be accessed.
+
+You can use `element_spec` property of the `tf.distribute.DistributedDataset`
+returned by this API to query the `tf.TypeSpec` of the elements returned by the
+iterator. This can be used to set the `input_signature` property of a
+`tf.function`. Follow `tf.distribute.DistributedDataset.element_spec` to see an
+example.
+
+IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
+per-replica batch size, unlike `experimental_distribute_dataset`, which uses the
+global batch size. This may be computed using
+`input_context.get_per_replica_batch_size`.
+
+Note: If you are using TPUStrategy, the order in which the data is processed by
+the workers when using `tf.distribute.Strategy.experimental_distribute_dataset`
+or `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed.
+This is typically required if you are using `tf.distribute` to scale prediction.
+You can however insert an index for each element in the batch and order outputs
+accordingly. Refer to
+[this snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
+for an example of how to order outputs.
+
+Note: Stateful dataset transformations are currently not supported with
+`tf.distribute.experimental_distribute_dataset` or
+`tf.distribute.distribute_datasets_from_function`. Any stateful ops that the
+dataset may have are currently ignored. For example, if your dataset has a
+`map_fn` that uses `tf.random.uniform` to rotate an image, then you have a
+dataset graph that depends on state (i.e the random seed) on the local machine
+where the python process is being executed.
+
+For a tutorial on more usage and properties of this method, refer to the
+[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
+If you are interested in last partial batch handling, read
+[this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
+
+
+
+
+
+Args |
+
+
+
+dataset_fn
+ |
+
+A function taking a tf.distribute.InputContext instance and
+returning a tf.data.Dataset .
+ |
+
+
+options
+ |
+
+tf.distribute.InputOptions used to control options on how this
+dataset is distributed.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.distribute.DistributedDataset .
+ |
+
+
+
+
+experimental_assign_to_logical_device
+
+
+experimental_assign_to_logical_device(
+ tensor, logical_device_id
+)
+
+
+Adds annotation that `tensor` will be assigned to a logical device.
+
+This adds an annotation to `tensor` specifying that operations on `tensor` will
+be invoked on logical core device id `logical_device_id`. When model parallelism
+is used, the default behavior is that all ops are placed on zero-th logical
+device.
+
+```python
+
+# Initializing TPU system with 2 logical devices and 4 replicas.
+resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+tf.config.experimental_connect_to_cluster(resolver)
+topology = tf.tpu.experimental.initialize_tpu_system(resolver)
+device_assignment = tf.tpu.experimental.DeviceAssignment.build(
+ topology,
+ computation_shape=[1, 1, 1, 2],
+ num_replicas=4)
+strategy = tf.distribute.TPUStrategy(
+ resolver, experimental_device_assignment=device_assignment)
+iterator = iter(inputs)
+
+@tf.function()
+def step_fn(inputs):
+ output = tf.add(inputs, inputs)
+
+ # Add operation will be executed on logical device 0.
+ output = strategy.experimental_assign_to_logical_device(output, 0)
+ return output
+
+strategy.run(step_fn, args=(next(iterator),))
+```
+
+
+
+
+
+Args |
+
+
+
+tensor
+ |
+
+Input tensor to annotate.
+ |
+
+
+logical_device_id
+ |
+
+Id of the logical core to which the tensor will be
+assigned.
+ |
+
+
+
+
+
+
+
+Raises |
+
+
+
+ValueError
+ |
+
+The logical device id presented is not consistent with total
+number of partitions specified by the device assignment or the TPUStrategy
+is constructed with experimental_spmd_xla_partitioning=True .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Annotated tensor with identical value as tensor .
+ |
+
+
+
+
+experimental_distribute_dataset
+
+
+experimental_distribute_dataset(
+ dataset, options=None
+)
+
+
+Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
+
+The returned `tf.distribute.DistributedDataset` can be iterated over similar to
+regular datasets. NOTE: The user cannot add any more transformations to a
+`tf.distribute.DistributedDataset`. You can only create an iterator or examine
+the `tf.TypeSpec` of the data generated by it. See API docs of
+`tf.distribute.DistributedDataset` to learn more.
+
+The following is an example:
+
+```
+>>> global_batch_size = 2
+>>> # Passing the devices is optional.
+... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
+>>> # Create a dataset
+... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
+>>> # Distribute that dataset
+... dist_dataset = strategy.experimental_distribute_dataset(dataset)
+>>> @tf.function
+... def replica_fn(input):
+... return input*2
+>>> result = []
+>>> # Iterate over the `tf.distribute.DistributedDataset`
+... for x in dist_dataset:
+... # process dataset elements
+... result.append(strategy.run(replica_fn, args=(x,)))
+>>> print(result)
+[PerReplica:{
+ 0: ,
+ 1:
+}, PerReplica:{
+ 0: ,
+ 1:
+}]
+```
+
+Three key actions happening under the hood of this method are batching,
+sharding, and prefetching.
+
+In the code snippet above, `dataset` is batched by `global_batch_size`, and
+calling `experimental_distribute_dataset` on it rebatches `dataset` to a new
+batch size that is equal to the global batch size divided by the number of
+replicas in sync. We iterate through it using a Pythonic for loop. `x` is a
+`tf.distribute.DistributedValues` containing data for all replicas, and each
+replica gets data of the new batch size. `tf.distribute.Strategy.run` will take
+care of feeding the right per-replica data in `x` to the right `replica_fn`
+executed on each replica.
+
+Sharding contains autosharding across multiple workers and within every worker.
+First, in multi-worker distributed training (i.e. when you use
+`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
+`tf.distribute.TPUStrategy`), autosharding a dataset over a set of workers means
+that each worker is assigned a subset of the entire dataset (if the right
+`tf.data.experimental.AutoShardPolicy` is set). This is to ensure that at each
+step, a global batch size of non-overlapping dataset elements will be processed
+by each worker. Autosharding has a couple of different options that can be
+specified using `tf.data.experimental.DistributeOptions`. Then, sharding within
+each worker means the method will split the data among all the worker devices
+(if more than one a present). This will happen regardless of multi-worker
+autosharding.
+
+Note: for autosharding across multiple workers, the default mode is
+`tf.data.experimental.AutoShardPolicy.AUTO`. This mode will attempt to shard the
+input dataset by files if the dataset is being created out of reader datasets
+(e.g. `tf.data.TFRecordDataset`, `tf.data.TextLineDataset`, etc.) or otherwise
+shard the dataset by data, where each of the workers will read the entire
+dataset and only process the shard assigned to it. However, if you have less
+than one input file per worker, we suggest that you disable dataset autosharding
+across workers by setting the
+`tf.data.experimental.DistributeOptions.auto_shard_policy` to be
+`tf.data.experimental.AutoShardPolicy.OFF`.
+
+By default, this method adds a prefetch transformation at the end of the user
+provided `tf.data.Dataset` instance. The argument to the prefetch transformation
+which is `buffer_size` is equal to the number of replicas in sync.
+
+If the above batch splitting and dataset sharding logic is undesirable, please
+use `tf.distribute.Strategy.distribute_datasets_from_function` instead, which
+does not do any automatic batching or sharding for you.
+
+Note: If you are using TPUStrategy, the order in which the data is processed by
+the workers when using `tf.distribute.Strategy.experimental_distribute_dataset`
+or `tf.distribute.Strategy.distribute_datasets_from_function` is not guaranteed.
+This is typically required if you are using `tf.distribute` to scale prediction.
+You can however insert an index for each element in the batch and order outputs
+accordingly. Refer to
+[this snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
+for an example of how to order outputs.
+
+Note: Stateful dataset transformations are currently not supported with
+`tf.distribute.experimental_distribute_dataset` or
+`tf.distribute.distribute_datasets_from_function`. Any stateful ops that the
+dataset may have are currently ignored. For example, if your dataset has a
+`map_fn` that uses `tf.random.uniform` to rotate an image, then you have a
+dataset graph that depends on state (i.e the random seed) on the local machine
+where the python process is being executed.
+
+For a tutorial on more usage and properties of this method, refer to the
+[tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
+If you are interested in last partial batch handling, read
+[this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
+
+
+
+
+
+Args |
+
+
+
+dataset
+ |
+
+tf.data.Dataset that will be sharded across all replicas using
+the rules stated above.
+ |
+
+
+options
+ |
+
+tf.distribute.InputOptions used to control options on how this
+dataset is distributed.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.distribute.DistributedDataset .
+ |
+
+
+
+
+experimental_distribute_values_from_function
+
+
+experimental_distribute_values_from_function(
+ value_fn
+)
+
+
+Generates `tf.distribute.DistributedValues` from `value_fn`.
+
+This function is to generate `tf.distribute.DistributedValues` to pass into
+`run`, `reduce`, or other methods that take distributed values when not using
+datasets.
+
+
+
+
+
+Args |
+
+
+
+value_fn
+ |
+
+The function to run to generate values. It is called for
+each replica with tf.distribute.ValueContext as the sole argument. It
+must return a Tensor or a type that can be converted to a Tensor.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.distribute.DistributedValues containing a value for each replica.
+ |
+
+
+
+
+#### Example usage:
+
+1. Return constant value per replica:
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> def value_fn(ctx):
+ ... return tf.constant(1.)
+ >>> distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ >>> local_result = strategy.experimental_local_results(
+ ... distributed_values)
+ >>> local_result
+ (,
+ )
+ ```
+
+2. Distribute values in array based on replica_id: {: value=2}
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> array_value = np.array([3., 2., 1.])
+ >>> def value_fn(ctx):
+ ... return array_value[ctx.replica_id_in_sync_group]
+ >>> distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ >>> local_result = strategy.experimental_local_results(
+ ... distributed_values)
+ >>> local_result
+ (3.0, 2.0)
+ ```
+
+3. Specify values using num_replicas_in_sync: {: value=3}
+
+ ```
+ >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+ >>> def value_fn(ctx):
+ ... return ctx.num_replicas_in_sync
+ >>> distributed_values = (
+ ... strategy.experimental_distribute_values_from_function(
+ ... value_fn))
+ >>> local_result = strategy.experimental_local_results(
+ ... distributed_values)
+ >>> local_result
+ (2, 2)
+ ```
+
+4. Place values on devices and distribute: {: value=4}
+
+ ```
+ strategy = tf.distribute.TPUStrategy()
+ worker_devices = strategy.extended.worker_devices
+ multiple_values = []
+ for i in range(strategy.num_replicas_in_sync):
+ with tf.device(worker_devices[i]):
+ multiple_values.append(tf.constant(1.0))
+
+ def value_fn(ctx):
+ return multiple_values[ctx.replica_id_in_sync_group]
+
+ distributed_values = strategy.
+ experimental_distribute_values_from_function(
+ value_fn)
+ ```
+
+experimental_local_results
+
+
+experimental_local_results(
+ value
+)
+
+
+Returns the list of all local per-replica values contained in `value`.
+
+Note: This only returns values on the worker initiated by this client. When
+using a `tf.distribute.Strategy` like
+`tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker will be
+its own client, and this function will only return values computed on that
+worker.
+
+
+
+
+
+Args |
+
+
+
+value
+ |
+
+A value returned by experimental_run() , run(), or a variable
+created in scope`.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of values contained in value where ith element corresponds to
+ith replica. If value represents a single value, this returns
+(value,).
+ |
+
+
+
+
+experimental_replicate_to_logical_devices
+
+
+experimental_replicate_to_logical_devices(
+ tensor
+)
+
+
+Adds annotation that `tensor` will be replicated to all logical devices.
+
+This adds an annotation to tensor `tensor` specifying that operations on
+`tensor` will be invoked on all logical devices.
+
+```python
+# Initializing TPU system with 2 logical devices and 4 replicas.
+resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+tf.config.experimental_connect_to_cluster(resolver)
+topology = tf.tpu.experimental.initialize_tpu_system(resolver)
+device_assignment = tf.tpu.experimental.DeviceAssignment.build(
+ topology,
+ computation_shape=[1, 1, 1, 2],
+ num_replicas=4)
+strategy = tf.distribute.TPUStrategy(
+ resolver, experimental_device_assignment=device_assignment)
+
+iterator = iter(inputs)
+
+@tf.function()
+def step_fn(inputs):
+ images, labels = inputs
+ images = strategy.experimental_split_to_logical_devices(
+ inputs, [1, 2, 4, 1])
+
+ # model() function will be executed on 8 logical devices with `inputs`
+ # split 2 * 4 ways.
+ output = model(inputs)
+
+ # For loss calculation, all logical devices share the same logits
+ # and labels.
+ labels = strategy.experimental_replicate_to_logical_devices(labels)
+ output = strategy.experimental_replicate_to_logical_devices(output)
+ loss = loss_fn(labels, output)
+
+ return loss
+
+strategy.run(step_fn, args=(next(iterator),))
+```
+
+Args: tensor: Input tensor to annotate.
+
+
+
+
+
+Returns |
+
+
+Annotated tensor with identical value as tensor .
+ |
+
+
+
+
+experimental_split_to_logical_devices
+
+
+experimental_split_to_logical_devices(
+ tensor, partition_dimensions
+)
+
+
+Adds annotation that `tensor` will be split across logical devices.
+
+This adds an annotation to tensor `tensor` specifying that operations on
+`tensor` will be split among multiple logical devices. Tensor `tensor` will be
+split across dimensions specified by `partition_dimensions`. The dimensions of
+`tensor` must be divisible by corresponding value in `partition_dimensions`.
+
+For example, for system with 8 logical devices, if `tensor` is an image tensor
+with shape (batch_size, width, height, channel) and `partition_dimensions` is
+[1, 2, 4, 1], then `tensor` will be split 2 in width dimension and 4 way in
+height dimension and the split tensor values will be fed into 8 logical devices.
+
+```python
+# Initializing TPU system with 8 logical devices and 1 replica.
+resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+tf.config.experimental_connect_to_cluster(resolver)
+topology = tf.tpu.experimental.initialize_tpu_system(resolver)
+device_assignment = tf.tpu.experimental.DeviceAssignment.build(
+ topology,
+ computation_shape=[1, 2, 2, 2],
+ num_replicas=1)
+# Construct the TPUStrategy. Since we are going to split the image across
+# logical devices, here we set `experimental_spmd_xla_partitioning=True`
+# so that the partitioning can be compiled in SPMD mode, which usually
+# results in faster compilation and smaller HBM requirement if the size of
+# input and activation tensors are much bigger than that of the model
+# parameters. Note that this flag is suggested but not a hard requirement
+# for `experimental_split_to_logical_devices`.
+strategy = tf.distribute.TPUStrategy(
+ resolver, experimental_device_assignment=device_assignment,
+ experimental_spmd_xla_partitioning=True)
+
+iterator = iter(inputs)
+
+@tf.function()
+def step_fn(inputs):
+ inputs = strategy.experimental_split_to_logical_devices(
+ inputs, [1, 2, 4, 1])
+
+ # model() function will be executed on 8 logical devices with `inputs`
+ # split 2 * 4 ways.
+ output = model(inputs)
+ return output
+
+strategy.run(step_fn, args=(next(iterator),))
+```
+
+Args: tensor: Input tensor to annotate. partition_dimensions: An unnested list
+of integers with the size equal to rank of `tensor` specifying how `tensor` will
+be partitioned. The product of all elements in `partition_dimensions` must be
+equal to the total number of logical devices per replica.
+
+
+
+
+
+Raises |
+
+
+
+ValueError
+ |
+
+1) If the size of partition_dimensions does not equal to rank
+of tensor or 2) if product of elements of partition_dimensions does
+not match the number of logical devices per replica defined by the
+implementing DistributionStrategy's device specification or
+3) if a known size of tensor is not divisible by corresponding
+value in partition_dimensions .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Annotated tensor with identical value as tensor .
+ |
+
+
+
+
+gather
+
+
+gather(
+ value, axis
+)
+
+
+Gather `value` across replicas along `axis` to the current device.
+
+Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like object `value`,
+this API gathers and concatenates `value` across replicas along the `axis`-th
+dimension. The result is copied to the "current" device, which would typically
+be the CPU of the worker on which the program is running. For
+`tf.distribute.TPUStrategy`, it is the first TPU host. For multi-client
+`tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of each worker.
+
+This API can only be called in the cross-replica context. For a counterpart in
+the replica context, see `tf.distribute.ReplicaContext.all_gather`.
+
+Note: For all strategies except `tf.distribute.TPUStrategy`, the input `value`
+on different replicas must have the same rank, and their shapes must be the same
+in all dimensions except the `axis`-th dimension. In other words, their shapes
+cannot be different in a dimension `d` where `d` does not equal to the `axis`
+argument. For example, given a `tf.distribute.DistributedValues` with component
+tensors of shape `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
+`gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
+`gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, all
+tensors must have exactly the same rank and same shape.
+
+Note: Given a `tf.distribute.DistributedValues` `value`, its component tensors
+must have a non-zero rank. Otherwise, consider using `tf.expand_dims` before
+gathering them.
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+>>> # A DistributedValues with component tensor of shape (2, 1) on each replica
+... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
+>>> @tf.function
+... def run():
+... return strategy.gather(distributed_values, axis=0)
+>>> run()
+
+```
+
+Consider the following example for more combinations:
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
+>>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
+>>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
+>>> @tf.function
+... def run(axis):
+... return strategy.gather(distributed_values, axis=axis)
+>>> axis=0
+>>> run(axis)
+
+>>> axis=1
+>>> run(axis)
+
+>>> axis=2
+>>> run(axis)
+
+```
+
+
+
+
+
+Args |
+
+
+
+value
+ |
+
+a tf.distribute.DistributedValues instance, e.g. returned by
+Strategy.run , to be combined into a single tensor. It can also be a
+regular tensor when used with tf.distribute.OneDeviceStrategy or the
+default strategy. The tensors that constitute the DistributedValues
+can only be dense tensors with non-zero rank, NOT a tf.IndexedSlices .
+ |
+
+
+axis
+ |
+
+0-D int32 Tensor. Dimension along which to gather. Must be in the
+range [0, rank(value)).
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A Tensor that's the concatenation of value across replicas along
+axis dimension.
+ |
+
+
+
+
+reduce
+
+
+reduce(
+ reduce_op, value, axis
+)
+
+
+Reduce `value` across replicas and return result on current device.
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+>>> def step_fn():
+... i = tf.distribute.get_replica_context().replica_id_in_sync_group
+... return tf.identity(i)
+>>>
+>>> per_replica_result = strategy.run(step_fn)
+>>> total = strategy.reduce("SUM", per_replica_result, axis=None)
+>>> total
+
+```
+
+To see how this would look with multiple replicas, consider the same example
+with MirroredStrategy with 2 GPUs:
+
+```python
+strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
+def step_fn():
+ i = tf.distribute.get_replica_context().replica_id_in_sync_group
+ return tf.identity(i)
+
+per_replica_result = strategy.run(step_fn)
+# Check devices on which per replica result is:
+strategy.experimental_local_results(per_replica_result)[0].device
+# /job:localhost/replica:0/task:0/device:GPU:0
+strategy.experimental_local_results(per_replica_result)[1].device
+# /job:localhost/replica:0/task:0/device:GPU:1
+
+total = strategy.reduce("SUM", per_replica_result, axis=None)
+# Check device on which reduced result is:
+total.device
+# /job:localhost/replica:0/task:0/device:CPU:0
+
+```
+
+This API is typically used for aggregating the results returned from different
+replicas, for reporting etc. For example, loss computed from different replicas
+can be averaged using this API before printing.
+
+Note: The result is copied to the "current" device - which would typically be
+the CPU of the worker on which the program is running. For `TPUStrategy`, it is
+the first TPU host. For multi client `MultiWorkerMirroredStrategy`, this is CPU
+of each worker.
+
+There are a number of different tf.distribute APIs for reducing values across
+replicas: * `tf.distribute.ReplicaContext.all_reduce`: This differs from
+`Strategy.reduce` in that it is for replica context and does not copy the
+results to the host device. `all_reduce` should be typically used for reductions
+inside the training step such as gradients. *
+`tf.distribute.StrategyExtended.reduce_to` and
+`tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more advanced
+versions of `Strategy.reduce` as they allow customizing the destination of the
+result. They are also called in cross replica context.
+
+*What should axis be?*
+
+Given a per-replica value returned by `run`, say a per-example loss, the batch
+will be divided across all the replicas. This function allows you to aggregate
+across replicas and optionally also across batch elements by specifying the axis
+parameter accordingly.
+
+For example, if you have a global batch size of 8 and 2 replicas, values for
+examples `[0, 1, 2, 3]` will be on replica 0 and `[4, 5, 6, 7]` will be on
+replica 1. With `axis=None`, `reduce` will aggregate only across replicas,
+returning `[0+4, 1+5, 2+6, 3+7]`. This is useful when each replica is computing
+a scalar or some other value that doesn't have a "batch" dimension (like a
+gradient or loss). `strategy.reduce("sum", per_replica_result, axis=None)`
+
+Sometimes, you will want to aggregate across both the global batch *and* all
+replicas. You can get this behavior by specifying the batch dimension as the
+`axis`, typically `axis=0`. In this case it would return a scalar
+`0+1+2+3+4+5+6+7`. `strategy.reduce("sum", per_replica_result, axis=0)`
+
+If there is a last partial batch, you will need to specify an axis so that the
+resulting shape is consistent across replicas. So if the last batch has size 6
+and it is divided into [0, 1, 2, 3] and [4, 5], you would get a shape mismatch
+unless you specify `axis=0`. If you specify `tf.distribute.ReduceOp.MEAN`, using
+`axis=0` will use the correct denominator of 6. Contrast this with computing
+`reduce_mean` to get a scalar value on each replica and this function to average
+those means, which will weigh some values `1/8` and others `1/4`.
+
+
+
+
+
+Args |
+
+
+
+reduce_op
+ |
+
+a tf.distribute.ReduceOp value specifying how values should
+be combined. Allows using string representation of the enum such as
+"SUM", "MEAN".
+ |
+
+
+value
+ |
+
+a tf.distribute.DistributedValues instance, e.g. returned by
+Strategy.run , to be combined into a single tensor. It can also be a
+regular tensor when used with OneDeviceStrategy or default strategy.
+ |
+
+
+axis
+ |
+
+specifies the dimension to reduce along within each
+replica's tensor. Should typically be set to the batch dimension, or
+None to only reduce across replicas (e.g. if the tensor has no batch
+dimension).
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A Tensor .
+ |
+
+
+
+
+run
+
+
+run(
+ fn, args=(), kwargs=None, options=None
+)
+
+
+Run the computation defined by `fn` on each TPU replica.
+
+Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
+`tf.distribute.DistributedValues`, such as those produced by a
+`tf.distribute.DistributedDataset` from
+`tf.distribute.Strategy.experimental_distribute_dataset` or
+`tf.distribute.Strategy.distribute_datasets_from_function`, when `fn` is
+executed on a particular replica, it will be executed with the component of
+`tf.distribute.DistributedValues` that correspond to that replica.
+
+`fn` may call `tf.distribute.get_replica_context()` to access members such as
+`all_reduce`.
+
+All arguments in `args` or `kwargs` should either be nest of tensors or
+`tf.distribute.DistributedValues` containing tensors or composite tensors.
+
+#### Example usage:
+
+```
+>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
+>>> tf.config.experimental_connect_to_cluster(resolver)
+>>> tf.tpu.experimental.initialize_tpu_system(resolver)
+>>> strategy = tf.distribute.TPUStrategy(resolver)
+>>> @tf.function
+... def run():
+... def value_fn(value_context):
+... return value_context.num_replicas_in_sync
+... distributed_values = (
+... strategy.experimental_distribute_values_from_function(value_fn))
+... def replica_fn(input):
+... return input * 2
+... return strategy.run(replica_fn, args=(distributed_values,))
+>>> result = run()
+```
+
+
+
+
+
+Args |
+
+
+
+fn
+ |
+
+The function to run. The output must be a tf.nest of Tensor s.
+ |
+
+
+args
+ |
+
+(Optional) Positional arguments to fn .
+ |
+
+
+kwargs
+ |
+
+(Optional) Keyword arguments to fn .
+ |
+
+
+options
+ |
+
+(Optional) An instance of tf.distribute.RunOptions specifying
+the options to run fn .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+Merged return value of fn across replicas. The structure of the return
+value is the same as the return value from fn . Each element in the
+structure can either be tf.distribute.DistributedValues , Tensor
+objects, or Tensor s (for example, if running on a single replica).
+ |
+
+
+
+
+scope
+
+
+scope()
+
+
+Context manager to make the strategy current and distribute variables.
+
+This method returns a context manager, and is used as follows:
+
+```
+>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
+>>> # Variable created inside scope:
+>>> with strategy.scope():
+... mirrored_variable = tf.Variable(1.)
+>>> mirrored_variable
+MirroredVariable:{
+ 0: ,
+ 1:
+}
+>>> # Variable created outside scope:
+>>> regular_variable = tf.Variable(1.)
+>>> regular_variable
+
+```
+
+*What happens when Strategy.scope is entered?*
+
+* `strategy` is installed in the global context as the "current" strategy.
+ Inside this scope, `tf.distribute.get_strategy()` will now return this
+ strategy. Outside this scope, it returns the default no-op strategy.
+* Entering the scope also enters the "cross-replica context". See
+ `tf.distribute.StrategyExtended` for an explanation on cross-replica and
+ replica contexts.
+* Variable creation inside `scope` is intercepted by the strategy. Each
+ strategy defines how it wants to affect the variable creation. Sync
+ strategies like `MirroredStrategy`, `TPUStrategy` and
+ `MultiWorkerMiroredStrategy` create variables replicated on each replica,
+ whereas `ParameterServerStrategy` creates variables on the parameter
+ servers. This is done using a custom `tf.variable_creator_scope`.
+* In some strategies, a default device scope may also be entered: in
+ `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is entered
+ on each worker.
+
+Note: Entering a scope does not automatically distribute a computation, except
+in the case of high level training framework like keras `model.fit`. If you're
+not using `model.fit`, you need to use `strategy.run` API to explicitly
+distribute that computation. See an example in the
+[custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
+
+*What should be in scope and what should be outside?*
+
+There are a number of requirements on what needs to happen inside the scope.
+However, in places where we have information about which strategy is in use, we
+often enter the scope for the user, so they don't have to do it explicitly (i.e.
+calling those either inside or outside the scope is OK).
+
+* Anything that creates variables that should be distributed variables must be
+ called in a `strategy.scope`. This can be accomplished either by directly
+ calling the variable creating function within the scope context, or by
+ relying on another API like `strategy.run` or `keras.Model.fit` to
+ automatically enter it for you. Any variable that is created outside scope
+ will not be distributed and may have performance implications. Some common
+ objects that create variables in TF are Models, Optimizers, Metrics. Such
+ objects should always be initialized in the scope, and any functions that
+ may lazily create variables (e.g., `Model.__call__()`, tracing a
+ `tf.function`, etc.) should similarly be called within scope. Another source
+ of variable creation can be a checkpoint restore - when variables are
+ created lazily. Note that any variable created inside a strategy captures
+ the strategy information. So reading and writing to these variables outside
+ the `strategy.scope` can also work seamlessly, without the user having to
+ enter the scope.
+* Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
+ require to be in a strategy's scope, enter the scope automatically, which
+ means when using those APIs you don't need to explicitly enter the scope
+ yourself.
+* When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
+ object captures the scope information. When high level training framework
+ methods such as `model.compile`, `model.fit`, etc. are then called, the
+ captured scope will be automatically entered, and the associated strategy
+ will be used to distribute the training etc. See a detailed example in
+ [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
+ WARNING: Simply calling `model(..)` does not automatically enter the
+ captured scope -- only high level training framework APIs support this
+ behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
+ and `model.save` can all be called inside or outside the scope.
+* The following can be either inside or outside the scope:
+ * Creating the input datasets
+ * Defining `tf.function`s that represent your training step
+ * Saving APIs such as `tf.saved_model.save`. Loading creates variables, so
+ that should go inside the scope if you want to train the model in a
+ distributed way.
+ * Checkpoint saving. As mentioned above - `checkpoint.restore` may
+ sometimes need to be inside scope if it creates variables.
+
+
+
+
+
+Returns |
+
+
+A context manager.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/Task.md b/tensorflow_gnn/docs/api_docs/python/runner/Task.md
new file mode 100644
index 00000000..770bcb48
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/Task.md
@@ -0,0 +1,179 @@
+# runner.Task
+
+
+
+
+ View source
+on GitHub
+
+Defines a learning objective for a GNN.
+
+
+
+A `Task` represents a learning objective for a GNN model and defines all the
+non-GNN pieces around the base GNN. Specifically:
+
+1) `preprocess` is expected to return a `GraphTensor` (or `GraphTensor`s) and a
+`Field` where (a) the base GNN's output for each `GraphTensor` is passed to
+`predict` and (b) the `Field` is used as the training label (for supervised
+tasks); 2) `predict` is expected to (a) take the base GNN's output for each
+`GraphTensor` returned by `preprocess` and (b) return a tensor with the model's
+prediction for this task; 3) `losses` is expected to return callables
+(`tf.Tensor`, `tf.Tensor`) -> `tf.Tensor` that accept (`y_true`, `y_pred`) where
+`y_true` is produced by some dataset and `y_pred` is the model's prediction from
+(2); 4) `metrics` is expected to return callables (`tf.Tensor`, `tf.Tensor`) ->
+`tf.Tensor` that accept (`y_true`, `y_pred`) where `y_true` is produced by some
+dataset and `y_pred` is the model's prediction from (2).
+
+`Task` can emit multiple outputs in `predict`: in that case we require that (a)
+it is a mapping, (b) outputs of `losses` and `metrics` are also mappings with
+matching keys, and (c) there is exactly one loss per key (there may be a
+sequence of metrics per key). This is done to prevent accidental dropping of
+losses (see b/291874188).
+
+No constraints are made on the `predict` method; e.g.: it may append a head with
+learnable weights or it may perform tensor computations only. (The entire `Task`
+coordinates what that means with respect to dataset—via `preprocess`—,
+modeling—via `predict`— and optimization—via `losses`.)
+
+`Task`s are applied in the scope of a training invocation: they are subject to
+the executing context of the `Trainer` and should, when needed, override it
+(e.g., a global policy, like `tf.keras.mixed_precision.global_policy()` and its
+implications over logit and activation layers).
+
+## Methods
+
+losses
+
+View
+source
+
+
+@abc.abstractmethod
+losses() -> Losses
+
+
+Returns arbitrary task specific losses.
+
+metrics
+
+View
+source
+
+
+@abc.abstractmethod
+metrics() -> Metrics
+
+
+Returns arbitrary task specific metrics.
+
+predict
+
+View
+source
+
+
+@abc.abstractmethod
+predict(
+ *args
+) -> Predictions
+
+
+Produces prediction outputs for the learning objective.
+
+Overall model composition* makes use of the Keras Functional API
+(https://www.tensorflow.org/guide/keras/functional) to map symbolic Keras
+`GraphTensor` inputs to symbolic Keras `Field` outputs. Outputs must match the
+structure (one or mapping) of labels from `preprocess`.
+
+*) `outputs = predict(GNN(inputs))` where `inputs` are those `GraphTensor`
+returned by `preprocess(...)`, `GNN` is the base GNN, `predict` is this method
+and `outputs` are the prediction outputs for the learning objective.
+
+
+
+
+
+Args |
+
+
+
+*args
+ |
+
+The symbolic Keras GraphTensor inputs(s). These inputs correspond
+(in sequence) to the base GNN output of each GraphTensor returned by
+preprocess(...) .
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+The model's prediction output for this task.
+ |
+
+
+
+
+preprocess
+
+View
+source
+
+
+@abc.abstractmethod
+preprocess(
+ inputs: GraphTensor
+) -> tuple[OneOrSequenceOf[GraphTensor], OneOrMappingOf[Field]]
+
+
+Preprocesses a scalar (after `merge_batch_to_components`) `GraphTensor`.
+
+This function uses the Keras functional API to define non-trainable
+transformations of the symbolic input `GraphTensor`, which get executed during
+dataset preprocessing in a `tf.data.Dataset.map(...)` operation. It has two
+responsibilities:
+
+1. Splitting the training label out of the input for training. It must be
+ returned as a separate tensor or mapping of tensors.
+2. Optionally, transforming input features. Some advanced modeling techniques
+ require running the same base GNN on multiple different transformations, so
+ this function may return a single `GraphTensor` or a non-empty sequence of
+ `GraphTensors`. The corresponding base GNN output for each `GraphTensor` is
+ provided to the `predict(...)` method.
+
+
+
+
+
+Args |
+
+
+
+inputs
+ |
+
+A symbolic Keras GraphTensor for processing.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
+be used as labels.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/TightPadding.md b/tensorflow_gnn/docs/api_docs/python/runner/TightPadding.md
new file mode 100644
index 00000000..d859478b
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/TightPadding.md
@@ -0,0 +1,47 @@
+# runner.TightPadding
+
+
+
+
+ View source
+on GitHub
+
+Calculates tight `SizeConstraints` for `GraphTensor` padding.
+
+Inherits From: [`GraphTensorPadding`](../runner/GraphTensorPadding.md)
+
+
+runner.TightPadding(
+ gtspec: tfgnn.GraphTensorSpec,
+ dataset_provider: runner.DatasetProvider ,
+ min_nodes_per_component: Optional[Mapping[str, int]] = None
+)
+
+
+
+
+See: `tfgnn.find_tight_size_constraints.`
+
+## Methods
+
+get_filter_fn
+
+View
+source
+
+
+get_filter_fn(
+ size_constraints: SizeConstraints
+) -> Callable[..., bool]
+
+
+get_size_constraints
+
+View
+source
+
+
+get_size_constraints(
+ target_batch_size: int
+) -> SizeConstraints
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/Trainer.md b/tensorflow_gnn/docs/api_docs/python/runner/Trainer.md
new file mode 100644
index 00000000..1b1db5f8
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/Trainer.md
@@ -0,0 +1,100 @@
+# runner.Trainer
+
+
+
+
+ View source
+on GitHub
+
+A class for training and validation of a Keras model.
+
+
+
+
+
+
+Attributes |
+
+ model_dir |
+
+ | strategy |
+
+ |
+
+
+
+## Methods
+
+train
+
+View
+source
+
+
+@abc.abstractmethod
+train(
+ model_fn: Callable[[], tf.keras.Model],
+ train_ds_provider: DatasetProvider,
+ *,
+ epochs: int = 1,
+ valid_ds_provider: Optional[DatasetProvider] = None
+) -> tf.keras.Model
+
+
+Trains a `tf.keras.Model` with optional validation.
+
+
+
+
+
+Args |
+
+
+
+model_fn
+ |
+
+Returns a tf.keras.Model for use in training and validation.
+ |
+
+
+train_ds_provider
+ |
+
+A DatasetProvider for training. The items of the
+tf.data.Dataset are pairs (graph_tensor, label) that represent one
+batch of per-replica training inputs after
+GraphTensor.merge_batch_to_components() has been applied.
+ |
+
+
+epochs
+ |
+
+The epochs to train.
+ |
+
+
+valid_ds_provider
+ |
+
+A DatasetProvider for validation. The items of the
+tf.data.Dataset are pairs (graph_tensor, label) that represent one
+batch of per-replica training inputs after
+GraphTensor.merge_batch_to_components() has been applied.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A trained tf.keras.Model .
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/all_symbols.md b/tensorflow_gnn/docs/api_docs/python/runner/all_symbols.md
new file mode 100644
index 00000000..b8281392
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/all_symbols.md
@@ -0,0 +1,62 @@
+# All symbols in TensorFlow GNN Runner
+
+
+
+## Primary symbols
+
+* runner
+* runner.ContextLabelFn
+* runner.DatasetProvider
+* runner.DotProductLinkPrediction
+* runner.FitOrSkipPadding
+* runner.GraphBinaryClassification
+* runner.GraphMeanAbsoluteError
+* runner.GraphMeanAbsolutePercentageError
+* runner.GraphMeanSquaredError
+* runner.GraphMeanSquaredLogScaledError
+* runner.GraphMeanSquaredLogarithmicError
+* runner.GraphMulticlassClassification
+* runner.GraphTensorPadding
+* runner.GraphTensorProcessorFn
+* runner.HadamardProductLinkPrediction
+* runner.IntegratedGradientsExporter
+* runner.KerasModelExporter
+* runner.KerasTrainer
+* runner.KerasTrainerCheckpointOptions
+* runner.KerasTrainerOptions
+* runner.Loss
+* runner.Losses
+* runner.Metric
+* runner.Metrics
+* runner.ModelExporter
+* runner.NodeBinaryClassification
+* runner.NodeMulticlassClassification
+* runner.ParameterServerStrategy
+* runner.PassthruDatasetProvider
+* runner.PassthruSampleDatasetsProvider
+* runner.Predictions
+* runner.RootNodeBinaryClassification
+* runner.RootNodeLabelFn
+* runner.RootNodeMeanAbsoluteError
+* runner.RootNodeMeanAbsoluteLogarithmicError
+* runner.RootNodeMeanAbsolutePercentageError
+* runner.RootNodeMeanSquaredError
+* runner.RootNodeMeanSquaredLogScaledError
+* runner.RootNodeMeanSquaredLogarithmicError
+* runner.RootNodeMulticlassClassification
+* runner.RunResult
+* runner.SampleTFRecordDatasetsProvider
+* runner.SimpleDatasetProvider
+* runner.SimpleSampleDatasetsProvider
+* runner.SubmoduleExporter
+* runner.TFDataServiceConfig
+* runner.TFRecordDatasetProvider
+* runner.TPUStrategy
+* runner.Task
+* runner.TightPadding
+* runner.Trainer
+* runner.export_model
+* runner.incrementing_model_dir
+* runner.integrated_gradients
+* runner.one_node_per_component
+* runner.run
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/export_model.md b/tensorflow_gnn/docs/api_docs/python/runner/export_model.md
new file mode 100644
index 00000000..22c27634
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/export_model.md
@@ -0,0 +1,77 @@
+# runner.export_model
+
+
+
+
+ View source
+on GitHub
+
+Exports a Keras model without traces s.t. it is loadable without TF-GNN.
+
+
+runner.export_model(
+ model: tf.keras.Model,
+ export_dir: str,
+ *,
+ output_names: Optional[Any] = None,
+ options: Optional[tf.saved_model.SaveOptions] = None,
+ use_legacy_model_save: Optional[bool] = None
+) -> None
+
+
+
+
+
+
+
+Args |
+
+
+
+model
+ |
+
+Keras model instance to be saved.
+ |
+
+
+export_dir
+ |
+
+Path where to save the model.
+ |
+
+
+output_names
+ |
+
+Optionally, a nest of str values or None with the same
+structure as the outputs of model . A non-None value is used as that
+output's key in the SavedModel signature. By default, an output gets
+the name of the final Keras layer creating it as its key (matching the
+behavior of legacy Model.save(save_format="tf") ).
+ |
+
+
+options
+ |
+
+An optional tf.saved_model.SaveOptions argument.
+ |
+
+
+use_legacy_model_save
+ |
+
+Optional; most users can leave it unset to get a
+useful default for export to inference. If set to True , forces the use
+of Model.save() , which exports a SavedModel suitable for inference and
+potentially also for reloading as a Keras model (depending on its Layers).
+If set to False , forces the use of tf.keras.export.ExportArchive ,
+which is usable as of TensorFlow 2.13 and is advertised as the more
+streamlined way of exporting to SavedModel for inference only. Currently,
+None behaves like True , but the long-term plan is to migrate towards
+False .
+ |
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/incrementing_model_dir.md b/tensorflow_gnn/docs/api_docs/python/runner/incrementing_model_dir.md
new file mode 100644
index 00000000..701754de
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/incrementing_model_dir.md
@@ -0,0 +1,52 @@
+# runner.incrementing_model_dir
+
+
+
+
+ View source
+on GitHub
+
+Create, given some `dirname`, an incrementing model directory.
+
+
+runner.incrementing_model_dir(
+ dirname: str, start: int = 0
+) -> str
+
+
+
+
+
+
+
+Args |
+
+
+
+dirname
+ |
+
+The base directory name.
+ |
+
+
+start
+ |
+
+The starting integer.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A model directory dirname/n where 'n' is the maximum integer in dirname .
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/integrated_gradients.md b/tensorflow_gnn/docs/api_docs/python/runner/integrated_gradients.md
new file mode 100644
index 00000000..7b2c69ef
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/integrated_gradients.md
@@ -0,0 +1,95 @@
+# runner.integrated_gradients
+
+
+
+
+ View source
+on GitHub
+
+Integrated gradients.
+
+
+runner.integrated_gradients(
+ preprocess_model: tf.keras.Model,
+ model: tf.keras.Model,
+ *,
+ output_name: Optional[str] = None,
+ random_counterfactual: bool,
+ steps: int,
+ seed: Optional[int] = None
+) -> tf.types.experimental.ConcreteFunction
+
+
+
+
+This `tf.function` computes integrated gradients over a `tfgnn.GraphTensor.` The
+`tf.function` will be persisted in the ultimate saved model for subsequent
+attribution.
+
+
+
+
+
+Args |
+
+
+
+preprocess_model
+ |
+
+A tf.keras.Model for preprocessing. This model is
+expected to return a tuple (GraphTensor , Tensor ) where the
+GraphTensor is used to invoke the below model and the tensor is used
+used for any loss computation. (Via model.compiled_loss .)
+ |
+
+
+model
+ |
+
+A tf.keras.Model for integrated gradients.
+ |
+
+
+output_name
+ |
+
+The output Tensor name. If unset, the tensor will be named
+by Keras defaults.
+ |
+
+
+random_counterfactual
+ |
+
+Whether to use a random uniform counterfactual.
+ |
+
+
+steps
+ |
+
+The number of interpolations of the Riemann sum approximation.
+ |
+
+
+seed
+ |
+
+An option random seed.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A tf.function with the integrated gradients as output.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/one_node_per_component.md b/tensorflow_gnn/docs/api_docs/python/runner/one_node_per_component.md
new file mode 100644
index 00000000..a260d8b0
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/one_node_per_component.md
@@ -0,0 +1,17 @@
+# runner.one_node_per_component
+
+
+
+
+ View source
+on GitHub
+
+Returns a `Mapping` `node_set_name: 1` for every node set in `gtspec`.
+
+
+runner.one_node_per_component(
+ gtspec: tfgnn.GraphTensorSpec
+) -> Mapping[str, int]
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/runner/run.md b/tensorflow_gnn/docs/api_docs/python/runner/run.md
new file mode 100644
index 00000000..7c234650
--- /dev/null
+++ b/tensorflow_gnn/docs/api_docs/python/runner/run.md
@@ -0,0 +1,275 @@
+# runner.run
+
+
+
+
+ View source
+on GitHub
+
+Runs training (and validation) of a model on task(s) with the given data.
+
+
+runner.run(
+ *,
+ train_ds_provider: DatasetProvider,
+ model_fn: Callable[[GraphTensorSpec], tf.keras.Model],
+ optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
+ trainer: Trainer,
+ task: OneOrMappingOf[Task],
+ loss_weights: Optional[Mapping[str, float]] = None,
+ gtspec: GraphTensorSpec,
+ global_batch_size: int,
+ epochs: int = 1,
+ drop_remainder: bool = False,
+ export_dirs: Optional[Sequence[str]] = None,
+ model_exporters: Optional[Sequence[ModelExporter]] = None,
+ feature_processors: Optional[Sequence[GraphTensorProcessorFn]] = None,
+ valid_ds_provider: Optional[DatasetProvider] = None,
+ train_padding: Optional[GraphTensorPadding] = None,
+ valid_padding: Optional[GraphTensorPadding] = None,
+ tf_data_service_config: Optional[TFDataServiceConfig] = None,
+ steps_per_execution: Optional[int] = None,
+ run_eagerly: bool = False
+)
+
+
+
+
+This includes preprocessing the input data, appending any suitable head(s), and
+running training (and validation) with the requested distribution strategy.
+
+The input data is processed in multiple stages, starting from the contents of
+the datasets provided by `train_ds_provider` and `valid_ds_provider`:
+
+1. Input examples are batched.
+2. If necessary, input batches are parsed as `GraphTensor` values and merged
+ into components (see: `GraphTensor.merge_batch_to_components`).
+3. If set, `train_padding` and `valid_padding`, resp., are applied.
+4. The given `feature_processors` are applied in order for all non-trainable
+ feature transformations on CPU (as part of `tf.data.Dataset.map(...)`).
+5. The
+ Task.preprocess(...)
+ method is applied to extract training targets (for supervised learning, that
+ means: labels) and optionally transform the value of the preprocessed
+ `GraphTensor` into a model input (or multiple model inputs for tasks like
+ self-supervised contrastive losses).
+6. If the resulting `GraphTensor`s have any auxiliary pieces (as indicated by
+ `tfgnn.get_aux_type_prefix(...)`): all features (typically: labels) are
+ removed from those graph pieces.
+
+The base GNN (as built by `model_fn`) is run on all results from step (6).
+Task.predict(...) is called
+on the model outputs that correspond to the one or more graphs requested in step
+(5) by
+Task.preprocess(...) .
+
+Trainable transformations of inputs (notably lookups in trainable embedding
+tables) are required to happen inside `model_fn`.
+
+For supervised learning, training labels enter the pipeline as features on the
+`GraphTensor` that undergo the `feature_processors` (shared by all `Task`s) and
+are read out of the `GraphTensor` by
+Task.preprocess(...) .
+
+Users are strongly encouraged to take one of the following two approaches to
+prevent the leakage of label information into the training:
+
+* Store labels on the auxiliary `"_readout"` node set and let
+ Task.preprocess(...)
+ read them from there. (For library-supplied `Task`s, that means initializing
+ with `label_feature_name="..."`.) If that is not already true for the input
+ datasets, the label feature can be moved there by one of the
+ `feature_processors`, using `tfgnn.structured_readout_into_feature(...)` or
+ a similar helper function.
+* For single-Task training only: Let
+ Task.preprocess()
+ return modified `GraphTensor`s that no longer contain the separately
+ returned labels. (Library-supplied Tasks delegate this to the
+ `label_fn="..."` passed in initialization.)
+
+
+
+
+
+Args |
+
+
+
+train_ds_provider
+ |
+
+A DatasetProvider for training. The tf.data.Dataset
+is not batched and contains scalar GraphTensor values conforming to
+gtspec , possibly serialized as a tf.train.Example proto.
+ |
+
+
+model_fn
+ |
+
+Returns the base GNN tf.keras.Model for use in training and
+validation.
+ |
+
+
+optimizer_fn
+ |
+
+Returns a tf.keras.optimizers.Optimizer for use in training.
+ |
+
+
+trainer
+ |
+
+A Trainer .
+ |
+
+
+task
+ |
+
+A Task for single-Task training or a Mapping[str, Task] for
+multi-Task training. In multi-Task training, Task.preprocess(...)
+must return GraphTensors with the same spec as its inputs, only the
+values may change (so that there remains a single spec for model_fn ).
+ |
+
+
+loss_weights
+ |
+
+An optional Mapping[str, float] for multi-Task training. If
+given, this structure must match (with tf.nest.assert_same_structure )
+the structure of task . The mapping contains, for each task , a scalar
+coefficient to weight the loss contributions of that task .
+ |
+
+
+gtspec
+ |
+
+A GraphTensorSpec matching the elements of train and valid
+datasets. If train or valid contain tf.string elements, this
+GraphTensorSpec is used for parsing; otherwise, train or valid are
+expected to contain GraphTensor elements whose relaxed spec matches
+gtspec .
+ |
+
+
+global_batch_size
+ |
+
+The tf.data.Dataset global batch size for both training
+and validation.
+ |
+
+
+epochs
+ |
+
+The epochs to train.
+ |
+
+
+drop_remainder
+ |
+
+Whether to drop a tf.data.Dataset remainder at batching.
+ |
+
+
+export_dirs
+ |
+
+Optional directories for exports (SavedModels); if unset,
+default behavior is os.path.join(model_dir, "export") .
+ |
+
+
+model_exporters
+ |
+
+Zero or more ModelExporter for exporting (SavedModels) to
+export_dirs . If unset, default behavior is [KerasModelExporter()] .
+ |
+
+
+feature_processors
+ |
+
+A sequence of callables for feature processing with the
+Keras functional API. Each callable must accept and return a symbolic
+scalar GraphTensor . The callables are composed in order and may change
+the GraphTensorSpec (e.g., add/remove features). The resulting Keras
+model is executed on CPU as part of a tf.data.Dataset.map operation.
+ |
+
+
+valid_ds_provider
+ |
+
+A DatasetProvider for validation. The tf.data.Dataset
+is not batched and contains scalar GraphTensor values conforming to
+gtspec , possibly serialized as a tf.train.Example proto.
+ |
+
+
+train_padding
+ |
+
+GraphTensor padding for training. Required if training on
+TPU.
+ |
+
+
+valid_padding
+ |
+
+GraphTensor padding for validation. Required if training on
+TPU.
+ |
+
+
+tf_data_service_config
+ |
+
+tf.data service speeds-up tf.data input pipeline
+runtime reducing input bottlenecks for model training. Particularly for
+training on accelerators consider enabling it. For more info please see:
+https://www.tensorflow.org/api_docs/python/tf/data/experimental/service.
+ |
+
+
+steps_per_execution
+ |
+
+The number of batches to run during each training
+iteration. If not set, for TPU strategy default to 100 and to None
+otherwise.
+ |
+
+
+run_eagerly
+ |
+
+Whether to compile the model in eager mode, primarily for
+debugging purposes. Note that the symbolic model will still be run twice,
+so if you use a breakpoint() you will have to Continue twice before you
+are in a real eager execution.
+ |
+
+
+
+
+
+
+
+Returns |
+
+
+A RunResult object containing models and information about this run.
+ |
+
+
+
diff --git a/tensorflow_gnn/docs/api_docs/python/tfgnn.md b/tensorflow_gnn/docs/api_docs/python/tfgnn.md
index 43d214c5..2d03c918 100644
--- a/tensorflow_gnn/docs/api_docs/python/tfgnn.md
+++ b/tensorflow_gnn/docs/api_docs/python/tfgnn.md
@@ -1,32 +1,20 @@
# Module: tfgnn
-[TOC]
-
-
-
-
+
+ View source
+on GitHub
Public interface for TensorFlow GNN package.
-
All the public symbols, data types and functions are provided from this
top-level package. To use the library, you should use a single import statement,
like this:
- import tensorflow_gnn as tfgnn
-
-The various data types provided by the GNN library have corresponding schemas
-similar to `tf.TensorSpec`. For example, a `FieldSpec` describes an instance of
-`Field`, and a `GraphTensorSpec` describes an instance of `GraphTensor`.
+```
+import tensorflow_gnn as tfgnn
+```
## Modules
@@ -35,43 +23,63 @@ of the public interface of TensorFlow GNN.
[`keras`](./tfgnn/keras.md) module: The tfgnn.keras package.
+[`proto`](./tfgnn/proto.md) module: The protocol message (protobuf) types
+defined by TensorFlow GNN.
+
[`sampler`](./tfgnn/sampler.md) module: Public interface for GNN Sampler.
## Classes
-[`class Adjacency`](./tfgnn/Adjacency.md): Stores how edges connect pairs of nodes from source and target node sets.
+[`class Adjacency`](./tfgnn/Adjacency.md): Stores how edges connect pairs of
+nodes from source and target node sets.
-[`class AdjacencySpec`](./tfgnn/AdjacencySpec.md): A type spec for tfgnn.Adjacency .
+[`class AdjacencySpec`](./tfgnn/AdjacencySpec.md): A type spec for
+tfgnn.Adjacency .
-[`class Context`](./tfgnn/Context.md): A composite tensor for graph context features.
+[`class Context`](./tfgnn/Context.md): A composite tensor for graph context
+features.
-[`class ContextSpec`](./tfgnn/ContextSpec.md): A type spec for tfgnn.Context .
+[`class ContextSpec`](./tfgnn/ContextSpec.md): A type spec for
+tfgnn.Context .
-[`class EdgeSet`](./tfgnn/EdgeSet.md): A composite tensor for edge set features, size and adjacency information.
+[`class EdgeSet`](./tfgnn/EdgeSet.md): A composite tensor for edge set features,
+size and adjacency information.
-[`class EdgeSetSpec`](./tfgnn/EdgeSetSpec.md): A type spec for tfgnn.EdgeSet .
+[`class EdgeSetSpec`](./tfgnn/EdgeSetSpec.md): A type spec for
+tfgnn.EdgeSet .
-[`class Feature`](./tfgnn/Feature.md): A schema for a single feature.
+[`class Feature`](./tfgnn/proto/Feature.md): The schema entry for a single
+feature.
-[`class FeatureDefaultValues`](./tfgnn/FeatureDefaultValues.md): Default values for graph context, node sets and edge sets features.
+[`class FeatureDefaultValues`](./tfgnn/FeatureDefaultValues.md): Default values
+for graph context, node sets and edge sets features.
-[`class GraphSchema`](./tfgnn/GraphSchema.md): A schema definition for graphs.
+[`class GraphSchema`](./tfgnn/proto/GraphSchema.md): The top-level container for
+the schema of a graph dataset.
-[`class GraphTensor`](./tfgnn/GraphTensor.md): A composite tensor for heterogeneous directed graphs with features.
+[`class GraphTensor`](./tfgnn/GraphTensor.md): A composite tensor for
+heterogeneous directed graphs with features.
-[`class GraphTensorSpec`](./tfgnn/GraphTensorSpec.md): A type spec for tfgnn.GraphTensor .
+[`class GraphTensorSpec`](./tfgnn/GraphTensorSpec.md): A type spec for
+tfgnn.GraphTensor .
-[`class HyperAdjacency`](./tfgnn/HyperAdjacency.md): Stores how (hyper-)edges connect tuples of nodes from incident node sets.
+[`class HyperAdjacency`](./tfgnn/HyperAdjacency.md): Stores how (hyper-)edges
+connect tuples of nodes from incident node sets.
-[`class HyperAdjacencySpec`](./tfgnn/HyperAdjacencySpec.md): A type spec for tfgnn.HyperAdjacency .
+[`class HyperAdjacencySpec`](./tfgnn/HyperAdjacencySpec.md): A type spec for
+tfgnn.HyperAdjacency .
-[`class NodeSet`](./tfgnn/NodeSet.md): A composite tensor for node set features plus size information.
+[`class NodeSet`](./tfgnn/NodeSet.md): A composite tensor for node set features
+plus size information.
-[`class NodeSetSpec`](./tfgnn/NodeSetSpec.md): A type spec for tfgnn.NodeSet .
+[`class NodeSetSpec`](./tfgnn/NodeSetSpec.md): A type spec for
+tfgnn.NodeSet .
-[`class SizeConstraints`](./tfgnn/SizeConstraints.md): Constraints on the number of entities in the graph.
+[`class SizeConstraints`](./tfgnn/SizeConstraints.md): Constraints on the number
+of entities in the graph.
-[`class ValidationError`](./tfgnn/ValidationError.md): A schema validation error.
+[`class ValidationError`](./tfgnn/ValidationError.md): A schema validation
+error.
## Functions
@@ -79,23 +87,29 @@ of the public interface of TensorFlow GNN.
Adds a readout structure equivalent to
tfgnn.gather_first_node() .
-[`add_self_loops(...)`](./tfgnn/add_self_loops.md): Adds self-loops for edge
-with name `edge_set_name` EVEN if already exist.
+[`add_self_loops(...)`](./tfgnn/add_self_loops.md): Adds self-loops for
+`edge_set_name` EVEN if they already exist.
-[`assert_constraints(...)`](./tfgnn/assert_constraints.md): Validate the shape constaints of a graph's features at runtime.
+[`assert_constraints(...)`](./tfgnn/assert_constraints.md): Validate the shape
+constaints of a graph's features at runtime.
-[`assert_satisfies_size_constraints(...)`](./tfgnn/assert_satisfies_size_constraints.md): Raises InvalidArgumentError if graph_tensor exceeds size_constraints.
+[`assert_satisfies_size_constraints(...)`](./tfgnn/assert_satisfies_size_constraints.md):
+Raises InvalidArgumentError if graph_tensor exceeds size_constraints.
-[`assert_satisfies_total_sizes(...)`](./tfgnn/assert_satisfies_size_constraints.md): Raises InvalidArgumentError if graph_tensor exceeds size_constraints.
+[`assert_satisfies_total_sizes(...)`](./tfgnn/assert_satisfies_size_constraints.md):
+Raises InvalidArgumentError if graph_tensor exceeds size_constraints.
[`broadcast(...)`](./tfgnn/broadcast.md): Broadcasts values from nodes to edges,
or from context to nodes or edges.
-[`broadcast_context_to_edges(...)`](./tfgnn/broadcast_context_to_edges.md): Broadcasts a context value to the `edge_set` edges.
+[`broadcast_context_to_edges(...)`](./tfgnn/broadcast_context_to_edges.md):
+Broadcasts a context value to the `edge_set` edges.
-[`broadcast_context_to_nodes(...)`](./tfgnn/broadcast_context_to_nodes.md): Broadcasts a context value to the `node_set` nodes.
+[`broadcast_context_to_nodes(...)`](./tfgnn/broadcast_context_to_nodes.md):
+Broadcasts a context value to the `node_set` nodes.
-[`broadcast_node_to_edges(...)`](./tfgnn/broadcast_node_to_edges.md): Broadcasts values from nodes to incident edges.
+[`broadcast_node_to_edges(...)`](./tfgnn/broadcast_node_to_edges.md): Broadcasts
+values from nodes to incident edges.
[`check_compatible_with_schema_pb(...)`](./tfgnn/check_compatible_with_schema_pb.md):
Checks that the given spec or value is compatible with the graph schema.
@@ -103,28 +117,47 @@ Checks that the given spec or value is compatible with the graph schema.
[`check_homogeneous_graph_tensor(...)`](./tfgnn/check_homogeneous_graph_tensor.md):
Raises ValueError when tfgnn.get_homogeneous_node_and_edge_set_name() does.
-[`check_required_features(...)`](./tfgnn/check_required_features.md): Checks the requirements of a given schema against another.
+[`check_required_features(...)`](./tfgnn/check_required_features.md): Checks the
+requirements of a given schema against another.
-[`check_scalar_graph_tensor(...)`](./tfgnn/check_scalar_graph_tensor.md)
+[`check_scalar_graph_tensor(...)`](./tfgnn/check_scalar_graph_tensor.md): Checks
+that graph tensor is scalar (has rank 0).
-[`combine_values(...)`](./tfgnn/combine_values.md): Combines a list of tensors into one (by concatenation or otherwise).
+[`combine_values(...)`](./tfgnn/combine_values.md): Combines a list of tensors
+into one (by concatenation or otherwise).
[`convert_to_line_graph(...)`](./tfgnn/convert_to_line_graph.md): Obtain a
graph's line graph.
-[`create_graph_spec_from_schema_pb(...)`](./tfgnn/create_graph_spec_from_schema_pb.md): Converts a graph schema proto message to a scalar GraphTensorSpec.
+[`create_graph_spec_from_schema_pb(...)`](./tfgnn/create_graph_spec_from_schema_pb.md):
+Converts a graph schema proto message to a scalar GraphTensorSpec.
[`create_schema_pb_from_graph_spec(...)`](./tfgnn/create_schema_pb_from_graph_spec.md):
Converts scalar GraphTensorSpec to a graph schema proto message.
-[`dataset_filter_with_summary(...)`](./tfgnn/dataset_filter_with_summary.md): Dataset filter with a summary for the fraction of dataset elements removed.
+[`dataset_filter_with_summary(...)`](./tfgnn/dataset_filter_with_summary.md):
+Dataset filter with a summary for the fraction of dataset elements removed.
[`dataset_from_generator(...)`](./tfgnn/dataset_from_generator.md): Creates
dataset from generator of any nest of scalar graph pieces.
-[`find_tight_size_constraints(...)`](./tfgnn/find_tight_size_constraints.md): Returns smallest possible size constraints that allow dataset padding.
+[`disable_graph_tensor_validation(...)`](./tfgnn/disable_graph_tensor_validation.md):
+Disables both static and runtime checks of graph tensors.
-[`gather_first_node(...)`](./tfgnn/gather_first_node.md): Gathers feature value from the first node of each graph component.
+[`disable_graph_tensor_validation_at_runtime(...)`](./tfgnn/disable_graph_tensor_validation_at_runtime.md):
+Disables runtime checks (`tf.debugging.Assert`) of graph tensors.
+
+[`enable_graph_tensor_validation(...)`](./tfgnn/enable_graph_tensor_validation.md):
+Enables static checks of graph tensors.
+
+[`enable_graph_tensor_validation_at_runtime(...)`](./tfgnn/enable_graph_tensor_validation_at_runtime.md):
+Enables both static and runtime checks of graph tensors.
+
+[`find_tight_size_constraints(...)`](./tfgnn/find_tight_size_constraints.md):
+Returns smallest possible size constraints that allow dataset padding.
+
+[`gather_first_node(...)`](./tfgnn/gather_first_node.md): Gathers feature value
+from the first node of each graph component.
[`get_aux_type_prefix(...)`](./tfgnn/get_aux_type_prefix.md): Returns type
prefix of aux node or edge set names, or `None` if non-aux.
@@ -132,11 +165,14 @@ prefix of aux node or edge set names, or `None` if non-aux.
[`get_homogeneous_node_and_edge_set_name(...)`](./tfgnn/get_homogeneous_node_and_edge_set_name.md):
Returns the sole `node_set_name, edge_set_name` or raises `ValueError`.
-[`get_io_spec(...)`](./tfgnn/get_io_spec.md): Returns tf.io parsing features for `GraphTensorSpec` type spec.
+[`get_io_spec(...)`](./tfgnn/get_io_spec.md): Returns tf.io parsing features for
+`GraphTensorSpec` type spec.
-[`get_registered_reduce_operation_names(...)`](./tfgnn/get_registered_reduce_operation_names.md): Returns the registered list of supported reduce operation names.
+[`get_registered_reduce_operation_names(...)`](./tfgnn/get_registered_reduce_operation_names.md):
+Returns the registered list of supported reduce operation names.
-[`graph_tensor_to_values(...)`](./tfgnn/graph_tensor_to_values.md): Convert an eager `GraphTensor` to a mapping of mappings of PODTs.
+[`graph_tensor_to_values(...)`](./tfgnn/graph_tensor_to_values.md): Convert an
+eager `GraphTensor` to a mapping of mappings of PODTs.
[`homogeneous(...)`](./tfgnn/homogeneous.md): Constructs a homogeneous
`GraphTensor` with node features and one edge_set.
@@ -144,16 +180,20 @@ Returns the sole `node_set_name, edge_set_name` or raises `ValueError`.
[`is_dense_tensor(...)`](./tfgnn/is_dense_tensor.md): Returns whether a tensor
(TF or Keras) is a Tensor.
-[`is_graph_tensor(...)`](./tfgnn/is_graph_tensor.md): Returns whether `value` is a GraphTensor (possibly wrapped for Keras).
+[`is_graph_tensor(...)`](./tfgnn/is_graph_tensor.md): Returns whether `value` is
+a GraphTensor (possibly wrapped for Keras).
[`is_ragged_tensor(...)`](./tfgnn/is_ragged_tensor.md): Returns whether a tensor
(TF or Keras) is a RaggedTensor.
-[`iter_features(...)`](./tfgnn/iter_features.md): Utility function to iterate over the features of a graph schema.
+[`iter_features(...)`](./tfgnn/iter_features.md): Utility function to iterate
+over the features of a graph schema.
-[`iter_sets(...)`](./tfgnn/iter_sets.md): Utility function to iterate over all the sets present in a graph schema.
+[`iter_sets(...)`](./tfgnn/iter_sets.md): Utility function to iterate over all
+the sets present in a graph schema.
-[`learn_fit_or_skip_size_constraints(...)`](./tfgnn/learn_fit_or_skip_size_constraints.md): Learns the optimal size constraints for the fixed size batching with retry.
+[`learn_fit_or_skip_size_constraints(...)`](./tfgnn/learn_fit_or_skip_size_constraints.md):
+Learns the optimal size constraints for the fixed size batching with retry.
[`mask_edges(...)`](./tfgnn/mask_edges.md): Creates a GraphTensor after applying
edge_mask over the specified edge-set.
@@ -161,35 +201,53 @@ edge_mask over the specified edge-set.
[`node_degree(...)`](./tfgnn/node_degree.md): Returns the degree of each node
w.r.t. one side of an edge set.
-[`pad_to_total_sizes(...)`](./tfgnn/pad_to_total_sizes.md): Pads graph tensor to the total sizes by inserting fake graph components.
+[`pad_to_total_sizes(...)`](./tfgnn/pad_to_total_sizes.md): Pads graph tensor to
+the total sizes by inserting fake graph components.
-[`parse_example(...)`](./tfgnn/parse_example.md): Parses a batch of serialized Example protos into a single `GraphTensor`.
+[`parse_example(...)`](./tfgnn/parse_example.md): Parses a batch of serialized
+Example protos into a single `GraphTensor`.
-[`parse_schema(...)`](./tfgnn/parse_schema.md): Parse a schema from text-formatted protos.
+[`parse_schema(...)`](./tfgnn/parse_schema.md): Parse a schema from
+text-formatted protos.
-[`parse_single_example(...)`](./tfgnn/parse_single_example.md): Parses a single serialized Example proto into a single `GraphTensor`.
+[`parse_single_example(...)`](./tfgnn/parse_single_example.md): Parses a single
+serialized Example proto into a single `GraphTensor`.
[`pool(...)`](./tfgnn/pool.md): Pools values from edges to nodes, or from nodes
or edges to context.
-[`pool_edges_to_context(...)`](./tfgnn/pool_edges_to_context.md): Aggregates (pools) edge values to graph context.
+[`pool_edges_to_context(...)`](./tfgnn/pool_edges_to_context.md): Aggregates
+(pools) edge values to graph context.
+
+[`pool_edges_to_node(...)`](./tfgnn/pool_edges_to_node.md): Aggregates (pools)
+edge values to incident nodes.
-[`pool_edges_to_node(...)`](./tfgnn/pool_edges_to_node.md): Aggregates (pools) edge values to incident nodes.
+[`pool_neighbors_to_node(...)`](./tfgnn/pool_neighbors_to_node.md): Aggregates
+(pools) neighbor node values along one or more edge sets.
-[`pool_nodes_to_context(...)`](./tfgnn/pool_nodes_to_context.md): Aggregates (pools) node values to graph context.
+[`pool_neighbors_to_node_feature(...)`](./tfgnn/pool_neighbors_to_node_feature.md):
+Aggregates (pools) sender node feature to receiver nodes feature.
-[`random_graph_tensor(...)`](./tfgnn/random_graph_tensor.md): Generate a graph tensor from a schema, with random features.
+[`pool_nodes_to_context(...)`](./tfgnn/pool_nodes_to_context.md): Aggregates
+(pools) node values to graph context.
-[`read_schema(...)`](./tfgnn/read_schema.md): Read a proto schema from a file with text-formatted contents.
+[`random_graph_tensor(...)`](./tfgnn/random_graph_tensor.md): Generate a graph
+tensor from a spec, with random features.
+
+[`read_schema(...)`](./tfgnn/read_schema.md): Read a proto schema from a file
+with text-formatted contents.
[`reorder_nodes(...)`](./tfgnn/reorder_nodes.md): Reorders nodes within node
sets according to indices.
-[`reverse_tag(...)`](./tfgnn/reverse_tag.md): Flips tfgnn.SOURCE to tfgnn.TARGET and vice versa.
+[`reverse_tag(...)`](./tfgnn/reverse_tag.md): Flips tfgnn.SOURCE to tfgnn.TARGET
+and vice versa.
-[`satisfies_size_constraints(...)`](./tfgnn/satisfies_size_constraints.md): Returns whether the input `graph_tensor` satisfies `total_sizes`.
+[`satisfies_size_constraints(...)`](./tfgnn/satisfies_size_constraints.md):
+Returns whether the input `graph_tensor` satisfies `total_sizes`.
-[`satisfies_total_sizes(...)`](./tfgnn/satisfies_size_constraints.md): Returns whether the input `graph_tensor` satisfies `total_sizes`.
+[`satisfies_total_sizes(...)`](./tfgnn/satisfies_size_constraints.md): Returns
+whether the input `graph_tensor` satisfies `total_sizes`.
[`shuffle_features_globally(...)`](./tfgnn/shuffle_features_globally.md):
Shuffles context, node set and edge set features of a scalar GraphTensor.
@@ -197,9 +255,11 @@ Shuffles context, node set and edge set features of a scalar GraphTensor.
[`shuffle_nodes(...)`](./tfgnn/shuffle_nodes.md): Randomly reorders nodes of
given node sets, within each graph component.
-[`softmax(...)`](./tfgnn/softmax.md): Computes softmax over a many-to-one relationship in a GraphTensor.
+[`softmax(...)`](./tfgnn/softmax.md): Computes softmax over a many-to-one
+relationship in a GraphTensor.
-[`softmax_edges_per_node(...)`](./tfgnn/softmax_edges_per_node.md): Returns softmax() of edge values per common `node_tag` node.
+[`softmax_edges_per_node(...)`](./tfgnn/softmax_edges_per_node.md): Returns
+softmax() of edge values per common `node_tag` node.
[`structured_readout(...)`](./tfgnn/structured_readout.md): Reads out a feature
value from select nodes (or edges) in a graph.
@@ -213,11 +273,14 @@ Checks `graph` supports `structured_readout()` from `required_keys`.
[`validate_graph_tensor_spec_for_readout(...)`](./tfgnn/validate_graph_tensor_spec_for_readout.md):
Checks `graph_spec` supports `structured_readout()` from `required_keys`.
-[`validate_schema(...)`](./tfgnn/validate_schema.md): Validates the correctness of a graph schema instance.
+[`validate_schema(...)`](./tfgnn/validate_schema.md): Validates the correctness
+of a graph schema instance.
-[`write_example(...)`](./tfgnn/write_example.md): Encode an eager `GraphTensor` to a tf.train.Example proto.
+[`write_example(...)`](./tfgnn/write_example.md): Encode an eager `GraphTensor`
+to a tf.train.Example proto.
-[`write_schema(...)`](./tfgnn/write_schema.md): Write a `GraphSchema` to a text-formatted proto file.
+[`write_schema(...)`](./tfgnn/write_schema.md): Write a `GraphSchema` to a
+text-formatted proto file.
## Type Aliases
@@ -233,9 +296,8 @@ Checks `graph_spec` supports `structured_readout()` from `required_keys`.
[`IncidentNodeOrContextTag`](./tfgnn/IncidentNodeOrContextTag.md)
-
-
+
Other Members |
@@ -245,70 +307,70 @@ Checks `graph_spec` supports `structured_readout()` from `required_keys`.
CONTEXT
-`'context'`
+'context'
|
EDGES
|
-`'edges'`
+'edges'
|
HIDDEN_STATE
|
-`'hidden_state'`
+'hidden_state'
|
NODES
|
-`'nodes'`
+'nodes'
|
SIZE_NAME
|
-`'#size'`
+'#size'
|
SOURCE
|
-`0`
+0
|
SOURCE_NAME
|
-`'#source'`
+'#source'
|
TARGET
|
-`1`
+1
|
TARGET_NAME
|
-`'#target'`
+'#target'
|
**version**
|
-`'0.6.0.dev1'`
+'1.0.0.dev2'
|
diff --git a/tensorflow_gnn/docs/api_docs/python/tfgnn/Adjacency.md b/tensorflow_gnn/docs/api_docs/python/tfgnn/Adjacency.md
index f6caa09c..69c0c2c1 100644
--- a/tensorflow_gnn/docs/api_docs/python/tfgnn/Adjacency.md
+++ b/tensorflow_gnn/docs/api_docs/python/tfgnn/Adjacency.md
@@ -1,17 +1,10 @@
# tfgnn.Adjacency
-[TOC]
-
-
+
+ View source
+on GitHub
Stores how edges connect pairs of nodes from source and target node sets.
@@ -19,103 +12,99 @@ Inherits From: [`HyperAdjacency`](../tfgnn/HyperAdjacency.md)
tfgnn.Adjacency(
- data: Data, spec: 'GraphPieceSpecBase', validate: bool = False
+ data: Data, spec: 'GraphPieceSpecBase'
)
-
-
-Each hyper-edge connect one node from the source node set with one node from
-the target node sets. The source and target node sets could be the same.
-The adjacency information is a pair of integer tensors containing indices of
-nodes in source and target node sets. Those tensors are indexed by
-edges, have the same type spec and shape of `[*graph_shape, num_edges]`,
-where `num_edges` is the number of edges in the edge set (could be potentially
-ragged). The index tensors are of `tf.Tensor` type if `num_edges` is not
-`None` or `graph_shape.rank = 0` and of`tf.RaggedTensor` type otherwise.
+Each hyper-edge connect one node from the source node set with one node from the
+target node sets. The source and target node sets could be the same. The
+adjacency information is a pair of integer tensors containing indices of nodes
+in source and target node sets. Those tensors are indexed by edges, have the
+same type spec and shape of `[*graph_shape, num_edges]`, where `num_edges` is
+the number of edges in the edge set (could be potentially ragged). The index
+tensors are of `tf.Tensor` type if `num_edges` is not `None` or
+`graph_shape.rank = 0` and of`tf.RaggedTensor` type otherwise.
The Adjacency is a composite tensor and a special case of tfgnn.HyperAdjacency
-class with tfgnn.SOURCE and tfgnn.TARGET node tags used for the source and
-target nodes correspondingly.
+class with tfgnn.SOURCE and
+tfgnn.TARGET node tags used for
+the source and target nodes correspondingly.
+
Args |
-`data`
+data
|
Nest of Field or subclasses of GraphPieceBase.
|
-`spec`
+spec
|
-A subclass of GraphPieceSpecBase with a `_data_spec` that matches
-`data`.
- |
-
-
-`validate`
- |
-
-if set, checks that data and spec are aligned, compatible and
-supported.
+A subclass of GraphPieceSpecBase with a _data_spec that matches
+data .
|
+
| | | | |