Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 23, 2024
1 parent 9f7ee18 commit 4f71dd2
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
9 changes: 5 additions & 4 deletions src/model/vgsl_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def init_model_from_string(self,
setattr(self, f"lstm_{index}", self.lstm_generator(layer))
self.history.append(f"lstm_{index}")
elif layer.startswith('F'):
setattr(self, f"dense{index}", self.fully_connected_generator(layer))
setattr(self, f"dense{index}",
self.fully_connected_generator(layer))
self.history.append(f"dense{index}")
elif layer.startswith('B'):
setattr(self, f"bidirectional_{index}",
Expand Down Expand Up @@ -387,10 +388,10 @@ def get_stride_spec(strides: tuple) -> str:
# This is only the case where we have a model created with the Keras
# functional API
if isinstance(model.layers[0], tf.keras.layers.InputLayer):
input_shape = model.layers[0].input_shape[0]
input_shape = model.layers[0].output.shape
start_idx = 1
else:
input_shape = model.layers[0].input_shape
input_shape = model.layers[0].input.shape
start_idx = 0

if not (len(input_shape) == 4 and
Expand Down Expand Up @@ -442,7 +443,7 @@ def get_stride_spec(strides: tuple) -> str:
f"{get_dropout(layer.dropout, layer.recurrent_dropout)}")

elif isinstance(layer, layers.Bidirectional):
wrapped_layer = layer.layer
wrapped_layer = layer.forward_layer
cell_type = 'l' if isinstance(
wrapped_layer, tf.keras.layers.LSTM) else 'g'
dropout = get_dropout(wrapped_layer.dropout,
Expand Down
32 changes: 22 additions & 10 deletions tests/test_model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,36 +554,48 @@ def test_bidirectional_layer(self):
model_generator = self.VGSLModelGenerator(vgsl_spec_string)
model = model_generator.build()
self.assertIsInstance(model.layers[2], layers.Bidirectional)
self.assertIsInstance(model.layers[2].layer, layers.GRU)
self.assertEqual(model.layers[2].layer.units, 128)
self.assertIsInstance(model.layers[2].forward_layer, layers.GRU)
self.assertIsInstance(model.layers[2].backward_layer, layers.GRU)
self.assertEqual(model.layers[2].forward_layer.units, 128)
self.assertEqual(model.layers[2].backward_layer.units, 128)

vgsl_spec_string = "None,64,None,1 Rc Bl128 O1s10"
model_generator = self.VGSLModelGenerator(vgsl_spec_string)
model = model_generator.build()
self.assertIsInstance(model.layers[2], layers.Bidirectional)
self.assertIsInstance(model.layers[2].layer, layers.LSTM)
self.assertEqual(model.layers[2].layer.units, 128)
self.assertIsInstance(model.layers[2].forward_layer, layers.LSTM)
self.assertIsInstance(model.layers[2].backward_layer, layers.LSTM)
self.assertEqual(model.layers[2].forward_layer.units, 128)
self.assertEqual(model.layers[2].backward_layer.units, 128)

vgsl_spec_string = "None,64,None,1 Rc Bl128,D50 O1s10"
model_generator = self.VGSLModelGenerator(vgsl_spec_string)
model = model_generator.build()

self.assertEqual(model.layers[2].layer.dropout, 0.50)
self.assertEqual(model.layers[2].layer.recurrent_dropout, 0)
self.assertEqual(model.layers[2].forward_layer.dropout, 0.50)
self.assertEqual(model.layers[2].backward_layer.dropout, 0.50)
self.assertEqual(model.layers[2].forward_layer.recurrent_dropout, 0)
self.assertEqual(model.layers[2].backward_layer.recurrent_dropout, 0)

vgsl_spec_string = "None,64,None,1 Rc Bl128,Rd50 O1s10"
model_generator = self.VGSLModelGenerator(vgsl_spec_string)
model = model_generator.build()

self.assertEqual(model.layers[2].layer.dropout, 0)
self.assertEqual(model.layers[2].layer.recurrent_dropout, 0.50)
self.assertEqual(model.layers[2].forward_layer.dropout, 0)
self.assertEqual(model.layers[2].backward_layer.dropout, 0)
self.assertEqual(model.layers[2].forward_layer.recurrent_dropout, 0.50)
self.assertEqual(model.layers[2].backward_layer.recurrent_dropout,
0.50)

vgsl_spec_string = "None,64,None,1 Rc Bl128,D42,Rd34 O1s10"
model_generator = self.VGSLModelGenerator(vgsl_spec_string)
model = model_generator.build()

self.assertEqual(model.layers[2].layer.dropout, 0.42)
self.assertEqual(model.layers[2].layer.recurrent_dropout, 0.34)
self.assertEqual(model.layers[2].forward_layer.dropout, 0.42)
self.assertEqual(model.layers[2].backward_layer.dropout, 0.42)
self.assertEqual(model.layers[2].forward_layer.recurrent_dropout, 0.34)
self.assertEqual(model.layers[2].backward_layer.recurrent_dropout,
0.34)

def test_bidirectional_error_handling(self):
# Invalid format
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_rnn_replacement_multiple(self):
found_gru = True
elif isinstance(layer, layers.Bidirectional):
self.assertEqual(
layer.layer.units, 32, "Unexpected number of units in "
layer.forward_layer.units, 32, "Unexpected number of units in "
"Bidirectional LSTM layer")
found_bidir = True

Expand Down
3 changes: 3 additions & 0 deletions tests/test_model_to_vgsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,11 @@ def compare_model_configs(self, model, generated_model):
# If the layer is Bidirectional, update the inner layer's config
if isinstance(original_layer, tf.keras.layers.Bidirectional):
for key in keys_to_ignore:
# NOTE: why is 'layer' not 'forward_layer'?
original_config['layer']['config'].pop(key, None)
original_config['backward_layer']['config'].pop(key, None)
generated_config['layer']['config'].pop(key, None)
generated_config['backward_layer']['config'].pop(key, None)
for key in keys_to_ignore:
original_config.pop(key, None)
generated_config.pop(key, None)
Expand Down

0 comments on commit 4f71dd2

Please sign in to comment.