From bbece5a5282876f2d1488211d9d30487afe09ac7 Mon Sep 17 00:00:00 2001 From: Martin Billinger Date: Tue, 26 Apr 2016 10:08:12 +0200 Subject: [PATCH] Fixed Yule-Walker fitting and added tests --- mne_sandbox/connectivity/mvar.py | 2 +- mne_sandbox/connectivity/tests/test_mvar.py | 29 ++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mne_sandbox/connectivity/mvar.py b/mne_sandbox/connectivity/mvar.py index da99c0a..4caaf40 100644 --- a/mne_sandbox/connectivity/mvar.py +++ b/mne_sandbox/connectivity/mvar.py @@ -25,7 +25,7 @@ def _acm(x, l): a = x[:, l:] b = x[:, 0:-l] - return np.dot(a[:, :], b[:, :].T) / a.shape[1] + return np.dot(a[:, :], b[:, :].T).T / a.shape[1] def _epoch_autocorrelations(epoch, max_lag): diff --git a/mne_sandbox/connectivity/tests/test_mvar.py b/mne_sandbox/connectivity/tests/test_mvar.py index 7f37065..81470d4 100644 --- a/mne_sandbox/connectivity/tests/test_mvar.py +++ b/mne_sandbox/connectivity/tests/test_mvar.py @@ -3,10 +3,13 @@ # License: BSD (3-clause) import numpy as np -from numpy.testing import assert_array_almost_equal, assert_array_less +from numpy.testing import (assert_array_equal, assert_array_almost_equal, + assert_array_less) from nose.tools import assert_raises, assert_equal +from copy import deepcopy from mne_sandbox.connectivity import mvar_connectivity +from mne_sandbox.connectivity.mvar import _fit_mvar_lsq, _fit_mvar_yw def _make_data(var_coef, n_samples, n_epochs): @@ -136,3 +139,27 @@ def test_mvar_connectivity(): assert_array_less(p_vals[0][i, j, 0], 0.05) else: assert_array_less(0.05, p_vals[0][i, j, 0]) + + +def test_fit_mvar(): + """Test MVAR model fitting""" + np.random.seed(0) + + n_sigs = 3 + n_epochs = 50 + n_samples = 200 + + var_coef = np.zeros((1, n_sigs, n_sigs)) + var_coef[0, :, :] = [[0.9, 0, 0], + [1, 0.5, 0], + [2, 0, -0.5]] + data = _make_data(var_coef, n_samples, n_epochs) + data0 = deepcopy(data) + + var = _fit_mvar_lsq(data, pmin=1, pmax=1, delta=0, n_jobs=1, verbose=0) + assert_array_equal(data, data0) + assert_array_almost_equal(var_coef[0], var.coef, decimal=2) + + var = _fit_mvar_yw(data, pmin=1, pmax=1, n_jobs=1, verbose=0) + assert_array_equal(data, data0) + assert_array_almost_equal(var_coef[0], var.coef, decimal=2)