diff --git a/mne_connectivity/tests/test_connectivity.py b/mne_connectivity/tests/test_connectivity.py index fe767ee8..f0c2c87e 100644 --- a/mne_connectivity/tests/test_connectivity.py +++ b/mne_connectivity/tests/test_connectivity.py @@ -56,7 +56,7 @@ def _make_test_epochs(): def _prep_correct_connectivity_input( - conn_cls, n_nodes=3, symmetric=False, n_epochs=4, indices=None + conn_cls, n_nodes=3, symmetric=False, n_epochs=4, indices=None, n_components=0 ): correct_numpy_shape = [] @@ -72,6 +72,10 @@ def _prep_correct_connectivity_input( else: correct_numpy_shape.append(len(indices[0])) + if n_components: + correct_numpy_shape.append(n_components) + extra_kwargs["components"] = np.arange(n_components) + 1 + if conn_cls in ( SpectralConnectivity, SpectroTemporalConnectivity, @@ -105,7 +109,8 @@ def _prep_correct_connectivity_input( EpochSpectroTemporalConnectivity, ], ) -def test_connectivity_containers(conn_cls): +@pytest.mark.parametrize("n_components", [0, 2]) +def test_connectivity_containers(conn_cls, n_components): """Test connectivity classes.""" n_epochs = 4 n_nodes = 3 @@ -114,14 +119,18 @@ def test_connectivity_containers(conn_cls): [3, 4, 5], [0, 1, 2], ] - bad_numpy_input = np.zeros((3, 3, 4, 5)) + bad_numpy_input = np.zeros((3, 3, 4, 5, 6)) bad_indices = ([1, 0], [2]) if conn_cls.is_epoched: - bad_numpy_input = np.zeros((3, 3, 3, 4, 5)) + bad_numpy_input = np.zeros((3, 3, 3, 4, 5, 6)) correct_numpy_shape, extra_kwargs = _prep_correct_connectivity_input( - conn_cls, n_nodes=n_nodes, symmetric=False, n_epochs=n_epochs + conn_cls, + n_nodes=n_nodes, + symmetric=False, + n_epochs=n_epochs, + n_components=n_components, ) correct_numpy_input = np.ones(correct_numpy_shape) @@ -146,6 +155,19 @@ def test_connectivity_containers(conn_cls): # test that get_data works as intended with pytest.raises(ValueError, match="Invalid value for the 'output' parameter"): conn.get_data(output="blah") + with pytest.raises( + ValueError, match="cannot return multivariate connectivity data in a dense form" + ): + multivar_conn = conn_cls( + data=correct_numpy_input, + n_nodes=n_nodes, + indices=( + [[ind] for ind in range(n_nodes**2)], + [[ind] for ind in range(n_nodes**2)], + ), + **extra_kwargs, + ) + multivar_conn.get_data(output="dense") assert conn.shape == tuple(correct_numpy_shape) assert conn.get_data(output="raveled").shape == tuple(correct_numpy_shape) @@ -224,6 +246,9 @@ def test_connectivity_containers(conn_cls): ) +test_connectivity_containers(Connectivity, 0) + + @pytest.mark.parametrize( "conn_cls", [