Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add support for ragged connections with multivariate methods with padding #142

Merged
merged 48 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
73792c3
added support for ragged connections
tsbinns Jul 24, 2023
a3354cd
added author
tsbinns Jul 24, 2023
63f2c39
bug fix ragged indices comparison
tsbinns Jul 24, 2023
d2f3de8
bug fix ragged indices comparison
tsbinns Jul 24, 2023
ec36983
bug fix ragged indices comparison
tsbinns Jul 24, 2023
f70ed11
Merge branch 'main' into pr-mvc_padding
adam2392 Jul 24, 2023
02f5ad0
added extra multivariate indices unit test
tsbinns Jul 25, 2023
684e35f
Merge branch 'main' into pr-mvc_padding
tsbinns Jul 31, 2023
c229d4e
updated utils tests and docs
tsbinns Aug 14, 2023
ffca93f
bug fix utils doc update
tsbinns Aug 14, 2023
c9855b7
bug fix utils doc update
tsbinns Aug 14, 2023
8c35917
bug fix utils doc update
tsbinns Aug 14, 2023
9bd2513
Merge branch 'main' into pr-mvc_padding
tsbinns Aug 28, 2023
409c2c6
updated spectral tests
tsbinns Aug 28, 2023
fe0fe68
added note for refactoring
tsbinns Aug 31, 2023
b7fcf12
updated spectral tests
tsbinns Sep 1, 2023
96a0dcf
Update ignore words
adam2392 Sep 1, 2023
f837889
added error message
tsbinns Sep 8, 2023
ff33e56
Added formatting suggestions
tsbinns Sep 9, 2023
95cd3c2
added max_n_chans suggestion
tsbinns Sep 9, 2023
42085e1
Merge branch 'pr-mvc_padding' of https://github.com/tsbinns/mne-conne…
tsbinns Sep 9, 2023
402cfaa
updated epochs docstring
tsbinns Sep 9, 2023
ebe0a69
added test suggestion
tsbinns Sep 9, 2023
1aa45d4
fixed style errors
tsbinns Sep 9, 2023
5497b55
Merge branch 'mne-tools:main' into pr-mvc_padding
tsbinns Oct 21, 2023
5bad5fb
Squashed commit of the following:
tsbinns Oct 24, 2023
238c0de
try fix ci error
tsbinns Oct 24, 2023
6fe682f
bug fix missing refactoring for example
tsbinns Oct 24, 2023
921cf9d
switch to masked arrays for indices
tsbinns Oct 26, 2023
27fac57
fix spelling error
tsbinns Oct 26, 2023
e0ebf2d
try fix codespell error
tsbinns Oct 26, 2023
8799ee2
Merge branch 'main' into pr-mvc_padding
larsoner Oct 27, 2023
f54cecc
Revert "bug fix missing refactoring for example"
tsbinns Nov 2, 2023
610ab7d
Revert "Squashed commit of the following:"
tsbinns Nov 2, 2023
f106d07
Revert "Squashed commit of the following:"
tsbinns Nov 2, 2023
c985886
switched to masked indices for multivariate conn
tsbinns Nov 2, 2023
a6a58e4
Revert "switch to masked arrays for indices"
tsbinns Nov 2, 2023
034dafe
Revert "bug fix missing refactoring for example"
tsbinns Nov 2, 2023
4907106
Revert "Squashed commit of the following:"
tsbinns Nov 2, 2023
736e642
Merge branch 'pr-mvc_padding_revert' into pr-mvc_padding
tsbinns Nov 2, 2023
cd69d65
updated time
tsbinns Nov 2, 2023
d6ba398
switched to masked indices for multivariate conn
tsbinns Nov 2, 2023
0732d22
removed redundant ignored word
tsbinns Nov 2, 2023
cd0b90f
removed redundant list creation
tsbinns Nov 2, 2023
883e816
updated default non-zero rank tolerance
tsbinns Nov 2, 2023
ff3b9e9
Merge branch 'pr-mvc_padding' of https://github.com/tsbinns/mne-conne…
tsbinns Nov 2, 2023
f5c1c3e
switched to array indices & added inline comments
tsbinns Nov 3, 2023
bac5000
fixed grammar mistake
tsbinns Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ Post-processing on connectivity

degree
seed_target_indices
seed_target_multivariate_indices
check_indices
check_multivariate_indices
select_order

Visualization functions
Expand Down
32 changes: 11 additions & 21 deletions examples/granger_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@

###############################################################################
# We will focus on connectivity between sensors over the parietal and occipital
# cortices, with 20 parietal sensors designated as group A, and 20 occipital
# cortices, with 20 parietal sensors designated as group A, and 22 occipital
# sensors designated as group B.

# %%
Expand All @@ -157,17 +157,8 @@
signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if
ch_info['ch_name'][2] == 'O']

# XXX: Currently ragged indices are not supported, so we only consider a single
# list of indices with an equal number of seeds and targets
min_n_chs = min(len(signals_a), len(signals_b))
signals_a = signals_a[:min_n_chs]
signals_b = signals_b[:min_n_chs]

indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B
indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A

signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a]
signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b]
indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B
indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A

# compute Granger causality
gc_ab = spectral_connectivity_epochs(
Expand All @@ -181,8 +172,8 @@

###############################################################################
# Plotting the results, we see that there is a flow of information from our
# parietal sensors (group A) to our occipital sensors (group B) with noticeable
# peaks at around 8, 18, and 26 Hz.
# parietal sensors (group A) to our occipital sensors (group B) with a
# noticeable peak at ~8 Hz, and smaller peaks at 18 and 26 Hz.

# %%

Expand All @@ -208,8 +199,7 @@
#
# Doing so, we see that the flow of information across the spectrum remains
# dominant from parietal to occipital sensors (indicated by the positive-valued
# Granger scores). However, the pattern of connectivity is altered, such as
# around 10 and 12 Hz where peaks of net information flow are now present.
# Granger scores), with similar peaks around 10, 18, and 26 Hz.

# %%

Expand Down Expand Up @@ -289,8 +279,8 @@
# Plotting the TRGC results, reveals a very different picture compared to net
# GC. For one, there is now a dominance of information flow ~6 Hz from
# occipital to parietal sensors (indicated by the negative-valued Granger
# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum,
# with parietal to occipital information flow between 13-20 Hz being much more
# scores). Additionally, the peak ~10 Hz is less dominant in the spectrum, with
# parietal to occipital information flow between 13-20 Hz being much more
# prominent. The stark difference between net GC and TRGC results indicates
# that the net GC spectrum was contaminated by spurious connectivity resulting
# from source mixing or correlated noise in the recordings. Altogether, the use
Expand Down Expand Up @@ -366,8 +356,8 @@

# gets the singular values of the data
s = np.linalg.svd(raw.get_data(), compute_uv=False)
# finds how many singular values are "close" to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria
# finds how many singular values are 'close' to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria

###############################################################################
# Nonethless, even in situations where you specify an appropriate rank, it is
Expand All @@ -387,7 +377,7 @@
try:
spectral_connectivity_epochs(
epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None,
gc_n_lags=20) # A => B
gc_n_lags=20, verbose=False) # A => B
print('Success!')
except RuntimeError as error:
print('\nCaught the following error:\n' + repr(error))
Comment on lines 377 to 383
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want, you could add this to expected_failing_examples then you don't have to try/except at all here. Just let it actually fail and sphinx-gallery will print a nicely formatted traceback for you.

https://sphinx-gallery.github.io/stable/configuration.html#dont-fail-exit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traceback in the file works really nicely, but is it possible to overwrite the "BROKEN" thumbnail?
image
image

E.g. setting # sphinx_gallery_thumbnail_path = '_static/granger_causality_gallery_thumbnail.png' within the example did not have an effect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not at the moment. Feel free to open an issue at https://github.com/sphinx-gallery/sphinx-gallery about adding this possibility

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened! sphinx-gallery/sphinx-gallery#1220

If the others agree, I would stick with the try-except approach until this behaviour in sphinx gallery is changed, so as not to give the impression from the thumbnail that the entire example is about failing code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine as is already, no need to wait to merge for this. It's an easy follow up PR after this is merged and SG has the machinery it needs

Expand Down
154 changes: 154 additions & 0 deletions examples/handling_ragged_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
=========================================================
Working with ragged indices for multivariate connectivity
=========================================================

This example demonstrates how multivariate connectivity involving different
numbers of seeds and targets can be handled in MNE-Connectivity.
"""

# Author: Thomas S. Binns <t.s.binns@outlook.com>
# License: BSD (3-clause)

# %%

import numpy as np

from mne_connectivity import spectral_connectivity_epochs

###############################################################################
# Background
# ----------
#
# With multivariate connectivity, interactions between multiple signals can be
# considered together, and the number of signals designated as seeds and
# targets does not have to be equal within or across connections. Issues can
# arise from this when storing information associated with connectivity in
# arrays, as the number of entries within each dimension can vary within and
# across connections depending on the number of seeds and targets. Such arrays
# are 'ragged', and support for ragged arrays is limited in NumPy to the
# ``object`` datatype. Not only is working with ragged arrays is cumbersome,
# but saving arrays with ``dtype='object'`` is not supported by the h5netcdf
# engine used to save connectivity objects. The workaround used in
# MNE-Connectivity is to pad ragged arrays with some known values according to
# the largest number of entries in each dimension, such that there is an equal
# amount of information across and within connections for each dimension of the
# arrays.
#
# As an example, consider we have 5 channels and want to compute 2 connections:
# the first between channels in indices 0 and 1 with those in indices 2, 3,
# and 4; and the second between channels 0, 1, 2, and 3 with channel 4. The
# seed and target indices can be written as such::
#
# seeds = [[0, 1 ], [0, 1, 2, 3]]
# targets = [[2, 3, 4], [4 ]]
#
# The ``indices`` parameter passed to
# :func:`~mne_connectivity.spectral_connectivity_epochs` and
# :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of
# array-likes, meaning
# that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays.
# Examples of how ``indices`` can be formed are shown below::
#
# # tuple of lists
# ragged_indices = ([[0, 1 ], [0, 1, 2, 3]],
# [[2, 3, 4], [4 ]])
#
# # tuple of tuples
# ragged_indices = (((0, 1 ), (0, 1, 2, 3)),
# ((2, 3, 4), (4 )))
#
# # tuple of arrays
# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'),
# np.array([[2, 3, 4], [4 ]], dtype='object'))
#
# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be
# specified.**
#
# Just as for bivariate connectivity, the length of ``indices[0]`` and
# ``indices[1]`` is equal (i.e. the number of connections), however information
# about the multiple channel indices for each connection is stored in a nested
# array. Importantly, these indices are ragged, as the first connection will be
# computed between 2 seed and 3 target channels, and the second connection
# between 4 seed and 1 target channel. The connectivity functions will
# recognise the indices as being ragged, and pad them accordingly to make them
# easier to work with and compatible with the h5netcdf saving engine. The known
# value used to pad the arrays is ``-1``, an invalid channel index. The above
# indices would be padded to::
#
# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, -1], [4, -1, -1, -1]]))
#
# These indices are what is stored in the connectivity object, and is also the
# format of indices returned from the helper functions
# :func:`~mne_connectivity.check_multivariate_indices` and
# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also
# possible to pass the padded indices to the connectivity functions directly.
#
# For the connectivity results themselves, the methods available in
# MNE-Connectivity combine information across the different channels into a
# single (time-)frequency-resolved connectivity spectrum, regardless of the
# number of seed and target channels, so ragged arrays are not a concern here.
# However, the maximised imaginary part of coherency (MIC) method also returns
# spatial patterns of connectivity, which show the contribution of each channel
# to the dimensionality-reduced connectivity estimate (explained in more detail
# in :doc:`mic_mim`). Because these patterns are returned for each channel,
# their shape can vary depending on the number of seeds and targets in each
# connection, making them ragged. To avoid this, the patterns are padded along
# the channel axis with the known and invalid entry ``np.nan``, in line with
# that applied to ``indices``. Extracting only the valid spatial patterns from
# the connectivity object is trivial, as shown below:

# %%

# create random data
data = np.random.randn(10, 5, 200) # epochs x channels x times
sfreq = 50
ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds
[[2, 3, 4], [4]]) # targets

# compute connectivity
con = spectral_connectivity_epochs(
data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30,
verbose=False)
patterns = np.array(con.attrs['patterns'])
padded_indices = con.indices
n_freqs = con.get_data().shape[-1]
n_cons = len(ragged_indices[0])
max_n_chans = max(
[len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])])

# show that the padded indices entries are all -1
assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels
assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels
assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels
assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels

# patterns have shape [seeds/targets x cons x max channels x freqs (x times)]
assert patterns.shape == (2, n_cons, max_n_chans, n_freqs)

# show that the padded patterns entries are all np.nan
assert np.all(np.isnan(patterns[0, 0, 2:])) # 2 padded channels
assert np.all(np.isnan(patterns[1, 0, 3:])) # 1 padded channels
assert not np.any(np.isnan(patterns[0, 1])) # 0 padded channels
assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels

# extract patterns for first connection using the ragged indices
seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])]
target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])]

# extract patterns for second connection using the padded indices (pad = -1)
seed_patterns_con2 = (
patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)])
target_patterns_con2 = (
patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)])

# show that shapes of patterns are correct
assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1)
assert target_patterns_con1.shape == (3, n_freqs) # channels (2, 3, 4)
assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3)
assert target_patterns_con2.shape == (1, n_freqs) # channels (4)

print('Assertions completed successfully!')

# %%
43 changes: 18 additions & 25 deletions examples/mic_mim.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
###############################################################################
# We will focus on connectivity between sensors over the left and right
# hemispheres, with 75 sensors in the left hemisphere designated as seeds, and
# 75 sensors in the right hemisphere designated as targets.
# 76 sensors in the right hemisphere designated as targets.

# %%

Expand All @@ -81,13 +81,7 @@
targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if
ch_info['loc'][0] > 0]

# XXX: Currently ragged indices are not supported, so we only consider a single
# list of indices with an equal number of seeds and targets
min_n_chs = min(len(seeds), len(targets))
seeds = seeds[:min_n_chs]
targets = targets[:min_n_chs]

multivar_indices = (np.array(seeds), np.array(targets))
multivar_indices = (np.array([seeds]), np.array([targets]))

seed_names = [epochs.info['ch_names'][idx] for idx in seeds]
target_names = [epochs.info['ch_names'][idx] for idx in targets]
Expand Down Expand Up @@ -171,12 +165,11 @@
#
# Here, we average across the patterns in the 13-18 Hz range. Plotting the
# patterns shows that the greatest connectivity between the left and right
# hemispheres occurs at the posteromedial regions, based on the regions with
# the largest absolute values. Using the signs of the values, we can infer the
# existence of a dipole source in the central regions of the left hemisphere
# which may account for the connectivity contributions seen for the left
# posteromedial and frontolateral areas (represented on the plot as a green
# line).
# hemispheres occurs at the left and right posterior and left central regions,
# based on the areas with the largest absolute values. Using the signs of the
# values, we can infer the existence of a dipole source between the central and
# posterior regions of the left hemisphere accounting for the connectivity
# contributions (represented on the plot as a green line).

# %%

Expand All @@ -185,9 +178,9 @@
fband_idx = [mic.freqs.index(freq) for freq in fband]

# patterns have shape [seeds/targets x cons x channels x freqs (x times)]
patterns = np.array(mic.attrs["patterns"])
seed_pattern = patterns[0]
target_pattern = patterns[1]
patterns = np.array(mic.attrs['patterns'])
seed_pattern = patterns[0, :, :len(seeds)]
target_pattern = patterns[1, :, :len(targets)]
# average across frequencies
seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1],
axis=1)
Expand Down Expand Up @@ -217,7 +210,7 @@

# plot the left hemisphere dipole example
axes[0].plot(
[-0.1, -0.05], [-0.075, -0.03], color='lime', linewidth=2,
[-0.01, -0.07], [-0.07, -0.03], color='lime', linewidth=2,
path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])

plt.show()
Expand Down Expand Up @@ -268,7 +261,7 @@
axis.set_ylabel('Absolute connectivity (A.U.)')
fig.suptitle('Multivariate interaction measure')

n_channels = len(np.unique([*multivar_indices[0], *multivar_indices[1]]))
n_channels = len(seeds) + len(targets)
normalised_mim = mim.get_data()[0] / n_channels
print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}')

Expand Down Expand Up @@ -296,7 +289,7 @@

# %%

indices = (np.array([*seeds, *targets]), np.array([*seeds, *targets]))
indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]]))
gim = spectral_connectivity_epochs(
epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None,
verbose=False)
Expand All @@ -307,7 +300,7 @@
axis.set_ylabel('Connectivity (A.U.)')
fig.suptitle('Global interaction measure')

n_channels = len(np.unique([*indices[0], *indices[1]]))
n_channels = len(seeds) + len(targets)
normalised_gim = gim.get_data()[0] / n_channels
print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}')

Expand Down Expand Up @@ -369,9 +362,9 @@

# no. channels equal with and without projecting to rank subspace for patterns
assert (patterns[0, 0].shape[0] ==
np.array(mic_red.attrs["patterns"])[0, 0].shape[0])
np.array(mic_red.attrs['patterns'])[0, 0].shape[0])
assert (patterns[1, 0].shape[0] ==
np.array(mic_red.attrs["patterns"])[1, 0].shape[0])
np.array(mic_red.attrs['patterns'])[1, 0].shape[0])


###############################################################################
Expand All @@ -392,8 +385,8 @@

# gets the singular values of the data
s = np.linalg.svd(raw.get_data(), compute_uv=False)
# finds how many singular values are "close" to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria
# finds how many singular values are 'close' to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria


###############################################################################
Expand Down
3 changes: 2 additions & 1 deletion mne_connectivity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .io import read_connectivity
from .spectral import spectral_connectivity_time, spectral_connectivity_epochs
from .vector_ar import vector_auto_regression, select_order
from .utils import check_indices, degree, seed_target_indices
from .utils import (check_indices, check_multivariate_indices, degree,
seed_target_indices, seed_target_multivariate_indices)
9 changes: 8 additions & 1 deletion mne_connectivity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,14 @@ def _prepare_xarray(self, data, names, indices, n_nodes, method,

# set method, indices and n_nodes
if isinstance(indices, tuple):
new_indices = (list(indices[0]), list(indices[1]))
if all([isinstance(inds, np.ndarray) for inds in indices]):
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
# leave multivariate indices as arrays for easier indexing
if all([inds.ndim > 1 for inds in indices]):
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
new_indices = (indices[0], indices[1])
else:
new_indices = (list(indices[0]), list(indices[1]))
else:
new_indices = (list(indices[0]), list(indices[1]))
indices = new_indices
kwargs['method'] = method
kwargs['indices'] = indices
Expand Down
Loading
Loading