Skip to content

Commit

Permalink
Fixed docstrings for wrapped functions: @disallow_keras_tensors and @…
Browse files Browse the repository at this point in the history
…delegate_keras_tensors

PiperOrigin-RevId: 588726635
  • Loading branch information
aferludin authored and tensorflower-gardener committed Dec 7, 2023
1 parent 9e1ac23 commit e8f18b1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
26 changes: 19 additions & 7 deletions tensorflow_gnn/keras/keras_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
`ValueError` or `TypeError` is raised when the wrapped callables are called
with an unexpected argument type.
"""
import functools
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import tensorflow as tf
from tensorflow_gnn.graph import adjacency as adj
from tensorflow_gnn.graph import graph_constants as const
Expand Down Expand Up @@ -437,13 +437,19 @@ def delegate_keras_tensors(target=None, name: Optional[str] = None):
target.
"""

def delegator(target=None):
return _TFGNNOpDispatcher(target, name)
def decorator(target=None):
impl = _TFGNNOpDispatcher(target, name)

@functools.wraps(target)
def fn(*argw, **kwargs):
return impl(*argw, **kwargs)

return fn

if target is None:
return delegator
return decorator
else:
return _TFGNNOpDispatcher(target, name)
return decorator(target)


def disallow_keras_tensors(
Expand All @@ -470,12 +476,18 @@ def disallow_keras_tensors(
"""

def decorator(target=None):
return _NotSupportedDispatcher(target, name=name, alternative=alternative)
impl = _NotSupportedDispatcher(target, name=name, alternative=alternative)

@functools.wraps(target)
def fn(*argw, **kwargs):
return impl(*argw, **kwargs)

return fn

if target is None:
return decorator
else:
return _NotSupportedDispatcher(target, name=name, alternative=alternative)
return decorator(target)


def _pack_args(
Expand Down
26 changes: 26 additions & 0 deletions tensorflow_gnn/keras/keras_tensors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for KerasTensor specializations for GraphTensor pieces."""

import functools
import inspect
import os
from typing import Mapping

Expand Down Expand Up @@ -806,6 +807,31 @@ def testOps(self, layer_name: str, transformation):
)


class DocStringsDelegationTest(_TestBase):

def testDelegateKerasTensors(self):
@kt.delegate_keras_tensors
def test_fn(t: tf.Tensor, a: int, *, b: str):
"""Delegates Keras tensors."""
del t, a, b

signature = inspect.signature(test_fn, follow_wrapped=True)
self.assertSequenceEqual(list(signature.parameters.keys()), ['t', 'a', 'b'])
self.assertEqual(inspect.getdoc(test_fn), 'Delegates Keras tensors.')

def testDisallowKerasTensors(self):
@kt.disallow_keras_tensors
def test_fn(t1: tf.Tensor, t2: tf.Tensor, *, x: str):
"""Deprecates Keras tensors."""
del t1, t2, x

signature = inspect.signature(test_fn, follow_wrapped=True)
self.assertSequenceEqual(
list(signature.parameters.keys()), ['t1', 't2', 'x']
)
self.assertEqual(inspect.getdoc(test_fn), 'Deprecates Keras tensors.')


class WrappedOpsSavingTest(_SaveAndLoadTestBase):

@tf.keras.utils.register_keras_serializable()
Expand Down

0 comments on commit e8f18b1

Please sign in to comment.