From 3a655dee751c9f4a98be4dc976e38c62cf6f4787 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Mon, 22 Jan 2024 11:41:00 +0100 Subject: [PATCH] [Tests] Check if Gram matrices are semi-definite. --- gklearn/tests/test_graph_kernels.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/gklearn/tests/test_graph_kernels.py b/gklearn/tests/test_graph_kernels.py index 08091cb80b..8105a3c5d9 100644 --- a/gklearn/tests/test_graph_kernels.py +++ b/gklearn/tests/test_graph_kernels.py @@ -102,6 +102,16 @@ def assert_equality(compute_fun, **kwargs): assert np.array_equal(lst[i], lst[i + 1]) +def assert_semidefinite(gram_matrix): + """Check if a matrix is positive semi-definite. + """ + eigvals = np.linalg.eigvals(gram_matrix) + assert np.all(eigvals >= 10e-9), "Gram matrix is not positive semi-definite." + + +############################################################################## + + @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'AIDS']) @pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')]) # @pytest.mark.parametrize('parallel', ['imap_unordered', None]) @@ -143,6 +153,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -198,6 +210,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -249,6 +263,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -319,6 +335,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -388,6 +406,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -441,6 +461,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -558,6 +580,8 @@ def compute(parallel=None, fcsp=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -633,6 +657,8 @@ def compute(parallel=None, fcsp=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -691,6 +717,8 @@ def compute(parallel=None, compute_method=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -742,6 +770,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception @@ -798,6 +828,8 @@ def compute(parallel=None): verbose=True ) + assert_semidefinite(gram_matrix) + except Exception as exception: print(repr(exception)) assert False, exception