Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement automatic ridge regression #124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 65 additions & 19 deletions mne_connectivity/vector_ar/var.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy
from scipy.linalg import sqrtm
from sklearn.linear_model import RidgeCV
from tqdm import tqdm
from mne import BaseEpochs

Expand All @@ -13,7 +14,7 @@
@verbose
@fill_doc
def vector_auto_regression(
data, times=None, names=None, lags=1, l2_reg=0.0,
data, times=None, names=None, lags=1, l2_reg='auto',
compute_fb_operator=False, model='dynamic', n_jobs=1, verbose=None):
"""Compute vector auto-regresssive (VAR) model.

Expand All @@ -29,8 +30,14 @@ def vector_auto_regression(
%(names)s
lags : int, optional
Autoregressive model order, by default 1.
l2_reg : float, optional
Ridge penalty (l2-regularization) parameter, by default 0.0.
l2_reg : str | array-like, shape=(n_alphas,) | float | None, optional
Ridge penalty (l2-regularization) parameter, by default 'auto'. If
``data`` has condition number less than 1e6, then ``data`` will undergo
automatic regularization using RidgeCV with a pre-defined array of
witherscp marked this conversation as resolved.
Show resolved Hide resolved
alphas. A user-defined array of alphas (must be positive floats) can be
inputted or a float value to fix the Ridge penalty (l2-regularization)
parameter. If ``l2_reg`` is set to 0 or None, then no regularization
will be performed.
compute_fb_operator : bool
Whether to compute the backwards operator and average with
the forward operator. Addresses bias in the least-square
Expand Down Expand Up @@ -151,10 +158,31 @@ def vector_auto_regression(
# 1. determine shape of the window of data
n_epochs, n_nodes, _ = data.shape

cv_alphas = None
if isinstance(l2_reg, str):
if l2_reg == 'auto':
witherscp marked this conversation as resolved.
Show resolved Hide resolved

# determine condition of matrix across all epochs
conds = np.linalg.cond(data)
if np.any(conds > 1e6):
# matrix is rank-deficient, so regularization must be used with
# cross-validation alphas values
cv_alphas = np.logspace(-15,5,11)

# TODO: Add message letting user know that matrix is ill-conditioned
# and the above alpha set will be searched
witherscp marked this conversation as resolved.
Show resolved Hide resolved

elif isinstance(l2_reg, (list, tuple, set, np.ndarray)):
cv_alphas = l2_reg

model_params = {
'lags': lags,
'l2_reg': l2_reg,
'cv_alphas': cv_alphas
}

# reset l2_reg for downstream functions
if cv_alphas is not None:
l2_reg = 0

if verbose:
logger.info(f'Running {model} vector autoregression with parameters: '
Expand All @@ -165,12 +193,15 @@ def vector_auto_regression(
# sample of the multivariate time-series of interest
# ordinary least squares or regularized least squares
# (ridge regression)
X, Y = _construct_var_eqns(data, **model_params)

b, res, rank, s = scipy.linalg.lstsq(X, Y)
X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg)

# get the coefficients
coef = b.transpose()
if cv_alphas is not None:
reg = RidgeCV(alphas=cv_alphas, cv=5).fit(X, Y)
coef = reg.coef_
else:
b, res, rank, s = scipy.linalg.lstsq(X, Y)
coef = b.transpose()

# create connectivity
coef = coef.flatten()
Expand All @@ -187,8 +218,9 @@ def vector_auto_regression(
# linear system
A_mats = _system_identification(
data=data, lags=lags,
l2_reg=l2_reg, n_jobs=n_jobs,
compute_fb_operator=compute_fb_operator)
l2_reg=l2_reg, cv_alphas=cv_alphas,
n_jobs=n_jobs, compute_fb_operator=compute_fb_operator
)
# create connectivity
if lags > 1:
conn = EpochTemporalConnectivity(data=A_mats,
Expand Down Expand Up @@ -261,7 +293,7 @@ def _construct_var_eqns(data, lags, l2_reg=None):
X[:n, i * lags + k -
1] = np.reshape(data[:, i, lags - k:-k].T, n)

if l2_reg is not None:
if l2_reg:
np.fill_diagonal(X[n:, :], l2_reg)

# Construct vectors yi (response variables for each channel i)
Expand All @@ -272,7 +304,7 @@ def _construct_var_eqns(data, lags, l2_reg=None):
return X, Y


def _system_identification(data, lags, l2_reg=0,
def _system_identification(data, lags, l2_reg=0, cv_alphas=None,
n_jobs=-1, compute_fb_operator=False):
"""Solve system identification using least-squares over all epochs.

Expand All @@ -290,6 +322,7 @@ def _system_identification(data, lags, l2_reg=0,
model_params = {
'l2_reg': l2_reg,
'lags': lags,
'cv_alphas': cv_alphas,
'compute_fb_operator': compute_fb_operator
}

Expand Down Expand Up @@ -346,7 +379,7 @@ def _system_identification(data, lags, l2_reg=0,
return A_mats


def _compute_lds_func(data, lags, l2_reg, compute_fb_operator):
def _compute_lds_func(data, lags, l2_reg, cv_alphas, compute_fb_operator):
"""Compute linear system using VAR model.

Allows for parallelization over epochs.
Expand All @@ -372,20 +405,21 @@ def _compute_lds_func(data, lags, l2_reg, compute_fb_operator):
# get time-shifted versions
X = data[:, :]
A, resid, omega = _estimate_var(X, lags=lags, offset=0,
l2_reg=l2_reg)
l2_reg=l2_reg, cv_alphas=cv_alphas)

if compute_fb_operator:
# compute backward linear operator
# original method
back_A, back_resid, back_omega = _estimate_var(
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg)
X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg, cv_alphas=cv_alphas
)
A = sqrtm(A.dot(np.linalg.inv(back_A)))
A = A.real # remove numerical noise

return A, resid, omega


def _estimate_var(X, lags, offset=0, l2_reg=0):
def _estimate_var(X, lags, offset=0, l2_reg=0, cv_alphas=None):
"""Estimate a VAR model.

Parameters
Expand All @@ -397,8 +431,10 @@ def _estimate_var(X, lags, offset=0, l2_reg=0):
offset : int, optional
Periods to drop from the beginning of the time-series, by default 0.
Used for order selection, so it's an apples-to-apples comparison
l2_reg : int
l2_reg : int, optional
The amount of l2-regularization to use. Default of 0.
cv_alphas : array-like | None, optional
RidgeCV regularization cross-validation alpha values. Defaults to None.

Returns
-------
Expand Down Expand Up @@ -432,10 +468,20 @@ def _estimate_var(X, lags, offset=0, l2_reg=0):
y_sample = endog[lags:]
del endog, X
# Lütkepohl p75, about 5x faster than stated formula

if l2_reg != 0:
params = np.linalg.lstsq(z.T @ z + l2_reg * np.eye(n_equations * lags),
z.T @ y_sample, rcond=1e-15)[0]
# use pre-specified l2 regularization value
params = np.linalg.lstsq(
z.T @ z + l2_reg * np.eye(n_equations * lags),
z.T @ y_sample,
rcond=1e-15
)[0]
elif cv_alphas is not None:
# use ridge regression with built-in cross validation of alpha values
reg = RidgeCV(alphas=cv_alphas, cv=5).fit(z, y_sample)
params = reg.coef_.T
else:
# use OLS regression
params = np.linalg.lstsq(z, y_sample, rcond=1e-15)[0]

# (n_samples - lags, n_channels)
Expand Down