Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement automatic ridge regression #124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions mne_connectivity/vector_ar/tests/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,42 @@ def bivariate_var_data():
return y


def illconditioned_data(
n=12,
m=100,
add_noise=False,
sigma=1e-4,
random_state=12345,
):

rng = np.random.RandomState(random_state)

if add_noise:
mu = 0.0
noise = rng.normal(mu, sigma, m) # gaussian noise

# create upper triangle A matrix
A = np.triu(rng.uniform(0, 1, (n, n)))
A[-1, -1] = 1e-6 # matrix is ill-conditioned

# compute true eigenvalues
true_eigvals = np.linalg.eigvals(A)

X = np.zeros((n, m))
X[:, 0] = rng.uniform(0, 1, n)
# evolve the system and perturb the data with noise
for k in range(1, m):
X[:, k] = A.dot(X[:, k - 1])

if add_noise:
X[:, k - 1] += noise[k - 1]

# data must be ill-conditioned
assert (np.linalg.cond(X) > 1e6)

return X, true_eigvals, A


def create_noisy_data(
add_noise,
sigma=1e-4,
Expand Down Expand Up @@ -316,3 +352,41 @@ def test_vector_auto_regression():
big_epoch_data = rng.randn(n_times * 2, n_signals, n_times)
parr_conn = vector_auto_regression(big_epoch_data, times=times, n_jobs=-1)
parr_conn.predict(big_epoch_data)


def test_auto_l2reg():
"""Test automatic l2 regularization of ill-conditioned data."""

sample_data, sample_eigs, sample_A = illconditioned_data(
add_noise=True
)

# create 3D array input
sample_data = sample_data[np.newaxis, ...]

# compute the model
model = vector_auto_regression(sample_data, l2_reg='auto')

# test that Ridge regression was used for ill-conditioned data
assert model.xarray.attrs['use_ridge']

# test the recovered model
assert_array_almost_equal(
model.get_data(output='dense').squeeze(), sample_A,
decimal=1
)

# compute model without regularization
noreg_model = vector_auto_regression(sample_data, l2_reg=None)
assert model.xarray.attrs['use_ridge'] is False

# test that the regularized model is better
eigs = np.linalg.eigvals(model.get_data(output='dense').squeeze())
noreg_eigs = np.linalg.eigvals(
noreg_model.get_data(output='dense').squeeze()
)

reg_diff = np.linalg.norm(eigs - sample_eigs)
noreg_diff = np.linalg.norm(noreg_eigs - sample_eigs)

assert reg_diff < noreg_diff
16 changes: 14 additions & 2 deletions mne_connectivity/vector_ar/var.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import scipy
from scipy.linalg import sqrtm
Expand Down Expand Up @@ -199,7 +201,12 @@ def vector_auto_regression(
X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg)

if cv_alphas is not None:
reg = RidgeCV(alphas=cv_alphas, cv=5).fit(X, Y)
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
message="Ill-conditioned matrix"
)
Comment on lines +204 to +208
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RidgeCV tests out an array of alpha values and some of them do not regularize the matrix enough to avoid an ill-conditioned matrix error. If the user sees many of these messages pop up, they may think that something is going wrong, when in fact the expected behavior of the function is happening. RidgeCV will choose the best alpha value and that will be from an instance when this error was not thrown.

reg = RidgeCV(alphas=cv_alphas, cv=5).fit(X, Y)
coef = reg.coef_
else:
b, res, rank, s = scipy.linalg.lstsq(X, Y)
Expand Down Expand Up @@ -480,7 +487,12 @@ def _estimate_var(X, lags, offset=0, l2_reg=0, cv_alphas=None):
)[0]
elif cv_alphas is not None:
# use ridge regression with built-in cross validation of alpha values
reg = RidgeCV(alphas=cv_alphas, cv=5).fit(z, y_sample)
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
message="Ill-conditioned matrix"
)
reg = RidgeCV(alphas=cv_alphas, cv=5).fit(z, y_sample)
params = reg.coef_.T
else:
# use OLS regression
Expand Down