Skip to content

Commit

Permalink
Fixed Yule-Walker fitting and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mbillingr committed Apr 26, 2016
1 parent 2b426dd commit bbece5a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mne_sandbox/connectivity/mvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _acm(x, l):
a = x[:, l:]
b = x[:, 0:-l]

return np.dot(a[:, :], b[:, :].T) / a.shape[1]
return np.dot(a[:, :], b[:, :].T).T / a.shape[1]


def _epoch_autocorrelations(epoch, max_lag):
Expand Down
29 changes: 28 additions & 1 deletion mne_sandbox/connectivity/tests/test_mvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
# License: BSD (3-clause)

import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_less
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
assert_array_less)
from nose.tools import assert_raises, assert_equal
from copy import deepcopy

from mne_sandbox.connectivity import mvar_connectivity
from mne_sandbox.connectivity.mvar import _fit_mvar_lsq, _fit_mvar_yw


def _make_data(var_coef, n_samples, n_epochs):
Expand Down Expand Up @@ -136,3 +139,27 @@ def test_mvar_connectivity():
assert_array_less(p_vals[0][i, j, 0], 0.05)
else:
assert_array_less(0.05, p_vals[0][i, j, 0])


def test_fit_mvar():
"""Test MVAR model fitting"""
np.random.seed(0)

n_sigs = 3
n_epochs = 50
n_samples = 200

var_coef = np.zeros((1, n_sigs, n_sigs))
var_coef[0, :, :] = [[0.9, 0, 0],
[1, 0.5, 0],
[2, 0, -0.5]]
data = _make_data(var_coef, n_samples, n_epochs)
data0 = deepcopy(data)

var = _fit_mvar_lsq(data, pmin=1, pmax=1, delta=0, n_jobs=1, verbose=0)
assert_array_equal(data, data0)
assert_array_almost_equal(var_coef[0], var.coef, decimal=2)

var = _fit_mvar_yw(data, pmin=1, pmax=1, n_jobs=1, verbose=0)
assert_array_equal(data, data0)
assert_array_almost_equal(var_coef[0], var.coef, decimal=2)

0 comments on commit bbece5a

Please sign in to comment.