Skip to content

Commit

Permalink
From models/contrastive_losses, remove untested Keras deserialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587699869
  • Loading branch information
arnoegw authored and tensorflower-gardener committed Dec 4, 2023
1 parent e165738 commit 6b8f354
Showing 1 changed file with 2 additions and 25 deletions.
27 changes: 2 additions & 25 deletions tensorflow_gnn/models/contrastive_losses/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import tensorflow as tf
import tensorflow_gnn as tfgnn

_PACKAGE = "GNN>models>contrastive_losses"
_WILDCARD = "*"
T = TypeVar("T")

Expand Down Expand Up @@ -136,23 +135,11 @@ def fn(inputs, *, node_set_name=None, edge_set_name=None):
super().__init__(fn, fn, fn, **kwargs)

def get_config(self):
return dict(
default=self._default,
corruption_fn=self._corruption_fn,
node_corruption_spec=self._node_corruption_spec,
edge_corruption_spec=self._edge_corruption_spec,
context_corruption_spec=self._context_corruption_spec,
**super().get_config(),
)
raise NotImplementedError()

@classmethod
def from_config(cls, config):
config["corruption_spec"] = CorruptionSpec(
config.pop("node_corruption_spec"),
config.pop("edge_corruption_spec"),
config.pop("context_corruption_spec"),
)
return cls(**config)
raise NotImplementedError()


def _seed_wrapper(
Expand All @@ -165,7 +152,6 @@ def wrapper_fn(tensor, rate):
return wrapper_fn


@tf.keras.utils.register_keras_serializable(package=_PACKAGE)
class ShuffleFeaturesGlobally(Corruptor[float]):
"""A corruptor that shuffles features.
Expand All @@ -178,21 +164,14 @@ def __init__(self, *args, seed: Optional[float] = None, **kwargs):
seeded_fn = _seed_wrapper(_shuffle_tensor, seed=seed)
super().__init__(*args, corruption_fn=seeded_fn, default=1.0, **kwargs)

def get_config(self):
return dict(seed=self._seed, **super().get_config())


@tf.keras.utils.register_keras_serializable(package=_PACKAGE)
class DropoutFeatures(Corruptor[float]):

def __init__(self, *args, seed: Optional[float] = None, **kwargs):
self._seed = seed
seeded_fn = _seed_wrapper(tf.nn.dropout, seed=seed)
super().__init__(*args, corruption_fn=seeded_fn, default=0.0, **kwargs)

def get_config(self):
return dict(seed=self._seed, **super().get_config())


def _ragged_dim_list(tensor: tf.RaggedTensor) -> List[Union[int, tf.Tensor]]:
"""Lists ragged tensor dimensions with a preference for static sizes."""
Expand Down Expand Up @@ -295,7 +274,6 @@ def _corrupt_features(
return output


@tf.keras.utils.register_keras_serializable(package=_PACKAGE)
class DeepGraphInfomaxLogits(tf.keras.layers.Layer):
"""Computes clean and corrupted logits for Deep Graph Infomax (DGI)."""

Expand Down Expand Up @@ -331,7 +309,6 @@ def call(self, inputs: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
return tf.keras.layers.Concatenate()((logits_clean, logits_corrupted))


@tf.keras.utils.register_keras_serializable(package=_PACKAGE)
class TripletEmbeddingSquaredDistances(tf.keras.layers.Layer):
"""Computes embeddings distance between positive and negative pairs."""

Expand Down

0 comments on commit 6b8f354

Please sign in to comment.