Skip to content

Commit

Permalink
Use TensorFlow public APIs in attention_wrapper_test (tensorflow#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored and seanpmorgan committed Sep 3, 2019
1 parent 68e8bb9 commit 28ec920
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions tensorflow_addons/seq2seq/attention_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
from tensorflow_addons.seq2seq import basic_decoder
from tensorflow_addons.seq2seq import sampler as sampler_py

# TODO: Find public API alternatives to these
from tensorflow.python import keras
from tensorflow.python.keras import initializers


@test_utils.run_all_in_graph_and_eager_modes
class AttentionMechanismTest(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -131,22 +127,22 @@ def test_passing_memory_from_call(self, attention_cls):
def test_save_load_layer(self, attention_cls):
vocab = 20
embedding_dim = 6
inputs = keras.layers.Input(shape=[self.timestep])
encoder_input = keras.layers.Embedding(
inputs = tf.keras.Input(shape=[self.timestep])
encoder_input = tf.keras.layers.Embedding(
vocab, embedding_dim, mask_zero=True)(inputs)
encoder_output = keras.layers.LSTM(
encoder_output = tf.keras.layers.LSTM(
self.memory_size, return_sequences=True)(encoder_input)

attention = attention_cls(self.units, encoder_output)
query = keras.layers.Input(shape=[self.units])
state = keras.layers.Input(shape=[self.timestep])
query = tf.keras.Input(shape=[self.units])
state = tf.keras.Input(shape=[self.timestep])

score = attention([query, state])

x = np.random.randint(vocab, size=(self.batch, self.timestep))
x_test = np.random.randint(vocab, size=(self.batch, self.timestep))
y = np.random.randn(self.batch, self.timestep)
model = keras.models.Model([inputs, query, state], score)
model = tf.keras.Model([inputs, query, state], score)
# Fall back to v1 style Keras training loop until issue with
# using outputs of a layer in another layer's constructor.
model.compile("rmsprop", "mse", experimental_run_tf_function=False)
Expand All @@ -155,7 +151,7 @@ def test_save_load_layer(self, attention_cls):

config = model.get_config()
weights = model.get_weights()
loaded_model = keras.models.Model.from_config(
loaded_model = tf.keras.Model.from_config(
config, custom_objects={attention_cls.__name__: attention_cls})
loaded_model.set_weights(weights)

Expand Down Expand Up @@ -337,11 +333,12 @@ def _testWithMaybeMultiAttention(self,
# Create a memory layer with deterministic initializer to avoid
# randomness in the test between graph and eager.
if create_query_layer:
create_attention_kwargs["query_layer"] = keras.layers.Dense(
create_attention_kwargs["query_layer"] = tf.keras.layers.Dense(
depth, kernel_initializer="ones", use_bias=False)
if create_memory_layer:
create_attention_kwargs["memory_layer"] = keras.layers.Dense(
depth, kernel_initializer="ones", use_bias=False)
create_attention_kwargs["memory_layer"] = (
tf.keras.layers.Dense(
depth, kernel_initializer="ones", use_bias=False))

attention_mechanisms.append(
creator(
Expand All @@ -358,7 +355,7 @@ def _testWithMaybeMultiAttention(self,
attention_layer_size = attention_layer_size[0]
if attention_layer is not None:
attention_layer = attention_layer[0]
cell = keras.layers.LSTMCell(
cell = tf.keras.layers.LSTMCell(
cell_depth,
recurrent_activation="sigmoid",
kernel_initializer="ones",
Expand All @@ -371,8 +368,9 @@ def _testWithMaybeMultiAttention(self,
attention_layer=attention_layer)
if cell._attention_layers is not None:
for layer in cell._attention_layers:
layer.kernel_initializer = initializers.glorot_uniform(
seed=1337)
layer.kernel_initializer = (
tf.compat.v1.keras.initializers.glorot_uniform(
seed=1337))

sampler = sampler_py.TrainingSampler()
my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
Expand Down Expand Up @@ -476,7 +474,7 @@ def testBahdanauNormalizedDType(self, dtype):
memory_sequence_length=self.encoder_sequence_length,
normalize=True,
dtype=dtype)
cell = keras.layers.LSTMCell(
cell = tf.keras.layers.LSTMCell(
self.units, recurrent_activation="sigmoid", dtype=dtype)
cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)

Expand Down Expand Up @@ -505,7 +503,7 @@ def testLuongScaledDType(self, dtype):
scale=True,
dtype=dtype,
)
cell = keras.layers.LSTMCell(
cell = tf.keras.layers.LSTMCell(
self.units, recurrent_activation="sigmoid", dtype=dtype)
cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)

Expand Down

0 comments on commit 28ec920

Please sign in to comment.