Skip to content

Commit

Permalink
[Tests] Check if Gram matrices are semi-definite.
Browse files Browse the repository at this point in the history
  • Loading branch information
jajupmochi committed Jan 22, 2024
1 parent c865eef commit 3a655de
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions gklearn/tests/test_graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ def assert_equality(compute_fun, **kwargs):
assert np.array_equal(lst[i], lst[i + 1])


def assert_semidefinite(gram_matrix):
"""Check if a matrix is positive semi-definite.
"""
eigvals = np.linalg.eigvals(gram_matrix)
assert np.all(eigvals >= 10e-9), "Gram matrix is not positive semi-definite."


##############################################################################


@pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'AIDS'])
@pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')])
# @pytest.mark.parametrize('parallel', ['imap_unordered', None])
Expand Down Expand Up @@ -143,6 +153,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -198,6 +210,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -249,6 +263,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -319,6 +335,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -388,6 +406,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -441,6 +461,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -558,6 +580,8 @@ def compute(parallel=None, fcsp=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -633,6 +657,8 @@ def compute(parallel=None, fcsp=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -691,6 +717,8 @@ def compute(parallel=None, compute_method=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -742,6 +770,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down Expand Up @@ -798,6 +828,8 @@ def compute(parallel=None):
verbose=True
)

assert_semidefinite(gram_matrix)

except Exception as exception:
print(repr(exception))
assert False, exception
Expand Down

0 comments on commit 3a655de

Please sign in to comment.