-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
240 changed files
with
15,611 additions
and
4,234 deletions.
There are no files selected for viewing
69 changes: 69 additions & 0 deletions
69
tensorflow_gnn/docs/api_docs/python/models/contrastive_losses.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Module: contrastive_losses | ||
|
||
<!-- Insert buttons and diff --> | ||
|
||
<a target="_blank" href="https://github.com/tensorflow/gnn/tree/master/tensorflow_gnn/models/contrastive_losses/__init__.py"> | ||
<img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source | ||
on GitHub </a> | ||
|
||
Contrastive losses. | ||
|
||
Users of TF-GNN can use these layers by importing them next to the core library: | ||
|
||
```python | ||
import tensorflow_gnn as tfgnn | ||
from tensorflow_gnn.models import contrastive_losses | ||
``` | ||
|
||
## Classes | ||
|
||
[`class AllSvdMetrics`](./contrastive_losses/AllSvdMetrics.md): Computes | ||
multiple metrics for representations using one SVD call. | ||
|
||
[`class BarlowTwinsTask`](./contrastive_losses/BarlowTwinsTask.md): A Barlow | ||
Twins (BT) Task. | ||
|
||
[`class ContrastiveLossTask`](./contrastive_losses/ContrastiveLossTask.md): Base | ||
class for unsupervised contrastive representation learning tasks. | ||
|
||
[`class CorruptionSpec`](./contrastive_losses/CorruptionSpec.md): Class for | ||
defining corruption specification. | ||
|
||
[`class Corruptor`](./contrastive_losses/Corruptor.md): Base class for graph | ||
corruptor. | ||
|
||
[`class DeepGraphInfomaxLogits`](./contrastive_losses/DeepGraphInfomaxLogits.md): | ||
Computes clean and corrupted logits for Deep Graph Infomax (DGI). | ||
|
||
[`class DeepGraphInfomaxTask`](./contrastive_losses/DeepGraphInfomaxTask.md): A | ||
Deep Graph Infomax (DGI) Task. | ||
|
||
[`class DropoutFeatures`](./contrastive_losses/DropoutFeatures.md): Base class | ||
for graph corruptor. | ||
|
||
[`class ShuffleFeaturesGlobally`](./contrastive_losses/ShuffleFeaturesGlobally.md): | ||
A corruptor that shuffles features. | ||
|
||
[`class TripletEmbeddingSquaredDistances`](./contrastive_losses/TripletEmbeddingSquaredDistances.md): | ||
Computes embeddings distance between positive and negative pairs. | ||
|
||
[`class TripletLossTask`](./contrastive_losses/TripletLossTask.md): The triplet | ||
loss task. | ||
|
||
[`class VicRegTask`](./contrastive_losses/VicRegTask.md): A VICReg Task. | ||
|
||
## Functions | ||
|
||
[`coherence(...)`](./contrastive_losses/coherence.md): Coherence metric | ||
implementation. | ||
|
||
[`numerical_rank(...)`](./contrastive_losses/numerical_rank.md): Numerical rank | ||
implementation. | ||
|
||
[`pseudo_condition_number(...)`](./contrastive_losses/pseudo_condition_number.md): | ||
Pseudo-condition number metric implementation. | ||
|
||
[`rankme(...)`](./contrastive_losses/rankme.md): RankMe metric implementation. | ||
|
||
[`self_clustering(...)`](./contrastive_losses/self_clustering.md): | ||
Self-clustering metric implementation. |
202 changes: 202 additions & 0 deletions
202
tensorflow_gnn/docs/api_docs/python/models/contrastive_losses/AllSvdMetrics.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# contrastive_losses.AllSvdMetrics | ||
|
||
<!-- Insert buttons and diff --> | ||
|
||
<a target="_blank" href="https://github.com/tensorflow/gnn/tree/master/tensorflow_gnn/models/contrastive_losses/metrics.py#L337-L348"> | ||
<img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source | ||
on GitHub </a> | ||
|
||
Computes multiple metrics for representations using one SVD call. | ||
|
||
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> | ||
<code>contrastive_losses.AllSvdMetrics( | ||
*args, **kwargs | ||
) | ||
</code></pre> | ||
|
||
<!-- Placeholder for "Used in" --> | ||
|
||
Refer to https://arxiv.org/abs/2305.16562 for more details. | ||
|
||
<!-- Tabular view --> | ||
|
||
<table class="responsive fixed orange"> | ||
<colgroup><col width="214px"><col></colgroup> | ||
<tr><th colspan="2"><h2 class="add-link">Args</h2></th></tr> | ||
|
||
<tr> | ||
<td> | ||
<code>fns</code><a id="fns"></a> | ||
</td> | ||
<td> | ||
a mapping from a metric name to a <code>Callable</code> that accepts | ||
representations as well as the result of their SVD decomposition. | ||
Currently only singular values are passed. | ||
</td> | ||
</tr><tr> | ||
<td> | ||
<code>y_pred_transform_fn</code><a id="y_pred_transform_fn"></a> | ||
</td> | ||
<td> | ||
a function to extract clean representations | ||
from model predictions. By default, no transformation is applied. | ||
</td> | ||
</tr><tr> | ||
<td> | ||
<code>name</code><a id="name"></a> | ||
</td> | ||
<td> | ||
Name for the metric class, used for Keras bookkeeping. | ||
</td> | ||
</tr> | ||
</table> | ||
|
||
## Methods | ||
|
||
<h3 id="merge_state"><code>merge_state</code></h3> | ||
|
||
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> | ||
<code>merge_state( | ||
metrics | ||
) | ||
</code></pre> | ||
|
||
Merges the state from one or more metrics. | ||
|
||
This method can be used by distributed systems to merge the state computed by | ||
different metric instances. Typically the state will be stored in the form of | ||
the metric's weights. For example, a tf.keras.metrics.Mean metric contains a | ||
list of two weight values: a total and a count. If there were two instances of a | ||
tf.keras.metrics.Accuracy that each independently aggregated partial state for | ||
an overall accuracy calculation, these two metric's states could be combined as | ||
follows: | ||
|
||
``` | ||
>>> m1 = tf.keras.metrics.Accuracy() | ||
>>> _ = m1.update_state([[1], [2]], [[0], [2]]) | ||
``` | ||
|
||
``` | ||
>>> m2 = tf.keras.metrics.Accuracy() | ||
>>> _ = m2.update_state([[3], [4]], [[3], [4]]) | ||
``` | ||
|
||
``` | ||
>>> m2.merge_state([m1]) | ||
>>> m2.result().numpy() | ||
0.75 | ||
``` | ||
|
||
<!-- Tabular view --> | ||
|
||
<table class="responsive fixed orange"> | ||
<colgroup><col width="214px"><col></colgroup> | ||
<tr><th colspan="2">Args</th></tr> | ||
|
||
<tr> | ||
<td> | ||
<code>metrics</code> | ||
</td> | ||
<td> | ||
an iterable of metrics. The metrics must have compatible | ||
state. | ||
</td> | ||
</tr> | ||
</table> | ||
|
||
<!-- Tabular view --> | ||
|
||
<table class="responsive fixed orange"> | ||
<colgroup><col width="214px"><col></colgroup> | ||
<tr><th colspan="2">Raises</th></tr> | ||
|
||
<tr> | ||
<td> | ||
<code>ValueError</code> | ||
</td> | ||
<td> | ||
If the provided iterable does not contain metrics matching | ||
the metric's required specifications. | ||
</td> | ||
</tr> | ||
</table> | ||
|
||
<h3 id="reset_state"><code>reset_state</code></h3> | ||
|
||
<a target="_blank" class="external" href="https://github.com/tensorflow/gnn/tree/master/tensorflow_gnn/models/contrastive_losses/metrics.py#L321-L323">View | ||
source</a> | ||
|
||
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> | ||
<code>reset_state() -> None | ||
</code></pre> | ||
|
||
Resets all of the metric state variables. | ||
|
||
This function is called between epochs/steps, when a metric is evaluated during | ||
training. | ||
|
||
<h3 id="result"><code>result</code></h3> | ||
|
||
<a target="_blank" class="external" href="https://github.com/tensorflow/gnn/tree/master/tensorflow_gnn/models/contrastive_losses/metrics.py#L333-L334">View | ||
source</a> | ||
|
||
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> | ||
<code>result() -> Mapping[str, tf.Tensor] | ||
</code></pre> | ||
|
||
Computes and returns the scalar metric value tensor or a dict of scalars. | ||
|
||
Result computation is an idempotent operation that simply calculates the metric | ||
value using the state variables. | ||
|
||
<!-- Tabular view --> | ||
|
||
<table class="responsive fixed orange"> | ||
<colgroup><col width="214px"><col></colgroup> | ||
<tr><th colspan="2">Returns</th></tr> | ||
<tr class="alt"> | ||
<td colspan="2"> | ||
A scalar tensor, or a dictionary of scalar tensors. | ||
</td> | ||
</tr> | ||
|
||
</table> | ||
|
||
<h3 id="update_state"><code>update_state</code></h3> | ||
|
||
<a target="_blank" class="external" href="https://github.com/tensorflow/gnn/tree/master/tensorflow_gnn/models/contrastive_losses/metrics.py#L325-L331">View | ||
source</a> | ||
|
||
<pre class="devsite-click-to-copy prettyprint lang-py tfo-signature-link"> | ||
<code>update_state( | ||
_, y_pred: tf.Tensor, sample_weight=None | ||
) -> None | ||
</code></pre> | ||
|
||
Accumulates statistics for the metric. | ||
|
||
Note: This function is executed as a graph function in graph mode. This means: | ||
a) Operations on the same resource are executed in textual order. This should | ||
make it easier to do things like add the updated value of a variable to another, | ||
for example. b) You don't need to worry about collecting the update ops to | ||
execute. All update ops added to the graph by this function will be executed. As | ||
a result, code should generally work the same way with graph or eager execution. | ||
|
||
<!-- Tabular view --> | ||
|
||
<table class="responsive fixed orange"> | ||
<colgroup><col width="214px"><col></colgroup> | ||
<tr><th colspan="2">Args</th></tr> | ||
|
||
<tr> <td> <code>*args</code> </td> <td> | ||
|
||
</td> | ||
</tr><tr> | ||
<td> | ||
<code>**kwargs</code> | ||
</td> | ||
<td> | ||
A mini-batch of inputs to the Metric. | ||
</td> | ||
</tr> | ||
</table> |
Oops, something went wrong.