From 2cbb191ea4daa6c45726b30a5d58de4200cb5233 Mon Sep 17 00:00:00 2001 From: Lonneke Scheffer Date: Wed, 24 Apr 2024 20:36:16 +0200 Subject: [PATCH] fix failing KerasSequenceCNN test on github --- test/ml_methods/test_kerasSequenceCNN.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/ml_methods/test_kerasSequenceCNN.py b/test/ml_methods/test_kerasSequenceCNN.py index 5b84d075c..1efeb595c 100644 --- a/test/ml_methods/test_kerasSequenceCNN.py +++ b/test/ml_methods/test_kerasSequenceCNN.py @@ -16,8 +16,6 @@ from immuneML.util.PathBuilder import PathBuilder - - class TestKerasSequenceCNN(TestCase): maxDiff = None @@ -32,6 +30,14 @@ def test_if_keras_installed(self): except ImportError as e: print("Test ignored since keras is not installed.") + def _recursive_convert_lists_to_tuples(self, model_description): + if isinstance(model_description, dict): + return {key: self._recursive_convert_lists_to_tuples(value) for key, value in + model_description.items()} + elif isinstance(model_description, (list, tuple, set)): + return tuple([self._recursive_convert_lists_to_tuples(item) for item in model_description]) + else: + return model_description def _test_fit(self): import keras @@ -82,6 +88,11 @@ def _test_fit(self): for item, value in cnn_params.items(): if isinstance(value, Label): self.assertDictEqual(vars(value), (vars(cnn2_params[item]))) + elif item == "model": + model1_params = self._recursive_convert_lists_to_tuples(cnn_params["model"]) + model2_params = self._recursive_convert_lists_to_tuples(cnn2_params["model"]) + + self.assertDictEqual(model1_params, model2_params) else: self.assertEqual(value, cnn2_params[item])