Skip to content

Commit

Permalink
Update container tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jul 15, 2024
1 parent 8d7e7eb commit 13c85e2
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions mne_connectivity/tests/test_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _make_test_epochs():


def _prep_correct_connectivity_input(
conn_cls, n_nodes=3, symmetric=False, n_epochs=4, indices=None
conn_cls, n_nodes=3, symmetric=False, n_epochs=4, indices=None, n_components=0
):
correct_numpy_shape = []

Expand All @@ -72,6 +72,10 @@ def _prep_correct_connectivity_input(
else:
correct_numpy_shape.append(len(indices[0]))

if n_components:
correct_numpy_shape.append(n_components)
extra_kwargs["components"] = np.arange(n_components) + 1

if conn_cls in (
SpectralConnectivity,
SpectroTemporalConnectivity,
Expand Down Expand Up @@ -105,7 +109,8 @@ def _prep_correct_connectivity_input(
EpochSpectroTemporalConnectivity,
],
)
def test_connectivity_containers(conn_cls):
@pytest.mark.parametrize("n_components", [0, 2])
def test_connectivity_containers(conn_cls, n_components):
"""Test connectivity classes."""
n_epochs = 4
n_nodes = 3
Expand All @@ -114,14 +119,18 @@ def test_connectivity_containers(conn_cls):
[3, 4, 5],
[0, 1, 2],
]
bad_numpy_input = np.zeros((3, 3, 4, 5))
bad_numpy_input = np.zeros((3, 3, 4, 5, 6))
bad_indices = ([1, 0], [2])

if conn_cls.is_epoched:
bad_numpy_input = np.zeros((3, 3, 3, 4, 5))
bad_numpy_input = np.zeros((3, 3, 3, 4, 5, 6))

correct_numpy_shape, extra_kwargs = _prep_correct_connectivity_input(
conn_cls, n_nodes=n_nodes, symmetric=False, n_epochs=n_epochs
conn_cls,
n_nodes=n_nodes,
symmetric=False,
n_epochs=n_epochs,
n_components=n_components,
)

correct_numpy_input = np.ones(correct_numpy_shape)
Expand All @@ -146,6 +155,19 @@ def test_connectivity_containers(conn_cls):
# test that get_data works as intended
with pytest.raises(ValueError, match="Invalid value for the 'output' parameter"):
conn.get_data(output="blah")
with pytest.raises(
ValueError, match="cannot return multivariate connectivity data in a dense form"
):
multivar_conn = conn_cls(
data=correct_numpy_input,
n_nodes=n_nodes,
indices=(
[[ind] for ind in range(n_nodes**2)],
[[ind] for ind in range(n_nodes**2)],
),
**extra_kwargs,
)
multivar_conn.get_data(output="dense")

assert conn.shape == tuple(correct_numpy_shape)
assert conn.get_data(output="raveled").shape == tuple(correct_numpy_shape)
Expand Down Expand Up @@ -224,6 +246,9 @@ def test_connectivity_containers(conn_cls):
)


test_connectivity_containers(Connectivity, 0)


@pytest.mark.parametrize(
"conn_cls",
[
Expand Down

0 comments on commit 13c85e2

Please sign in to comment.