From 73792c3ee6c4a2cf2b2d963f47701481b5a19f39 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 24 Jul 2023 11:19:05 +0200 Subject: [PATCH] added support for ragged connections --- doc/api.rst | 2 + examples/granger_causality.py | 32 +- examples/handling_ragged_arrays.py | 154 ++++++++++ examples/mic_mim.py | 43 ++- mne_connectivity/__init__.py | 3 +- mne_connectivity/base.py | 9 +- mne_connectivity/spectral/epochs.py | 273 ++++++++++-------- .../spectral/tests/test_spectral.py | 167 +++++++---- mne_connectivity/spectral/time.py | 167 ++++++----- mne_connectivity/tests/test_utils.py | 73 ++++- mne_connectivity/utils/__init__.py | 3 +- mne_connectivity/utils/utils.py | 95 ++++++ 12 files changed, 726 insertions(+), 295 deletions(-) create mode 100644 examples/handling_ragged_arrays.py diff --git a/doc/api.rst b/doc/api.rst index 26ef14b6..c91f9c02 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -73,7 +73,9 @@ Post-processing on connectivity degree seed_target_indices + seed_target_multivariate_indices check_indices + check_multivariate_indices select_order Visualization functions diff --git a/examples/granger_causality.py b/examples/granger_causality.py index f5d8316d..64a657db 100644 --- a/examples/granger_causality.py +++ b/examples/granger_causality.py @@ -145,7 +145,7 @@ ############################################################################### # We will focus on connectivity between sensors over the parietal and occipital -# cortices, with 20 parietal sensors designated as group A, and 20 occipital +# cortices, with 20 parietal sensors designated as group A, and 22 occipital # sensors designated as group B. # %% @@ -157,17 +157,8 @@ signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if ch_info['ch_name'][2] == 'O'] -# XXX: Currently ragged indices are not supported, so we only consider a single -# list of indices with an equal number of seeds and targets -min_n_chs = min(len(signals_a), len(signals_b)) -signals_a = signals_a[:min_n_chs] -signals_b = signals_b[:min_n_chs] - -indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B -indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A - -signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a] -signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b] +indices_ab = (np.array([signals_a]), np.array([signals_b])) # A => B +indices_ba = (np.array([signals_b]), np.array([signals_a])) # B => A # compute Granger causality gc_ab = spectral_connectivity_epochs( @@ -181,8 +172,8 @@ ############################################################################### # Plotting the results, we see that there is a flow of information from our -# parietal sensors (group A) to our occipital sensors (group B) with noticeable -# peaks at around 8, 18, and 26 Hz. +# parietal sensors (group A) to our occipital sensors (group B) with a +# noticeable peak at ~8 Hz, and smaller peaks at 18 and 26 Hz. # %% @@ -208,8 +199,7 @@ # # Doing so, we see that the flow of information across the spectrum remains # dominant from parietal to occipital sensors (indicated by the positive-valued -# Granger scores). However, the pattern of connectivity is altered, such as -# around 10 and 12 Hz where peaks of net information flow are now present. +# Granger scores), with similar peaks around 10, 18, and 26 Hz. # %% @@ -289,8 +279,8 @@ # Plotting the TRGC results, reveals a very different picture compared to net # GC. For one, there is now a dominance of information flow ~6 Hz from # occipital to parietal sensors (indicated by the negative-valued Granger -# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum, -# with parietal to occipital information flow between 13-20 Hz being much more +# scores). Additionally, the peak ~10 Hz is less dominant in the spectrum, with +# parietal to occipital information flow between 13-20 Hz being much more # prominent. The stark difference between net GC and TRGC results indicates # that the net GC spectrum was contaminated by spurious connectivity resulting # from source mixing or correlated noise in the recordings. Altogether, the use @@ -366,8 +356,8 @@ # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) -# finds how many singular values are "close" to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria ############################################################################### # Nonethless, even in situations where you specify an appropriate rank, it is @@ -387,7 +377,7 @@ try: spectral_connectivity_epochs( epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, - gc_n_lags=20) # A => B + gc_n_lags=20, verbose=False) # A => B print('Success!') except RuntimeError as error: print('\nCaught the following error:\n' + repr(error)) diff --git a/examples/handling_ragged_arrays.py b/examples/handling_ragged_arrays.py new file mode 100644 index 00000000..67f06a4e --- /dev/null +++ b/examples/handling_ragged_arrays.py @@ -0,0 +1,154 @@ +""" +========================================================= +Working with ragged indices for multivariate connectivity +========================================================= + +This example demonstrates how multivariate connectivity involving different +numbers of seeds and targets can be handled in MNE-Connectivity. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) + +# %% + +import numpy as np + +from mne_connectivity import spectral_connectivity_epochs + +############################################################################### +# Background +# ---------- +# +# With multivariate connectivity, interactions between multiple signals can be +# considered together, and the number of signals designated as seeds and +# targets does not have to be equal within or across connections. Issues can +# arise from this when storing information associated with connectivity in +# arrays, as the number of entries within each dimension can vary within and +# across connections depending on the number of seeds and targets. Such arrays +# are 'ragged', and support for ragged arrays is limited in NumPy to the +# ``object`` datatype. Not only is working with ragged arrays is cumbersome, +# but saving arrays with ``dtype='object'`` is not supported by the h5netcdf +# engine used to save connectivity objects. The workaround used in +# MNE-Connectivity is to pad ragged arrays with some known values according to +# the largest number of entries in each dimension, such that there is an equal +# amount of information across and within connections for each dimension of the +# arrays. +# +# As an example, consider we have 5 channels and want to compute 2 connections: +# the first between channels in indices 0 and 1 with those in indices 2, 3, +# and 4; and the second between channels 0, 1, 2, and 3 with channel 4. The +# seed and target indices can be written as such:: +# +# seeds = [[0, 1 ], [0, 1, 2, 3]] +# targets = [[2, 3, 4], [4 ]] +# +# The ``indices`` parameter passed to +# :func:`~mne_connectivity.spectral_connectivity_epochs` and +# :func:`~mne_connectivity.spectral_connectivity_time` must be a tuple of +# array-likes, meaning +# that the indices can be passed as a tuple of: lists; tuples; or NumPy arrays. +# Examples of how ``indices`` can be formed are shown below:: +# +# # tuple of lists +# ragged_indices = ([[0, 1 ], [0, 1, 2, 3]], +# [[2, 3, 4], [4 ]]) +# +# # tuple of tuples +# ragged_indices = (((0, 1 ), (0, 1, 2, 3)), +# ((2, 3, 4), (4 ))) +# +# # tuple of arrays +# ragged_indices = (np.array([[0, 1 ], [0, 1, 2, 3]], dtype='object'), +# np.array([[2, 3, 4], [4 ]], dtype='object')) +# +# **N.B. Note that when forming ragged arrays in NumPy, dtype='object' must be +# specified.** +# +# Just as for bivariate connectivity, the length of ``indices[0]`` and +# ``indices[1]`` is equal (i.e. the number of connections), however information +# about the multiple channel indices for each connection is stored in a nested +# array. Importantly, these indices are ragged, as the first connection will be +# computed between 2 seed and 3 target channels, and the second connection +# between 4 seed and 1 target channel. The connectivity functions will +# recognise the indices as being ragged, and pad them accordingly to make them +# easier to work with and compatible with the h5netcdf saving engine. The known +# value used to pad the arrays is ``-1``, an invalid channel index. The above +# indices would be padded to:: +# +# padded_indices = (np.array([[0, 1, -1, -1], [0, 1, 2, 3]]), +# np.array([[2, 3, 4, -1], [4, -1, -1, -1]])) +# +# These indices are what is stored in the connectivity object, and is also the +# format of indices returned from the helper functions +# :func:`~mne_connectivity.check_multivariate_indices` and +# :func:`~mne_connectivity.seed_target_multivariate_indices`. It is also +# possible to pass the padded indices to the connectivity functions directly. +# +# For the connectivity results themselves, the methods available in +# MNE-Connectivity combine information across the different channels into a +# single (time-)frequency-resolved connectivity spectrum, regardless of the +# number of seed and target channels, so ragged arrays are not a concern here. +# However, the maximised imaginary part of coherency (MIC) method also returns +# spatial patterns of connectivity, which show the contribution of each channel +# to the dimensionality-reduced connectivity estimate (explained in more detail +# in :doc:`mic_mim`). Because these patterns are returned for each channel, +# their shape can vary depending on the number of seeds and targets in each +# connection, making them ragged. To avoid this, the patterns are padded along +# the channel axis with the known and invalid entry ``np.nan``, in line with +# that applied to ``indices``. Extracting only the valid spatial patterns from +# the connectivity object is trivial, as shown below: + +# %% + +# create random data +data = np.random.randn(10, 5, 200) # epochs x channels x times +sfreq = 50 +ragged_indices = ([[0, 1], [0, 1, 2, 3]], # seeds + [[2, 3, 4], [4]]) # targets + +# compute connectivity +con = spectral_connectivity_epochs( + data, method='mic', indices=ragged_indices, sfreq=sfreq, fmin=10, fmax=30, + verbose=False) +patterns = np.array(con.attrs['patterns']) +padded_indices = con.indices +n_freqs = con.get_data().shape[-1] +n_cons = len(ragged_indices[0]) +max_n_chans = max( + [len(inds) for inds in ([*ragged_indices[0], *ragged_indices[1]])]) + +# show that the padded indices entries are all -1 +assert np.count_nonzero(padded_indices[0][0] == -1) == 2 # 2 padded channels +assert np.count_nonzero(padded_indices[1][0] == -1) == 1 # 1 padded channels +assert np.count_nonzero(padded_indices[0][1] == -1) == 0 # 0 padded channels +assert np.count_nonzero(padded_indices[1][1] == -1) == 3 # 3 padded channels + +# patterns have shape [seeds/targets x cons x max channels x freqs (x times)] +assert patterns.shape == (2, n_cons, max_n_chans, n_freqs) + +# show that the padded patterns entries are all np.nan +assert np.all(np.isnan(patterns[0, 0, 2:])) # 2 padded channels +assert np.all(np.isnan(patterns[1, 0, 3:])) # 1 padded channels +assert not np.any(np.isnan(patterns[0, 1])) # 0 padded channels +assert np.all(np.isnan(patterns[1, 1, 1:])) # 3 padded channels + +# extract patterns for first connection using the ragged indices +seed_patterns_con1 = patterns[0, 0, :len(ragged_indices[0][0])] +target_patterns_con1 = patterns[1, 0, :len(ragged_indices[1][0])] + +# extract patterns for second connection using the padded indices (pad = -1) +seed_patterns_con2 = ( + patterns[0, 1, :np.count_nonzero(padded_indices[0][1] != -1)]) +target_patterns_con2 = ( + patterns[1, 1, :np.count_nonzero(padded_indices[1][1] != -1)]) + +# show that shapes of patterns are correct +assert seed_patterns_con1.shape == (2, n_freqs) # channels (0, 1) +assert target_patterns_con1.shape == (3, n_freqs) # channels (2, 3, 4) +assert seed_patterns_con2.shape == (4, n_freqs) # channels (0, 1, 2, 3) +assert target_patterns_con2.shape == (1, n_freqs) # channels (4) + +print('Assertions completed successfully!') + +# %% diff --git a/examples/mic_mim.py b/examples/mic_mim.py index 179ea620..87111586 100644 --- a/examples/mic_mim.py +++ b/examples/mic_mim.py @@ -70,7 +70,7 @@ ############################################################################### # We will focus on connectivity between sensors over the left and right # hemispheres, with 75 sensors in the left hemisphere designated as seeds, and -# 75 sensors in the right hemisphere designated as targets. +# 76 sensors in the right hemisphere designated as targets. # %% @@ -81,13 +81,7 @@ targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if ch_info['loc'][0] > 0] -# XXX: Currently ragged indices are not supported, so we only consider a single -# list of indices with an equal number of seeds and targets -min_n_chs = min(len(seeds), len(targets)) -seeds = seeds[:min_n_chs] -targets = targets[:min_n_chs] - -multivar_indices = (np.array(seeds), np.array(targets)) +multivar_indices = (np.array([seeds]), np.array([targets])) seed_names = [epochs.info['ch_names'][idx] for idx in seeds] target_names = [epochs.info['ch_names'][idx] for idx in targets] @@ -171,12 +165,11 @@ # # Here, we average across the patterns in the 13-18 Hz range. Plotting the # patterns shows that the greatest connectivity between the left and right -# hemispheres occurs at the posteromedial regions, based on the regions with -# the largest absolute values. Using the signs of the values, we can infer the -# existence of a dipole source in the central regions of the left hemisphere -# which may account for the connectivity contributions seen for the left -# posteromedial and frontolateral areas (represented on the plot as a green -# line). +# hemispheres occurs at the left and right posterior and left central regions, +# based on the areas with the largest absolute values. Using the signs of the +# values, we can infer the existence of a dipole source between the central and +# posterior regions of the left hemisphere accounting for the connectivity +# contributions (represented on the plot as a green line). # %% @@ -185,9 +178,9 @@ fband_idx = [mic.freqs.index(freq) for freq in fband] # patterns have shape [seeds/targets x cons x channels x freqs (x times)] -patterns = np.array(mic.attrs["patterns"]) -seed_pattern = patterns[0] -target_pattern = patterns[1] +patterns = np.array(mic.attrs['patterns']) +seed_pattern = patterns[0, :, :len(seeds)] +target_pattern = patterns[1, :, :len(targets)] # average across frequencies seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], axis=1) @@ -217,7 +210,7 @@ # plot the left hemisphere dipole example axes[0].plot( - [-0.1, -0.05], [-0.075, -0.03], color='lime', linewidth=2, + [-0.01, -0.07], [-0.07, -0.03], color='lime', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()]) plt.show() @@ -268,7 +261,7 @@ axis.set_ylabel('Absolute connectivity (A.U.)') fig.suptitle('Multivariate interaction measure') -n_channels = len(np.unique([*multivar_indices[0], *multivar_indices[1]])) +n_channels = len(seeds) + len(targets) normalised_mim = mim.get_data()[0] / n_channels print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}') @@ -296,7 +289,7 @@ # %% -indices = (np.array([*seeds, *targets]), np.array([*seeds, *targets])) +indices = (np.array([[*seeds, *targets]]), np.array([[*seeds, *targets]])) gim = spectral_connectivity_epochs( epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, verbose=False) @@ -307,7 +300,7 @@ axis.set_ylabel('Connectivity (A.U.)') fig.suptitle('Global interaction measure') -n_channels = len(np.unique([*indices[0], *indices[1]])) +n_channels = len(seeds) + len(targets) normalised_gim = gim.get_data()[0] / n_channels print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}') @@ -369,9 +362,9 @@ # no. channels equal with and without projecting to rank subspace for patterns assert (patterns[0, 0].shape[0] == - np.array(mic_red.attrs["patterns"])[0, 0].shape[0]) + np.array(mic_red.attrs['patterns'])[0, 0].shape[0]) assert (patterns[1, 0].shape[0] == - np.array(mic_red.attrs["patterns"])[1, 0].shape[0]) + np.array(mic_red.attrs['patterns'])[1, 0].shape[0]) ############################################################################### @@ -392,8 +385,8 @@ # gets the singular values of the data s = np.linalg.svd(raw.get_data(), compute_uv=False) -# finds how many singular values are "close" to the largest singular value -rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria +# finds how many singular values are 'close' to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the 'closeness' criteria ############################################################################### diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 57aeff7f..c2f03a6c 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -17,4 +17,5 @@ from .io import read_connectivity from .spectral import spectral_connectivity_time, spectral_connectivity_epochs from .vector_ar import vector_auto_regression, select_order -from .utils import check_indices, degree, seed_target_indices +from .utils import (check_indices, check_multivariate_indices, degree, + seed_target_indices, seed_target_multivariate_indices) diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index 88951529..672448d1 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -483,7 +483,14 @@ def _prepare_xarray(self, data, names, indices, n_nodes, method, # set method, indices and n_nodes if isinstance(indices, tuple): - new_indices = (list(indices[0]), list(indices[1])) + if all([isinstance(inds, np.ndarray) for inds in indices]): + # leave multivariate indices as arrays for easier indexing + if all([inds.ndim > 1 for inds in indices]): + new_indices = (indices[0], indices[1]) + else: + new_indices = (list(indices[0]), list(indices[1])) + else: + new_indices = (list(indices[0]), list(indices[1])) indices = new_indices kwargs['method'] = method kwargs['indices'] = indices diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index eb766f06..c8499341 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -24,7 +24,7 @@ ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) -from ..utils import fill_doc, check_indices +from ..utils import fill_doc, check_indices, check_multivariate_indices def _compute_freqs(n_times, sfreq, cwt_freqs, mode): @@ -92,40 +92,40 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, times = times_in[tmin_idx:tmax_idx] n_times = len(times) + if any(this_method in _multivariate_methods for this_method in method): + multivariate_con = True + else: + multivariate_con = False + if indices is None: - if any(this_method in _multivariate_methods for this_method in method): + if multivariate_con: if any(this_method in _gc_methods for this_method in method): raise ValueError( 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') else: logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int), - np.arange(n_signals, dtype=int)) + indices_use = (np.arange(n_signals, dtype=int)[np.newaxis, :], + np.arange(n_signals, dtype=int)[np.newaxis, :]) else: logger.info('only using indices for lower-triangular matrix') # only compute r for lower-triangular region indices_use = np.tril_indices(n_signals, -1) else: - if any(this_method in _gc_methods for this_method in method): - if set(indices[0]).intersection(indices[1]): - raise ValueError( - 'seed and target indices must not intersect when computing' - 'Granger causality') - indices_use = check_indices(indices) + if multivariate_con: + indices_use = check_multivariate_indices(indices) # pad with -1 + if any(this_method in _gc_methods for this_method in method): + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + else: + indices_use = check_indices(indices) # number of connectivities to compute - if any(this_method in _multivariate_methods for this_method in method): - if ( - len(np.unique(indices_use[0])) != len(indices_use[0]) or - len(np.unique(indices_use[1])) != len(indices_use[1]) - ): - raise ValueError( - 'seed and target indices cannot contain repeated channels for ' - 'multivariate connectivity') - n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED - else: - n_cons = len(indices_use[0]) + n_cons = len(indices_use[0]) logger.info(' computing connectivity for %d connections' % n_cons) @@ -189,6 +189,46 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, n_signals, indices_use, warn_times) +def _check_rank_input(rank, data, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + sv_tol = 1e-10 # tolerance for non-zero singular val (rel to largest) + if rank is None: + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + for group_i in range(2): # seeds and targets + for con_i, con_idcs in enumerate(indices[group_i]): + con_idcs = con_idcs[con_idcs != -1] # -1 is padded value + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * sv_tol) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank + + def _assemble_spectral_params(mode, n_times, mt_adaptive, mt_bandwidth, sfreq, mt_low_bias, cwt_n_cycles, cwt_freqs, freqs, freq_mask): @@ -422,22 +462,23 @@ def compute_con(self, indices, ranks, n_epochs=1): if self.name == 'MIC': self.patterns = np.full( - (2, self.n_cons, len(indices[0]), self.n_freqs, n_times), + (2, self.n_cons, indices[0].shape[1], self.n_freqs, n_times), np.nan) con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - [indices[0]], [indices[1]], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) - n_seeds = len(seed_idcs) + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] # Eqs. 32 & 33 C_bar, U_bar_aa, U_bar_bb = self._csd_svd( - C, n_seeds, seed_rank, target_rank) + C, seed_idcs, seed_rank, target_rank) # Eqs. 3 & 4 E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) @@ -452,10 +493,11 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() - def _csd_svd(self, csd, n_seeds, seed_rank, target_rank): + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): """Dimensionality reduction of CSD with SVD.""" n_times = csd.shape[0] - n_targets = csd.shape[2] - n_seeds + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds C_aa = csd[..., :n_seeds, :n_seeds] C_ab = csd[..., :n_seeds, n_seeds:] @@ -505,8 +547,9 @@ def _compute_e(self, csd, n_seeds): for block_i in ProgressBar( range(self.n_steps), mesg="frequency blocks"): freqs = self._get_block_indices(block_i, self.n_freqs) - parallel(parallel_compute_t( + T[:, freqs] = np.array(parallel(parallel_compute_t( C_r[:, f], T[:, f], n_seeds) for f in freqs) + ).transpose(1, 0, 2, 3) if not np.isreal(T).all() or not np.isfinite(T).all(): raise RuntimeError( @@ -526,6 +569,7 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, U_bar_bb, con_i): """Compute MIC and the associated spatial patterns.""" n_seeds = len(seed_idcs) + n_targets = len(target_idcs) times = np.arange(n_times) freqs = np.arange(self.n_freqs) @@ -564,12 +608,12 @@ def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] # Eq. 46 (seed spatial patterns) - self.patterns[0, con_i] = (np.matmul( + self.patterns[0, con_i, :n_seeds] = (np.matmul( np.real(C[..., :n_seeds, :n_seeds]), np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T # Eq. 47 (target spatial patterns) - self.patterns[1, con_i] = (np.matmul( + self.patterns[1, con_i, :n_targets] = (np.matmul( np.real(C[..., n_seeds:, n_seeds:]), np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T @@ -586,7 +630,7 @@ def _compute_mim(self, E, seed_idcs, target_idcs, con_i): E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T # Eq. 15 - if all(np.unique(seed_idcs) == np.unique(target_idcs)): + if np.all(np.unique(seed_idcs) == np.unique(target_idcs)): self.con_scores[con_i] *= 0.5 def reshape_results(self): @@ -598,7 +642,7 @@ def reshape_results(self): def _mic_mim_compute_t(C, T, n_seeds): - """Compute T in place for a single frequency (used for MIC and MIM).""" + """Compute T for a single frequency (used for MIC and MIM).""" for time_i in range(C.shape[0]): T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( C[time_i, :n_seeds, :n_seeds], -0.5 @@ -607,6 +651,8 @@ def _mic_mim_compute_t(C, T, n_seeds): C[time_i, n_seeds:, n_seeds:], -0.5 ) + return T + class _MICEst(_MultivariateCohEstBase): """Multivariate imaginary part of coherency (MIC) estimator.""" @@ -889,17 +935,16 @@ def compute_con(self, indices, ranks, n_epochs=1): con_i = 0 for seed_idcs, target_idcs, seed_rank, target_rank in zip( - [indices[0]], [indices[1]], ranks[0], ranks[1]): + indices[0], indices[1], ranks[0], ranks[1]): self._log_connection_number(con_i) + seed_idcs = seed_idcs[seed_idcs != -1] + target_idcs = target_idcs[target_idcs != -1] con_idcs = [*seed_idcs, *target_idcs] - C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - con_seeds = np.arange(len(seed_idcs)) - con_targets = np.arange(len(target_idcs)) + len(seed_idcs) + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] - C_bar = self._csd_svd( - C, con_seeds, con_targets, seed_rank, target_rank) + C_bar = self._csd_svd(C, seed_idcs, seed_rank, target_rank) n_signals = seed_rank + target_rank con_seeds = np.arange(seed_rank) con_targets = np.arange(target_rank) + seed_rank @@ -921,13 +966,13 @@ def compute_con(self, indices, ranks, n_epochs=1): self.reshape_results() - def _csd_svd(self, csd, seeds, targets, seed_rank, target_rank): + def _csd_svd(self, csd, seed_idcs, seed_rank, target_rank): """Dimensionality reduction of CSD with SVD on the covariance.""" # sum over times and epochs to get cov. from CSD cov = csd.sum(axis=(0, 1)) - n_seeds = len(seeds) - n_targets = len(targets) + n_seeds = len(seed_idcs) + n_targets = csd.shape[3] - n_seeds cov_aa = cov[:n_seeds, :n_seeds] cov_bb = cov[n_seeds:, n_seeds:] @@ -1202,7 +1247,7 @@ def _gc_compute_H(A, C, K, z_k, I_n, I_m): See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: 10.1103/PhysRevE.91.040101, Eq. 4. """ - from scipy import linalg # is this necessary??? + from scipy import linalg # XXX: is this necessary??? H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) for t in range(A.shape[0]): H[t] = I_n + np.matmul( @@ -1231,16 +1276,15 @@ class _GCTREst(_GCEstBase): def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, method, mode, window_fun, eigvals, wavelets, - freq_mask, mt_adaptive, idx_map, block_size, - psd, accumulate_psd, con_method_types, - con_methods, n_signals, n_signals_use, - n_times, gc_n_lags, accumulate_inplace=True): + freq_mask, mt_adaptive, idx_map, n_cons, + block_size, psd, accumulate_psd, + con_method_types, con_methods, n_signals, + n_signals_use, n_times, gc_n_lags, + accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" if any(this_method in _multivariate_methods for this_method in method): - n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED n_con_signals = n_signals_use ** 2 else: - n_cons = len(idx_map[0]) n_con_signals = n_cons if wavelets is not None: @@ -1513,10 +1557,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, the indices are for a - single connection between all seeds and all targets. If None, all - connections are computed, unless a Granger causality method is called, - in which case an error is raised. + connectivity. If a multivariate method is called, each array for the + seeds and targets should contain a nested array of channel indices for + the individual connections. If None, connections between all channels + are computed, unless a Granger causality method is called, in which + case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -1582,14 +1627,13 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, con : array | list of array Computed connectivity measure(s). Either an instance of ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of each connectivity dataset is either - (n_signals ** 2, n_freqs) mode: 'multitaper' or 'fourier' - (n_signals ** 2, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is None, or - (n_con, n_freqs) mode: 'multitaper' or 'fourier' - (n_con, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is specified and "n_con = len(indices[0])". If a - multivariate method is called "n_con = 1" even if "indices" is None. + The shape of each connectivity dataset is either: + (n_cons, n_freqs) mode: 'multitaper' or 'fourier'; or + (n_cons, n_freqs, n_times) mode: 'cwt_morlet'. When "indices" is None + and a bivariate method is called, "n_cons = n_signals ** 2", or if a + multivariate method is called "n_cons = 1". When "indices" is + specified, "n_con = len(indices[0])" for bivariate and multivariate + methods. See Also -------- @@ -1635,13 +1679,19 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, are in the same order as defined indices. For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will attempt to be computed (this is - not possible if a Granger causality method is called). If "indices" is - specified, the seeds and targets are treated as a single connection. For - example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, - one would use the same approach as above, however the signals would all be - considered for a single connection and the connectivity scores would have - the shape (1, n_freqs). + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. **Supported Connectivity Measures** @@ -1834,11 +1884,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # check rank input and compute data ranks if necessary if multivariate_con: - rank = _check_rank_input(rank, data, sfreq, indices_use) + rank = _check_rank_input(rank, data, indices_use) else: rank = None gc_n_lags = None + # make sure padded indices are stored in the connectivity object + if multivariate_con and indices is not None: + indices = tuple(np.array(indices_use)) # create a copy + # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, n_tapers) = _assemble_spectral_params( @@ -1848,16 +1902,33 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, cwt_freqs=cwt_freqs, freqs=freqs, freq_mask=freq_mask) # unique signals for which we actually need to compute PSD etc. - sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + if multivariate_con: + sig_idx = np.unique(np.concatenate(np.concatenate( + indices_use))) + sig_idx = sig_idx[sig_idx != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(sig_idx)} + remapping[-1] = -1 + remapped_inds = (indices_use[0].copy(), indices_use[1].copy()) + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + remapped_inds[0][con_i] = np.array([ + remapping[idx] for idx in seed]) + remapped_inds[1][con_i] = np.array([ + remapping[idx] for idx in target]) + con_i += 1 + remapped_sig = [remapping[idx] for idx in sig_idx] + else: + sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) n_signals_use = len(sig_idx) # map indices to unique indices - idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] if multivariate_con: - indices_use = idx_map - idx_map = np.unique([*idx_map[0], *idx_map[1]]) - idx_map = [np.sort(np.repeat(idx_map, len(sig_idx))), - np.tile(idx_map, len(sig_idx))] + indices_use = remapped_inds # use remapped seeds & targets + idx_map = [np.sort(np.repeat(remapped_sig, len(sig_idx))), + np.tile(remapped_sig, len(sig_idx))] + else: + idx_map = [ + np.searchsorted(sig_idx, ind) for ind in indices_use] # allocate space to accumulate PSD if accumulate_psd: @@ -1894,7 +1965,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, call_params = dict( sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, - block_size=block_size, + n_cons=n_cons, block_size=block_size, psd=psd, accumulate_psd=accumulate_psd, mt_adaptive=mt_adaptive, con_method_types=con_method_types, @@ -1978,8 +2049,8 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, this_con = this_con_bands if this_patterns is not None: - patterns_shape = ((2, n_cons, len(indices[0]), n_bands) + - this_patterns.shape[4:]) + patterns_shape = list(this_patterns.shape) + patterns_shape[3] = n_bands this_patterns_bands = np.empty(patterns_shape, dtype=this_patterns.dtype) for band_idx in range(n_bands): @@ -2023,11 +2094,6 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # number of nodes in the original data n_nodes = n_signals - if multivariate_con: - # UNTIL RAGGED ARRAYS SUPPORTED - indices = tuple( - [[np.array(indices_use[0])], [np.array(indices_use[1])]]) - # create a list of connectivity containers conn_list = [] for _con, _patterns, _method in zip(con, patterns, method): @@ -2054,46 +2120,3 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, conn_list = conn_list[0] return conn_list - - -def _check_rank_input(rank, data, sfreq, indices): - """Check the rank argument is appropriate and compute rank if missing.""" - # UNTIL RAGGED ARRAYS SUPPORTED - indices = np.array([[indices[0]], [indices[1]]]) - - if rank is None: - - rank = np.zeros((2, len(indices[0])), dtype=int) - - if isinstance(data, BaseEpochs): - data_arr = data.get_data() - else: - data_arr = data - - for group_i in range(2): - for con_i, con_idcs in enumerate(indices[group_i]): - s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) - rank[group_i][con_i] = np.min( - [np.count_nonzero(epoch >= epoch[0] * 1e-10) - for epoch in s]) - - logger.info('Estimated data ranks:') - con_i = 1 - for seed_rank, target_rank in zip(rank[0], rank[1]): - logger.info(' connection %i - seeds (%i); targets (%i)' - % (con_i, seed_rank, target_rank, )) - con_i += 1 - - rank = tuple((np.array(rank[0]), np.array(rank[1]))) - - else: - for seed_idcs, target_idcs, seed_rank, target_rank in zip( - indices[0], indices[1], rank[0], rank[1]): - if not (0 < seed_rank <= len(seed_idcs) and - 0 < target_rank <= len(target_idcs)): - raise ValueError( - 'ranks for seeds and targets must be > 0 and <= the ' - 'number of channels in the seeds and targets, ' - 'respectively, for each connection') - - return rank diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index fa8cf44d..a436aec8 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -423,7 +423,9 @@ def test_spectral_connectivity_epochs_multivariate(method): trans_bandwidth = 2.0 # Hz delay = 10 # samples (non-zero delay needed for ImCoh and GC to be >> 0) - indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + indices = (np.arange(n_seeds)[np.newaxis, :], + np.arange(n_seeds)[np.newaxis, :] + n_seeds) + n_targets = n_seeds # 15-25 Hz connectivity fstart, fend = 15.0, 25.0 @@ -494,8 +496,17 @@ def test_spectral_connectivity_epochs_multivariate(method): if method in ['mic', 'mim']: con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=None, sfreq=sfreq) - assert (np.array(con.indices).tolist() == - [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + assert con.indices is None + assert con.n_nodes == n_signals + if method == 'mic': + assert np.array(con.attrs['patterns']).shape[2] == n_signals + + # check ragged indices padded correctly + ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=ragged_indices, sfreq=sfreq) + assert np.all(np.array(con.indices) == + np.array([np.array([[0, -1]]), np.array([[1, 2]])])) # check shape of MIC patterns if method == 'mic': @@ -507,12 +518,12 @@ def test_spectral_connectivity_epochs_multivariate(method): if mode == 'cwt_morlet': patterns_shape = ( - (len(indices[0]), len(con.freqs), len(con.times)), - (len(indices[1]), len(con.freqs), len(con.times))) + (n_seeds, len(con.freqs), len(con.times)), + (n_targets, len(con.freqs), len(con.times))) else: patterns_shape = ( - (len(indices[0]), len(con.freqs)), - (len(indices[1]), len(con.freqs))) + (n_seeds, len(con.freqs)), + (n_targets, len(con.freqs))) assert np.shape(con.attrs["patterns"][0][0]) == patterns_shape[0] assert np.shape(con.attrs["patterns"][1][0]) == patterns_shape[1] @@ -532,10 +543,22 @@ def test_spectral_connectivity_epochs_multivariate(method): con = spectral_connectivity_epochs( data, method=method, mode=mode, indices=indices, sfreq=sfreq, rank=rank) - assert (np.shape(con.attrs["patterns"][0][0])[0] == - len(indices[0])) - assert (np.shape(con.attrs["patterns"][1][0])[0] == - len(indices[1])) + assert (np.shape(con.attrs["patterns"][0][0])[0] == n_seeds) + assert (np.shape(con.attrs["patterns"][1][0])[0] == n_targets) + + # check patterns padded correctly + ragged_indices = (np.array([[0]]), np.array([[1, 2]])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=ragged_indices, + sfreq=sfreq) + patterns = np.array(con.attrs["patterns"]) + patterns_shape = ( + (n_seeds, len(con.freqs)), (n_targets, len(con.freqs))) + assert patterns[0, 0].shape == patterns_shape[0] + assert patterns[1, 0].shape == patterns_shape[1] + assert not np.any(np.isnan(patterns[0, 0, 0])) + assert np.all(np.isnan(patterns[0, 0, 1])) + assert not np.any(np.isnan(patterns[1, 0])) def test_multivariate_spectral_connectivity_epochs_regression(): @@ -558,7 +581,7 @@ def test_multivariate_spectral_connectivity_epochs_regression(): data = pd.read_pickle( os.path.join(fpath, 'data', 'example_multivariate_data.pkl')) sfreq = 100 - indices = tuple([[0, 1], [2, 3]]) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) methods = ['mic', 'mim', 'gc', 'gc_tr'] con = spectral_connectivity_epochs( data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, @@ -587,13 +610,21 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) cwt_freqs = np.arange(10, 25 + 1) - # check bad indices with repeated channels + # check bad indices without nested array caught + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=non_nested_indices, + sfreq=sfreq, gc_n_lags=10) + + # check bad indices with repeated channels caught with pytest.raises(ValueError, - match='seed and target indices cannot contain'): - repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) spectral_connectivity_epochs( data, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq, gc_n_lags=10) @@ -644,7 +675,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): sfreq=sfreq, rank=(np.array([2]), np.array([2])), cwt_freqs=cwt_freqs) - # only check these once for speed + # only check these once (e.g. only with multitaper) for speed if method == 'gc' and mode == 'multitaper': # check bad n_lags caught frange = (5, 10) @@ -662,7 +693,7 @@ def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): cwt_freqs=cwt_freqs) # check intersecting indices caught - bad_indices = (np.array([0, 1]), np.array([0, 2])) + bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): spectral_connectivity_epochs(data, method=method, mode=mode, @@ -695,7 +726,7 @@ def test_multivar_spectral_connectivity_parallel(method): n_times = 256 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) spectral_connectivity_epochs( data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, @@ -854,7 +885,7 @@ def test_spectral_connectivity_time_delayed(): trans_bandwidth = 2.0 # Hz delay = 5 # samples (non-zero delay needed for GC to be >> 0) - indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) # 20-30 Hz connectivity fstart, fend = 20.0, 30.0 @@ -1058,7 +1089,8 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): n_times = 500 rng = np.random.RandomState(0) data = rng.randn(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + n_cons = len(indices[0]) freqs = np.arange(10, 25 + 1) con_shape = [1] @@ -1077,34 +1109,60 @@ def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): # check shape of MIC patterns are correct if method == 'mic': - patterns_shape = [len(indices[0])] - if faverage: - patterns_shape.append(1) - else: - patterns_shape.append(len(freqs)) - if not average: - patterns_shape = [n_epochs, *patterns_shape] - patterns_shape = [2, *patterns_shape] - assert np.array(con.attrs['patterns']).shape == tuple(patterns_shape) + for indices_type in ['full', 'ragged']: + if indices_type == 'full': + indices = (np.array([[0, 1]]), np.array([[2, 3]])) + else: + indices = (np.array([[0, 1]]), np.array([[2]])) + max_n_chans = 2 + patterns_shape = [n_cons, max_n_chans] + if faverage: + patterns_shape.append(1) + else: + patterns_shape.append(len(freqs)) + if not average: + patterns_shape = [n_epochs, *patterns_shape] + patterns_shape = [2, *patterns_shape] + con = spectral_connectivity_time( + data, freqs, indices=indices, method=method, sfreq=sfreq, + faverage=faverage, average=average, gc_n_lags=10) -@pytest.mark.parametrize( - 'method', ['mic', 'mim', 'gc', 'gc_tr']) -def test_multivar_spectral_connectivity_time_error_catch(method): + patterns = np.array(con.attrs['patterns']) + # 2 (x epochs) x cons x channels x freqs|fbands + assert (patterns.shape == tuple(patterns_shape)) + if indices_type == 'ragged': + assert not np.any(np.isnan(patterns[0, ..., :, :])) + assert not np.any(np.isnan(patterns[0, ..., 0, :])) + assert np.all(np.isnan(patterns[1, ..., 1, :])) # padded entry + assert np.all(np.array(con.indices) == np.array( + (np.array([[0, 1]]), np.array([[2, -1]])))) + + +@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('mode', ['multitaper', 'cwt_morlet']) +def test_multivar_spectral_connectivity_time_error_catch(method, mode): """Test error catching for time-resolved multivar. connectivity methods.""" sfreq = 50. n_signals = 4 # Do not change! n_epochs = 8 n_times = 256 data = np.random.rand(n_epochs, n_signals, n_times) - indices = (np.arange(0, 2), np.arange(2, 4)) + indices = (np.array([[0, 1]]), np.array([[2, 3]])) freqs = np.arange(10, 25 + 1) - # check bad indices with repeated channels + # check bad indices without nested array caught + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + spectral_connectivity_time(data, freqs, method=method, mode=mode, + indices=non_nested_indices, sfreq=sfreq) + + # check bad indices with repeated channels caught with pytest.raises(ValueError, - match='seed and target indices cannot contain'): - repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) - spectral_connectivity_time(data, freqs, method=method, + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=repeated_indices, sfreq=sfreq) # check mixed methods caught @@ -1112,7 +1170,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method): match='bivariate and multivariate connectivity'): mixed_methods = [method, 'coh'] spectral_connectivity_time(data, freqs, method=mixed_methods, - indices=indices, sfreq=sfreq) + mode=mode, indices=indices, sfreq=sfreq) # check bad rank args caught too_low_rank = (np.array([0]), np.array([0])) @@ -1120,38 +1178,40 @@ def test_multivar_spectral_connectivity_time_error_catch(method): match='ranks for seeds and targets must be'): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, - rank=too_low_rank) + mode=mode, rank=too_low_rank) too_high_rank = (np.array([3]), np.array([3])) with pytest.raises(ValueError, match='ranks for seeds and targets must be'): spectral_connectivity_time( data, freqs, method=method, indices=indices, sfreq=sfreq, - rank=too_high_rank) + mode=mode, rank=too_high_rank) # check all-to-all conn. computed for MIC/MIM when no indices given if method in ['mic', 'mim']: - con = spectral_connectivity_epochs( - data, freqs, method=method, indices=None, sfreq=sfreq) - assert (np.array(con.indices).tolist() == - [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + con = spectral_connectivity_time( + data, freqs, method=method, indices=None, sfreq=sfreq, mode=mode) + assert con.indices is None + assert con.n_nodes == n_signals + if method == 'mic': + assert np.array(con.attrs['patterns']).shape[3] == n_signals if method in ['gc', 'gc_tr']: # check no indices caught with pytest.raises(ValueError, match='indices must be specified'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=None, sfreq=sfreq) # check intersecting indices caught - bad_indices = (np.array([0, 1]), np.array([0, 2])) + bad_indices = (np.array([[0, 1]]), np.array([[0, 2]])) with pytest.raises(ValueError, match='seed and target indices must not intersect'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=bad_indices, sfreq=sfreq) # check bad fmin/fmax caught with pytest.raises(ValueError, match='computing Granger causality on multiple'): - spectral_connectivity_time(data, freqs, method=method, + spectral_connectivity_time(data, freqs, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=(5., 15.), fmax=(15., 30.)) @@ -1159,7 +1219,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method): def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) - n_epochs, n_chs, n_times, sfreq, f = 10, 2, 2000, 1000., 20. + n_epochs, n_chs, n_times, sfreq, f = 10, 3, 2000, 1000., 20. data = rng.randn(n_epochs, n_chs, n_times) sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000) data[:, :, 500:1500] += sig @@ -1171,3 +1231,10 @@ def test_save(tmp_path): epochs, fmin=(4, 8, 13, 30), fmax=(8, 13, 30, 45), faverage=True) conn.save(tmp_path / 'foo.nc') + + # multivariate connectivity + # use ragged indices & MIC to test padding of indices and patterns + indices = (np.array([[0, 1]]), np.array([[2]])) + conn_mvc = spectral_connectivity_epochs( + epochs, method="mic", indices=indices, sfreq=sfreq, fmin=10, fmax=40) + conn_mvc.save(tmp_path / 'foo_mvc.nc') diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7c1aabe6..6b5eb000 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -16,7 +16,7 @@ from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, _check_rank_input) from .smooth import _create_kernel, _smooth_spectra -from ..utils import check_indices, fill_doc +from ..utils import check_indices, check_multivariate_indices, fill_doc _multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] @@ -70,10 +70,11 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None Two arrays with indices of connections for which to compute - connectivity. If a multivariate method is called, the indices are for a - single connection between all seeds and all targets. If None, all - connections are computed, unless a Granger causality method is called, - in which case an error is raised. + connectivity. If a multivariate method is called, each array for the + seeds and targets should contain a nested array of channel indices for + the individual connections. If None, connections between all channels + are computed, unless a Granger causality method is called, in which + case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -144,11 +145,11 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, :class:`EpochSpectralConnectivity`, :class:`SpectralConnectivity` or a list of instances corresponding to connectivity measures if several connectivity measures are specified. - The shape of each connectivity dataset is - (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None`, - (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified - and ``n_nodes = len(indices[0])``, or (n_epochs, 1, 1, n_freqs) when a - multi-variate method is called regardless of "indices". + The shape of each connectivity dataset is (n_epochs, n_cons, n_freqs). + When "indices" is None and a bivariate method is called, + "n_cons = n_signals ** 2", or if a multivariate method is called + "n_cons = 1". When "indices" is specified, "n_con = len(indices[0])" + for bivariate and multivariate methods. See Also -------- @@ -202,13 +203,19 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, scores are in the same order as defined indices. For multivariate methods, this is handled differently. If "indices" is - None, connectivity between all signals will attempt to be computed (this is - not possible if a Granger causality method is called). If "indices" is - specified, the seeds and targets are treated as a single connection. For - example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, - one would use the same approach as above, however the signals would all be - considered for a single connection and the connectivity scores would have - the shape (1, n_freqs). + None, connectivity between all signals will be computed and a single + connectivity spectrum will be returned (this is not possible if a Granger + causality method is called). If "indices" is specified, seed and target + indices for each connection should be specified as nested array-likes. For + example, to compute the connectivity between signals (0, 1) -> (2, 3) and + (0, 1) -> (4, 5), indices should be specified as:: + + indices = (np.array([[0, 1], [0, 1]]), # seeds + np.array([[2, 3], [4, 5]])) # targets + + More information on working with multivariate indices and handling + connections where the number of seeds and targets are not equal can be + found in the :doc:`../auto_examples/handling_ragged_arrays` example. **Supported Connectivity Measures** @@ -398,36 +405,51 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, 'indices must be specified when computing Granger ' 'causality, as all-to-all connectivity is not supported') logger.info('using all indices for multivariate connectivity') - indices_use = (np.arange(n_signals, dtype=int), - np.arange(n_signals, dtype=int)) + indices_use = (np.array([np.arange(n_signals, dtype=np.int32)]), + np.array([np.arange(n_signals, dtype=np.int32)])) else: logger.info('only using indices for lower-triangular matrix') indices_use = np.tril_indices(n_signals, k=-1) else: if multivariate_con: - if ( - len(np.unique(indices[0])) != len(indices[0]) or - len(np.unique(indices[1])) != len(indices[1]) - ): - raise ValueError( - 'seed and target indices cannot contain repeated ' - 'channels for multivariate connectivity') + indices_use = check_multivariate_indices(indices) # pad with -1 if any(this_method in _gc_methods for this_method in method): - if set(indices[0]).intersection(indices[1]): - raise ValueError( - 'seed and target indices must not intersect when ' - 'computing Granger causality') - indices_use = check_indices(indices) - source_idx = indices_use[0] - target_idx = indices_use[1] - n_pairs = len(source_idx) if not multivariate_con else 1 + for seed, target in zip(indices[0], indices[1]): + intersection = np.intersect1d(seed, target) + if np.any(intersection != -1): # ignore padded entries + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') + # make sure padded indices are stored in the connectivity object + indices = tuple(np.array(indices_use)) # create a copy + else: + indices_use = check_indices(indices) + # create copies of indices_use for independent manipulation + source_idx = np.array(indices_use[0]) + target_idx = np.array(indices_use[1]) + n_cons = len(source_idx) # unique signals for which we actually need to compute the CSD of - signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + if multivariate_con: + signals_use = np.unique(np.concatenate(np.concatenate(indices_use))) + signals_use = signals_use[signals_use != -1] + remapping = {ch_i: sig_i for sig_i, ch_i in enumerate(signals_use)} + remapping[-1] = -1 + # multivariate functions expect seed/target remapping + con_i = 0 + for seed, target in zip(indices_use[0], indices_use[1]): + source_idx[con_i] = np.array([remapping[idx] for idx in seed]) + target_idx[con_i] = np.array([remapping[idx] for idx in target]) + con_i += 1 + max_n_channels = len(indices_use[0][0]) + else: + # no indices remapping required for bivariate functions + signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + max_n_channels = len(indices_use[0]) # check rank input and compute data ranks if necessary if multivariate_con: - rank = _check_rank_input(rank, data, sfreq, indices_use) + rank = _check_rank_input(rank, data, indices_use) else: rank = None gc_n_lags = None @@ -479,9 +501,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn = dict() conn_patterns = dict() for m in method: - conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) - conn_patterns[m] = np.full((n_epochs, 2, len(source_idx), n_freqs), - np.nan) + conn[m] = np.zeros((n_epochs, n_cons, n_freqs)) + # patterns shape of [epochs x seeds/targets x cons x channels x freqs] + conn_patterns[m] = np.full( + (n_epochs, 2, n_cons, max_n_channels, n_freqs), np.nan) logger.info('Connectivity computation...') # parameters to pass to the connectivity function @@ -505,8 +528,8 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, if np.isnan(conn_patterns[m]).all(): conn_patterns[m] = None else: - # epochs x 2 x n_channels x n_freqs - conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3)) + # transpose to [seeds/targets x epochs x cons x channels x freqs] + conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3, 4)) if indices is None and not multivariate_con: conn_flat = conn @@ -520,11 +543,6 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn_flat[m].shape[2:]) conn[m] = this_conn - if multivariate_con: - # UNTIL RAGGED ARRAYS SUPPORTED - indices = tuple( - [[np.array(indices_use[0])], [np.array(indices_use[1])]]) - # create the connectivity containers out = [] for m in method: @@ -569,9 +587,9 @@ def _spectral_connectivity(data, method, kernel, foi_idx, Smoothing kernel. foi_idx : array_like, shape (n_foi, 2) Upper and lower bound indices of frequency bands. - source_idx : array_like, shape (n_pairs,) + source_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``target_idx``. - target_idx : array_like, shape (n_pairs,) + target_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``source_idx``. signals_use : list of int The unique signals on which connectivity is to be computed. @@ -608,8 +626,8 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ------- scores : dict Dictionary containing the connectivity estimates corresponding to the - metrics in ``method``. Each element is an array of shape (n_pairs, - n_freqs) or (n_pairs, n_fbands) if ``faverage`` is `True`. + metrics in ``method``. Each element is an array of shape (n_cons, + n_freqs) or (n_cons, n_fbands) if ``faverage`` is `True`. patterns : dict Dictionary containing the connectivity patterns (for reconstructing the @@ -619,7 +637,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, or (2, n_channels, 1) if ``faverage`` is `True`, where 2 corresponds to the seed and target signals (respectively). """ - n_pairs = len(source_idx) + n_cons = len(source_idx) data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': out = tfr_array_morlet( @@ -665,12 +683,12 @@ def _spectral_connectivity(data, method, kernel, foi_idx, scores = {} patterns = {} conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, - signals_use, gc_n_lags, rank, n_jobs, verbose, - n_pairs, faverage, weights, multivariate_con) + signals_use, gc_n_lags, rank, n_jobs, verbose, n_cons, + faverage, weights, multivariate_con) for i, m in enumerate(method): if multivariate_con: scores[m] = conn[0][i] - patterns[m] = conn[1][i][:, 0] if conn[1][i] is not None else None + patterns[m] = conn[1][i] if conn[1][i] is not None else None else: scores[m] = [out[i] for out in conn] patterns[m] = None @@ -699,16 +717,16 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, Smoothing kernel. foi_idx : array_like, shape (n_foi, 2) Upper and lower bound indices of frequency bands. - source_idx : array_like, shape (n_pairs,) + source_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``target_idx``. - target_idx : array_like, shape (n_pairs,) + target_idx : array_like, shape (n_cons,) or (n_cons, n_channels) Defines the signal pairs of interest together with ``source_idx``. signals_use : list of int The unique signals on which connectivity is to be computed. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. - rank : tuple of array + rank : tuple of array of int Ranks to project the seed and target data to. n_jobs : int Number of parallel jobs. @@ -825,18 +843,23 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, return out -def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, - foi_idx, faverage, weights, gc_n_lags, rank, n_jobs): +def _multivariate_con(w, seeds, targets, signals_use, method, kernel, foi_idx, + faverage, weights, gc_n_lags, rank, n_jobs): """Compute spectral connectivity metrics between multiple signals. Parameters ---------- w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) Time-frequency data. - x : int - Channel index. - y : int - Channel index. + seeds : array, shape of (n_cons, n_channels) + Seed channel indices. ``n_channels`` is the largest number of channels + across all connections, with missing entries padded with ``-1``. + targets : array, shape of (n_cons, n_channels) + Target channel indices. ``n_channels`` is the largest number of + channels across all connections, with missing entries padded with + ``-1``. + signals_use : list of int + The unique signals on which connectivity is to be computed. method : str Connectivity method. kernel : array_like, shape (n_sm_fres, n_sm_times) @@ -847,6 +870,13 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, Average over frequency bands. weights : array_like, shape (n_tapers, n_freqs, n_times) | None Multitaper weights. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. + rank : tuple of array, shape of (2, n_cons) + Ranks to project the seed and target data to. + n_jobs : int + Number of jobs to run in parallel. Returns ------- @@ -859,8 +889,10 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, List of connectivity patterns between seed and target signals for each connectivity method. Each element is an array of length 2 corresponding to the seed and target patterns, respectively, each with shape - (n_channels, n_freqs,) or (n_channels, n_fbands) depending on - ``faverage``. + (n_channels, n_freqs) or (n_channels, n_fbands) + depending on ``faverage``. ``n_channels`` is the largest number of + channels across all connections, with missing entries padded with + ``np.nan``. """ csd = [] for x in signals_use: @@ -880,8 +912,7 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, 'gc_tr': _GCTREst} conn = [] for m in method: - # N_CONS = 1 UNTIL RAGGED ARRAYS SUPPORTED - call_params = {'n_signals': len(signals_use), 'n_cons': 1, + call_params = {'n_signals': len(signals_use), 'n_cons': len(seeds), 'n_freqs': csd.shape[1], 'n_times': 0, 'n_jobs': n_jobs} if m in _gc_methods: @@ -895,7 +926,7 @@ def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, scores = [] patterns = [] for con_est in conn: - con_est.compute_con(np.array([source_idx, target_idx]), rank) + con_est.compute_con((seeds, targets), rank) scores.append(con_est.con_scores[..., np.newaxis]) patterns.append(con_est.patterns) if patterns[-1] is not None: diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index c012481c..0549ee43 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -3,11 +3,15 @@ from numpy.testing import assert_array_equal from mne_connectivity import Connectivity -from mne_connectivity.utils import degree, seed_target_indices +from mne_connectivity.utils import (degree, check_indices, + check_multivariate_indices, + seed_target_indices, + seed_target_multivariate_indices) -def test_indices(): - """Test connectivity indexing methods.""" +def test_seed_target_indices(): + """Test indices generation functions.""" + # bivariate indices n_seeds_test = [1, 3, 4] n_targets_test = [2, 3, 200] rng = np.random.RandomState(42) @@ -25,6 +29,69 @@ def test_indices(): for target in targets: assert np.sum(indices[1] == target) == n_seeds + # multivariate indices + # non-ragged indices + seeds = [[0, 1]] + targets = [[2, 3], [3, 4]] + indices = seed_target_multivariate_indices(seeds, targets) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) + # ragged indices + seeds = [[0, 1]] + targets = [[2, 3, 4], [4]] + indices = seed_target_multivariate_indices(seeds, targets) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) + # test error catching + # non-array-like seeds/targets + with pytest.raises(TypeError, + match='`seeds` and `targets` must be array-like'): + seed_target_multivariate_indices(0, 1) + # non-nested seeds/targets + with pytest.raises(TypeError, + match='`seeds` and `targets` must contain nested'): + seed_target_multivariate_indices([0], [1]) + + +def test_check_indices(): + """Test indices checking functions.""" + # bivariate indices + # test error catching + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_tuple_indices = [[0], [1]] + check_indices(non_tuple_indices) + with pytest.raises(ValueError, + match='indices must be a tuple of length 2'): + non_len2_indices = ([0], [1], [2]) + check_indices(non_len2_indices) + with pytest.raises(ValueError, match='Index arrays indices'): + non_equal_len_indices = ([0], [1, 2]) + check_indices(non_equal_len_indices) + + # multivariate indices + # non-ragged indices + seeds = [[0, 1], [0, 1]] + targets = [[2, 3], [3, 4]] + indices = check_multivariate_indices((seeds, targets)) + assert np.all(np.array(indices) == (np.array([[0, 1], [0, 1]]), + np.array([[2, 3], [3, 4]]))) + # ragged indices + seeds = [[0, 1], [0, 1]] + targets = [[2, 3, 4], [4]] + indices = check_multivariate_indices((seeds, targets)) + assert np.all(np.array(indices) == (np.array([[0, 1, -1], [0, 1, -1]]), + np.array([[2, 3, 4], [4, -1, -1]]))) + # test error catching + with pytest.raises(TypeError, + match='multivariate indices must contain array-likes'): + non_nested_indices = (np.array([0, 1]), np.array([2, 3])) + check_multivariate_indices(non_nested_indices) + with pytest.raises(ValueError, + match='multivariate indices cannot contain repeated'): + repeated_indices = (np.array([[0, 1, 1]]), np.array([[2, 2, 3]])) + check_multivariate_indices(repeated_indices) + def test_degree(): """Test degree function.""" diff --git a/mne_connectivity/utils/__init__.py b/mne_connectivity/utils/__init__.py index e82f054b..0df454a4 100644 --- a/mne_connectivity/utils/__init__.py +++ b/mne_connectivity/utils/__init__.py @@ -1,3 +1,4 @@ from .docs import fill_doc -from .utils import (check_indices, degree, seed_target_indices, +from .utils import (check_indices, check_multivariate_indices, degree, + seed_target_indices, seed_target_multivariate_indices, parallel_loop, _prepare_xarray_mne_data_structures) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 5ae94acb..b8216654 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -71,6 +71,47 @@ def check_indices(indices): return indices +def check_multivariate_indices(indices): + """Check indices parameter for multivariate connectivity and pad it. + + Parameters + ---------- + indices : tuple of array-like of array-like of int + Tuple of length 2 containing index pairs. + + Returns + ------- + indices : tuple of array of array of int + The indices padded with the invalid channel index ``-1``. + """ + indices = check_indices(indices) + n_cons = len(indices[0]) + + n_chans = [] + for inds in ([*indices[0], *indices[1]]): + if not isinstance(inds, (np.ndarray, list, tuple)): + raise TypeError( + 'multivariate indices must contain array-likes of channel ' + 'indices for each seed and target') + if len(inds) != len(np.unique(inds)): + raise ValueError( + 'multivariate indices cannot contain repeated channels within ' + 'a seed or target') + n_chans.append(len(inds)) + max_n_chans = np.max(n_chans) + + # pad indices to avoid ragged arrays + padded_indices = (np.full((n_cons, max_n_chans), -1, dtype=np.int32), + np.full((n_cons, max_n_chans), -1, dtype=np.int32)) + con_i = 0 + for seed, target in zip(indices[0], indices[1]): + padded_indices[0][con_i, :len(seed)] = seed + padded_indices[1][con_i, :len(target)] = target + con_i += 1 + + return padded_indices + + def seed_target_indices(seeds, targets): """Generate indices parameter for seed based connectivity analysis. @@ -99,6 +140,60 @@ def seed_target_indices(seeds, targets): return indices +def seed_target_multivariate_indices(seeds, targets): + """Generate indices parameter for multivariate seed-based connectivity. + + Parameters + ---------- + seeds : array-like of array-like of int + Seed indices. + + targets : array-like of array-like of int + Target indices. + + Returns + ------- + indices : tuple of array of array of int + The indices padded with the invalid channel index ``-1``. + """ + array_like = (np.ndarray, list, tuple) + + if ( + not isinstance(seeds, array_like) or + not isinstance(targets, array_like) + ): + raise TypeError('`seeds` and `targets` must be array-like') + + n_chans = [] + for inds in [*seeds, *targets]: + if not isinstance(inds, array_like): + raise TypeError( + '`seeds` and `targets` must contain nested array-likes') + n_chans.append(len(inds)) + max_n_chans = max(n_chans) + n_cons = len(seeds) * len(targets) + + # pad indices to avoid ragged arrays + padded_seeds = np.full((len(seeds), max_n_chans), -1, dtype=np.int32) + padded_targets = np.full((len(targets), max_n_chans), -1, dtype=np.int32) + for con_i, seed in enumerate(seeds): + padded_seeds[con_i, :len(seed)] = seed + for con_i, target in enumerate(targets): + padded_targets[con_i, :len(target)] = target + + # create final indices + indices = (np.zeros((n_cons, max_n_chans), dtype=np.int32), + np.zeros((n_cons, max_n_chans), dtype=np.int32)) + con_i = 0 + for seed in padded_seeds: + for target in padded_targets: + indices[0][con_i] = seed + indices[1][con_i] = target + con_i += 1 + + return indices + + def degree(connectivity, threshold_prop=0.2): """Compute the undirected degree of a connectivity matrix.