Skip to content

Commit

Permalink
added support for ragged connections
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jul 24, 2023
1 parent a7bb85b commit 73792c3
Show file tree
Hide file tree
Showing 12 changed files with 726 additions and 295 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ Post-processing on connectivity

degree
seed_target_indices
seed_target_multivariate_indices
check_indices
check_multivariate_indices
select_order

Visualization functions
Expand Down
32 changes: 11 additions & 21 deletions examples/granger_causality.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@

###############################################################################
# We will focus on connectivity between sensors over the parietal and occipital
# cortices, with 20 parietal sensors designated as group A, and 20 occipital
# cortices, with 20 parietal sensors designated as group A, and 22 occipital
# sensors designated as group B.

# %%
Expand All @@ -157,17 +157,8 @@
signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if
ch_info['ch_name'][2] == 'O']

# XXX: Currently ragged indices are not supported, so we only consider a single
# list of indices with an equal number of seeds and targets
min_n_chs = min(len(signals_a), len(signals_b))
signals_a = signals_a[:min_n_chs]
signals_b = signals_b[:min_n_chs]

indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B
indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A

signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a]
signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b]
indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B
indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A

# compute Granger causality
gc_ab = spectral_connectivity_epochs(
Expand All @@ -181,8 +172,8 @@

###############################################################################
# Plotting the results, we see that there is a flow of information from our
# parietal sensors (group A) to our occipital sensors (group B) with noticeable
# peaks at around 8, 18, and 26 Hz.
# parietal sensors (group A) to our occipital sensors (group B) with a
# noticeable peak at ~8 Hz, and smaller peaks at 18 and 26 Hz.

# %%

Expand All @@ -208,8 +199,7 @@
#
# Doing so, we see that the flow of information across the spectrum remains
# dominant from parietal to occipital sensors (indicated by the positive-valued
# Granger scores). However, the pattern of connectivity is altered, such as
# around 10 and 12 Hz where peaks of net information flow are now present.
# Granger scores), with similar peaks around 10, 18, and 26 Hz.

# %%

Expand Down Expand Up @@ -289,8 +279,8 @@
# Plotting the TRGC results, reveals a very different picture compared to net
# GC. For one, there is now a dominance of information flow ~6 Hz from
# occipital to parietal sensors (indicated by the negative-valued Granger
# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum,
# with parietal to occipital information flow between 13-20 Hz being much more
# scores). Additionally, the peak ~10 Hz is less dominant in the spectrum, with
# parietal to occipital information flow between 13-20 Hz being much more
# prominent. The stark difference between net GC and TRGC results indicates
# that the net GC spectrum was contaminated by spurious connectivity resulting
# from source mixing or correlated noise in the recordings. Altogether, the use
Expand Down Expand Up @@ -366,8 +356,8 @@

# gets the singular values of the data
s = np.linalg.svd(raw.get_data(), compute_uv=False)
# finds how many singular values are "close" to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria
# finds how many singular values are 'close' to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria

###############################################################################
# Nonethless, even in situations where you specify an appropriate rank, it is
Expand All @@ -387,7 +377,7 @@
try:
spectral_connectivity_epochs(
epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None,
gc_n_lags=20) # A => B
gc_n_lags=20, verbose=False) # A => B
print('Success!')
except RuntimeError as error:
print('\nCaught the following error:\n' + repr(error))
Expand Down
154 changes: 154 additions & 0 deletions examples/handling_ragged_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
=========================================================
Working with ragged indices for multivariate connectivity
=========================================================
This example demonstrates how multivariate connectivity involving different
numbers of seeds and targets can be handled in MNE-Connectivity.
"""

# Author: Thomas S. Binns <t.s.binns@outlook.com>
# License: BSD (3-clause)

# %%

import numpy as np

from mne_connectivity import spectral_connectivity_epochs

###############################################################################
# Background
# ----------
#
# With multivariate connectivity, interactions between multiple signals can be
# considered together, and the number of signals designated as seeds and
# targets does not have to be equal within or across connections. Issues can
# arise from this when storing information associated with connectivity in
# arrays, as the number of entries within each dimension can vary within and
# across connections depending on the number of seeds and targets. Such arrays
# are 'ragged', and support for ragged arrays is limited in NumPy to the
# ``object`` datatype. Not only is working with ragged arrays is cumbersome,
# but saving arrays with ``dtype='object'`` is not supported by the h5netcdf
# engine used to save connectivity objects. The workaround used in
# MNE-Connectivity is to pad ragged arrays with some known values according to
# the largest number of entries in each dimension, such that there is an equal
# amount of information across and within connections for each dimension of the
# arrays.
#
# As an example, consider we have 5 channels and want to compute 2 connections:
# the first between channels in indices 0 and 1 with those in indices 2, 3,
# and 4; and the second between channels 0, 1, 2, and 3 with channel 4. The
# seed and target indices can be written as such::
#
# seeds = [[0, 1 ], [0, 1, 2, 3]]
# targets = [[2, 3, 4], [4 ]]
#
# The ``indices`` parameter passed to
# :func:`~mne_connectivity.spectral_connectivity_epochs` and
# :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of
# array-likes, meaning
# that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays.
# Examples of how ``indices`` can be formed are shown below::
#
# # tuple of lists
# ragged_indices = ([[0, 1 ], [0, 1, 2, 3]],
# [[2, 3, 4], [4 ]])
#
# # tuple of tuples
# ragged_indices = (((0, 1 ), (0, 1, 2, 3)),
# ((2, 3, 4), (4 )))
#
# # tuple of arrays
# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'),
# np.array([[2, 3, 4], [4 ]], dtype='object'))
#
# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be
# specified.**
#
# Just as for bivariate connectivity, the length of ``indices[0]`` and
# ``indices[1]`` is equal (i.e. the number of connections), however information
# about the multiple channel indices for each connection is stored in a nested
# array. Importantly, these indices are ragged, as the first connection will be
# computed between 2 seed and 3 target channels, and the second connection
# between 4 seed and 1 target channel. The connectivity functions will
# recognise the indices as being ragged, and pad them accordingly to make them
# easier to work with and compatible with the h5netcdf saving engine. The known
# value used to pad the arrays is ``-1``, an invalid channel index. The above
# indices would be padded to::
#
# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]),
# np.array([[2, 3, 4, -1], [4, -1, -1, -1]]))
#
# These indices are what is stored in the connectivity object, and is also the
# format of indices returned from the helper functions
# :func:`~mne_connectivity.check_multivariate_indices` and
# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also
# possible to pass the padded indices to the connectivity functions directly.
#
# For the connectivity results themselves, the methods available in
# MNE-Connectivity combine information across the different channels into a
# single (time-)frequency-resolved connectivity spectrum, regardless of the
# number of seed and target channels, so ragged arrays are not a concern here.
# However, the maximised imaginary part of coherency (MIC) method also returns
# spatial patterns of connectivity, which show the contribution of each channel
# to the dimensionality-reduced connectivity estimate (explained in more detail
# in :doc:`mic_mim`). Because these patterns are returned for each channel,
# their shape can vary depending on the number of seeds and targets in each
# connection, making them ragged. To avoid this, the patterns are padded along
# the channel axis with the known and invalid entry ``np.nan``, in line with
# that applied to ``indices``. Extracting only the valid spatial patterns from
# the connectivity object is trivial, as shown below:

# %%

# create random data
data = np.random.randn(10, 5, 200) # epochs x channels x times
sfreq = 50
ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds
[[2, 3, 4], [4]]) # targets

# compute connectivity
con = spectral_connectivity_epochs(
data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30,
verbose=False)
patterns = np.array(con.attrs['patterns'])
padded_indices = con.indices
n_freqs = con.get_data().shape[-1]
n_cons = len(ragged_indices[0])
max_n_chans = max(
[len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])])

# show that the padded indices entries are all -1
assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels
assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels
assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels
assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels

# patterns have shape [seeds/targets x cons x max channels x freqs (x times)]
assert patterns.shape == (2, n_cons, max_n_chans, n_freqs)

# show that the padded patterns entries are all np.nan
assert np.all(np.isnan(patterns[0, 0, 2:])) # 2 padded channels
assert np.all(np.isnan(patterns[1, 0, 3:])) # 1 padded channels
assert not np.any(np.isnan(patterns[0, 1])) # 0 padded channels
assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels

# extract patterns for first connection using the ragged indices
seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])]
target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])]

# extract patterns for second connection using the padded indices (pad = -1)
seed_patterns_con2 = (
patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)])
target_patterns_con2 = (
patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)])

# show that shapes of patterns are correct
assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1)
assert target_patterns_con1.shape == (3, n_freqs) # channels (2, 3, 4)
assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3)
assert target_patterns_con2.shape == (1, n_freqs) # channels (4)

print('Assertions completed successfully!')

# %%
43 changes: 18 additions & 25 deletions examples/mic_mim.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
###############################################################################
# We will focus on connectivity between sensors over the left and right
# hemispheres, with 75 sensors in the left hemisphere designated as seeds, and
# 75 sensors in the right hemisphere designated as targets.
# 76 sensors in the right hemisphere designated as targets.

# %%

Expand All @@ -81,13 +81,7 @@
targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if
ch_info['loc'][0] > 0]

# XXX: Currently ragged indices are not supported, so we only consider a single
# list of indices with an equal number of seeds and targets
min_n_chs = min(len(seeds), len(targets))
seeds = seeds[:min_n_chs]
targets = targets[:min_n_chs]

multivar_indices = (np.array(seeds), np.array(targets))
multivar_indices = (np.array([seeds]), np.array([targets]))

seed_names = [epochs.info['ch_names'][idx] for idx in seeds]
target_names = [epochs.info['ch_names'][idx] for idx in targets]
Expand Down Expand Up @@ -171,12 +165,11 @@
#
# Here, we average across the patterns in the 13-18 Hz range. Plotting the
# patterns shows that the greatest connectivity between the left and right
# hemispheres occurs at the posteromedial regions, based on the regions with
# the largest absolute values. Using the signs of the values, we can infer the
# existence of a dipole source in the central regions of the left hemisphere
# which may account for the connectivity contributions seen for the left
# posteromedial and frontolateral areas (represented on the plot as a green
# line).
# hemispheres occurs at the left and right posterior and left central regions,
# based on the areas with the largest absolute values. Using the signs of the
# values, we can infer the existence of a dipole source between the central and
# posterior regions of the left hemisphere accounting for the connectivity
# contributions (represented on the plot as a green line).

# %%

Expand All @@ -185,9 +178,9 @@
fband_idx = [mic.freqs.index(freq) for freq in fband]

# patterns have shape [seeds/targets x cons x channels x freqs (x times)]
patterns = np.array(mic.attrs["patterns"])
seed_pattern = patterns[0]
target_pattern = patterns[1]
patterns = np.array(mic.attrs['patterns'])
seed_pattern = patterns[0, :, :len(seeds)]
target_pattern = patterns[1, :, :len(targets)]
# average across frequencies
seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1],
axis=1)
Expand Down Expand Up @@ -217,7 +210,7 @@

# plot the left hemisphere dipole example
axes[0].plot(
[-0.1, -0.05], [-0.075, -0.03], color='lime', linewidth=2,
[-0.01, -0.07], [-0.07, -0.03], color='lime', linewidth=2,
path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()])

plt.show()
Expand Down Expand Up @@ -268,7 +261,7 @@
axis.set_ylabel('Absolute connectivity (A.U.)')
fig.suptitle('Multivariate interaction measure')

n_channels = len(np.unique([*multivar_indices[0], *multivar_indices[1]]))
n_channels = len(seeds) + len(targets)
normalised_mim = mim.get_data()[0] / n_channels
print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}')

Expand Down Expand Up @@ -296,7 +289,7 @@

# %%

indices = (np.array([*seeds, *targets]), np.array([*seeds, *targets]))
indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]]))
gim = spectral_connectivity_epochs(
epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None,
verbose=False)
Expand All @@ -307,7 +300,7 @@
axis.set_ylabel('Connectivity (A.U.)')
fig.suptitle('Global interaction measure')

n_channels = len(np.unique([*indices[0], *indices[1]]))
n_channels = len(seeds) + len(targets)
normalised_gim = gim.get_data()[0] / n_channels
print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}')

Expand Down Expand Up @@ -369,9 +362,9 @@

# no. channels equal with and without projecting to rank subspace for patterns
assert (patterns[0, 0].shape[0] ==
np.array(mic_red.attrs["patterns"])[0, 0].shape[0])
np.array(mic_red.attrs['patterns'])[0, 0].shape[0])
assert (patterns[1, 0].shape[0] ==
np.array(mic_red.attrs["patterns"])[1, 0].shape[0])
np.array(mic_red.attrs['patterns'])[1, 0].shape[0])


###############################################################################
Expand All @@ -392,8 +385,8 @@

# gets the singular values of the data
s = np.linalg.svd(raw.get_data(), compute_uv=False)
# finds how many singular values are "close" to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria
# finds how many singular values are 'close' to the largest singular value
rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria


###############################################################################
Expand Down
3 changes: 2 additions & 1 deletion mne_connectivity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .io import read_connectivity
from .spectral import spectral_connectivity_time, spectral_connectivity_epochs
from .vector_ar import vector_auto_regression, select_order
from .utils import check_indices, degree, seed_target_indices
from .utils import (check_indices, check_multivariate_indices, degree,
seed_target_indices, seed_target_multivariate_indices)
9 changes: 8 additions & 1 deletion mne_connectivity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,14 @@ def _prepare_xarray(self, data, names, indices, n_nodes, method,

# set method, indices and n_nodes
if isinstance(indices, tuple):
new_indices = (list(indices[0]), list(indices[1]))
if all([isinstance(inds, np.ndarray) for inds in indices]):
# leave multivariate indices as arrays for easier indexing
if all([inds.ndim > 1 for inds in indices]):
new_indices = (indices[0], indices[1])
else:
new_indices = (list(indices[0]), list(indices[1]))
else:
new_indices = (list(indices[0]), list(indices[1]))
indices = new_indices
kwargs['method'] = method
kwargs['indices'] = indices
Expand Down
Loading

0 comments on commit 73792c3

Please sign in to comment.