From 627c6615f684a0b5ce68b809012ddfe528738182 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Tue, 10 Jan 2023 16:08:24 +0200 Subject: [PATCH] [ENH] Multiple improvements to spectral_connectivity_time: ciPLV, and efficient computation of multiple metrics (#115) * Add ciPLV: Add the corrected imaginary Phase-Locking-Value into the list of available connectivity metrics. * Speed up computation: All connectivity measures are now computed with only a single computation of pairwise cross spectrum. * Add the option to specify freqs in all modes: In some scenarios, users might want to specify the frequencies for time-frequency decomposition also when using multitapering. These changes allow users to specify the 'freqs' parameter to override the automatically determined frequencies. * BUG: Average over CSD instead of connectivity * Add option to use part of signal as padding: This adds the option to use the edges of the signal at each epoch as padding. The purpose of this is to avoid edge effects generated by the time-frequency transformation methods. * Fix test bug, use 'freqs' instead of 'cwt_freqs' * Fix bug with dpss windows: Sym is not a parameter of dpss_windows. (But is one of the underlying scipy.signal.dpss) * Only show progress bar if verbosity level is DEBUG: This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with MNE-Python, where progress bars are not shown at INFO or higher logging levels. Rendering the progress bar regardless of logging levels has the potential to cause unnecessary clutter in users' log files. * Require freqs in all tfr modes The user is required to specify the wavelet central frequencies in both multitaper and cwt_morlet tfr mode. The reasoning is that the underlying tfr implementations are very similar. This is in contrast to spectral_connectivity_epochs, where multitaper assumes that the spectrum is stationary and therefore no wavelets are used. * Require mne>=1.3 Signed-off-by: Adam Li Co-authored-by: Adam Li --- doc/whats_new.rst | 2 + .../spectral/tests/test_spectral.py | 86 +++- mne_connectivity/spectral/time.py | 467 +++++++++++------- requirements.txt | 2 +- 4 files changed, 350 insertions(+), 207 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 503ab0af..758426d2 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,6 +29,8 @@ Enhancements - Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). - Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). - Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). +- Add the ``ciPLV`` method in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`). +- Add the option to use the edges of each epoch as padding in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`). Bug ~~~ diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 8b4c71a8..3286bba8 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -472,7 +472,7 @@ def test_epochs_tmin_tmax(kind): assert len(w) == 1 # just one even though there were multiple epochs -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) @@ -504,11 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) - cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) - con = spectral_connectivity_time(data, method=method, mode=mode, + freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) + con = spectral_connectivity_time(data, freqs, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, - cwt_freqs=cwt_freqs, n_jobs=1, + n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -526,12 +526,13 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): assert np.all(con_matrix) <= 0.5 -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( - 'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) -def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): + 'freqs', [[8., 10.], [8, 10], 10., 10]) +@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) +def test_spectral_connectivity_time_freqs(method, freqs, mode): """Test time-resolved spectral connectivity with int and float values for - cwt_freqs.""" + freqs.""" rng = np.random.default_rng(0) n_epochs = 5 n_channels = 3 @@ -552,10 +553,10 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): data[i, c] = np.squeeze(np.sin(x)) # the frequency band should contain the frequency at which there is a # hypothesized "connection" - con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', - sfreq=sfreq, fmin=np.min(cwt_freqs), - fmax=np.max(cwt_freqs), - cwt_freqs=cwt_freqs, n_jobs=1, + con = spectral_connectivity_time(data, freqs, method=method, + mode=mode, sfreq=sfreq, + fmin=np.min(freqs), + fmax=np.max(freqs), n_jobs=1, faverage=True, average=True, sm_times=0) assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] @@ -588,12 +589,12 @@ def test_spectral_connectivity_time_resolved(method, mode): info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') data = EpochsArray(data, info) - # define some frequencies for cwt + # define some frequencies for tfr freqs = np.arange(3, 20.5, 1) # run connectivity estimation con = spectral_connectivity_time( - data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode, + data, freqs, sfreq=sfreq, method=method, mode=mode, n_cycles=5) assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ @@ -613,6 +614,63 @@ def test_spectral_connectivity_time_resolved(method, mode): for idx, jdx in triu_inds) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize('padding', [0, 1, 5]) +def test_spectral_connectivity_time_padding(method, mode, padding): + """Test time-resolved spectral connectivity with padding.""" + sfreq = 50. + n_signals = 3 + n_epochs = 2 + n_times = 300 + trans_bandwidth = 2. + tmin = 0. + tmax = (n_times - 1) / sfreq + # 5Hz..15Hz + fstart, fend = 5.0, 15.0 + data, _ = create_test_dataset( + sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times, + tmin=tmin, tmax=tmax, + fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth) + ch_names = np.arange(n_signals).astype(str).tolist() + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg') + data = EpochsArray(data, info) + + # define some frequencies for tfr + freqs = np.arange(3, 20.5, 1) + + # run connectivity estimation + if padding == 5: + with pytest.raises(ValueError, match='Padding cannot be larger than ' + 'half of data length'): + con = spectral_connectivity_time( + data, freqs, sfreq=sfreq, method=method, mode=mode, + n_cycles=5, padding=padding) + return + else: + con = spectral_connectivity_time( + data, freqs, sfreq=sfreq, method=method, mode=mode, + n_cycles=5, padding=padding) + + assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) + assert con.get_data(output='dense').shape == \ + (n_epochs, n_signals, n_signals, len(con.freqs)) + + # test the simulated signal + triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T + + # average over frequencies + conn_data = con.get_data(output='dense').mean(axis=-1) + + # the indices at which there is a correlation should be greater + # then the rest of the components + for epoch_idx in range(n_epochs): + high_conn_val = conn_data[epoch_idx, 0, 1] + assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx] + for idx, jdx in triu_inds) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d5262667..8658814b 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -9,46 +9,49 @@ from mne.parallel import parallel_func from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper, dpss_windows) -from mne.utils import (logger, warn, verbose) +from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import _compute_freqs, _compute_freq_mask +from .epochs import _compute_freq_mask from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, fill_doc @verbose @fill_doc -def spectral_connectivity_time(data, method='coh', average=False, +def spectral_connectivity_time(data, freqs, method='coh', average=False, indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, - sm_freqs=1, sm_kernel='hanning', + sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, - cwt_freqs=None, n_cycles=7, decim=1, - n_jobs=1, verbose=None): - """Compute frequency- and time-frequency-domain connectivity measures. + n_cycles=7, decim=1, n_jobs=1, verbose=None): + """Compute time-frequency-domain connectivity measures. - This method computes time-resolved connectivity measures from epoched data. + This function computes spectral connectivity over time from epoched data. + The data may consist of a single epoch. The connectivity method(s) are specified using the ``method`` parameter. - All methods are based on estimates of the cross- and power spectral - densities (CSD/PSD) Sxy and Sxx, Syy. + All methods are based on time-resolved estimates of the cross- and + power spectral densities (CSD/PSD) Sxy and Sxx, Syy. Parameters ---------- data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. + freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only the frequencies within the range specified by ``fmin`` and + ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be - ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: - - * 'coh' : Coherence - * 'plv' : Phase-Locking Value (PLV) - * 'sxy' : Cross-spectrum - * 'pli' : Phase-Lag Index - * 'wpli': Weighted Phase-Lag Index + ``['coh', 'plv', 'ciplv', 'pli', 'wpli']``. These are: + * 'coh' : Coherence + * 'plv' : Phase-Locking Value (PLV) + * 'ciplv' : Corrected imaginary Phase-Locking Value + * 'pli' : Phase-Lag Index + * 'wpli' : Weighted Phase-Lag Index average : bool - Average connectivity scores over epochs. If True, output will be + Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None @@ -61,12 +64,11 @@ def spectral_connectivity_time(data, method='coh', average=False, fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower - bounds. If `None`, the frequency corresponding to an epoch length of - 5 cycles is used. + bounds. If `None`, the lowest frequency in ``freqs`` is used. fmax : float | tuple of float | None The upper frequency of interest. Multiple bands are defined using a tuple, e.g. ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper - bounds. If `None`, ``sfreq/2`` is used. + bounds. If `None`, the highest frequency in ``freqs`` is used. fskip : int Omit every ``(fskip + 1)``-th frequency bin to decimate in frequency domain. @@ -82,6 +84,9 @@ def spectral_connectivity_time(data, method='coh', average=False, is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} Smoothing kernel type. Choose either 'square' or 'hanning'. + padding : float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. See Notes for more information. mode : str Time-frequency decomposition method. Can be either: 'multitaper', or 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and @@ -93,11 +98,6 @@ def spectral_connectivity_time(data, method='coh', average=False, bandwidth (thus the frequency resolution) and the number of good tapers. See :func:`mne.time_frequency.tfr_array_multitaper` documentation. - cwt_freqs : array_like - Array of frequencies of interest for time-frequency decomposition. - Only used in 'cwt_morlet' mode. Only the frequencies within - the range specified by ``fmin`` and ``fmax`` are used. Required if - ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. n_cycles : float | array_like of float Number of cycles in the wavelet, either a fixed number or one per frequency. The number of cycles ``n_cycles`` and the frequencies of @@ -150,6 +150,14 @@ def spectral_connectivity_time(data, method='coh', average=False, using a weighted average, where the weights correspond to the concentration ratios between the DPSS windows. + Spectral estimation using multitaper or Morlet wavelets introduces edge + effects that depend on the length of the wavelet. To remove edge effects, + the parameter ``padding`` can be used to prune the edges of the signal. + Please see the documentation of + :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for details on wavelet length + (i.e., time window length). + By default, the connectivity between all signals is computed (only connections corresponding to the lower-triangular part of the connectivity matrix). If one is only interested in the connectivity @@ -184,7 +192,12 @@ def spectral_connectivity_time(data, method='coh', average=False, PLV = |E[Sxy/|Sxy|]| - 'sxy' : Cross spectrum Sxy + 'ciplv' : Corrected imaginary PLV (icPLV) :footcite:`BrunaEtAl2018` + given by:: + + |E[Im(Sxy/|Sxy|)]| + ciPLV = ------------------------------------ + sqrt(1 - |E[real(Sxy/|Sxy|)]| ** 2) 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: @@ -256,25 +269,13 @@ def spectral_connectivity_time(data, method='coh', average=False, if isinstance(method, str): method = [method] - # check that fmin corresponds to at least 5 cycles - dur = float(n_times) / sfreq - five_cycle_freq = 5. / dur + # defaults for fmin and fmax if fmin is None: - # use the 5 cycle freq. as default - fmin = five_cycle_freq - logger.info(f'Fmin was not specified. Using fmin={fmin:.2f}, which ' - 'corresponds to at least five cycles.') - else: - if np.any(fmin < five_cycle_freq): - warn('fmin=%0.3f Hz corresponds to %0.3f < 5 cycles ' - 'based on the epoch length %0.3f sec, need at least %0.3f ' - 'sec epochs or fmin=%0.3f. Spectrum estimate will be ' - 'unreliable.' % (np.min(fmin), dur * np.min(fmin), dur, - 5. / np.min(fmin), five_cycle_freq)) + fmin = np.min(freqs) + logger.info('Fmin was not specified. Using fmin=min(freqs)') if fmax is None: - fmax = sfreq / 2 - logger.info(f'Fmax was not specified. Using fmax={fmax:.2f}, which ' - f'corresponds to Nyquist.') + fmax = np.max(freqs) + logger.info('Fmax was not specified. Using fmax=max(freqs).') fmin = np.array((fmin,), dtype=float).ravel() fmax = np.array((fmax,), dtype=float).ravel() @@ -308,24 +309,30 @@ def spectral_connectivity_time(data, method='coh', average=False, target_idx = indices_use[1] n_pairs = len(source_idx) - # check cwt_freqs - if cwt_freqs is not None: - # check for single frequency - if isinstance(cwt_freqs, (int, float)): - cwt_freqs = [cwt_freqs] - # array conversion - cwt_freqs = np.asarray(cwt_freqs) - # check order for multiple frequencies - if len(cwt_freqs) >= 2: - delta_f = np.diff(cwt_freqs) - increase = np.all(delta_f > 0) - assert increase, "Frequencies should be in increasing order" - - # compute frequencies to analyze based on number of samples, - # sampling rate, specified wavelet frequencies and mode - freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) - - # compute the mask based on specified min/max and decimation factor + # check freqs + if isinstance(freqs, (int, float)): + freqs = [freqs] + # array conversion + freqs = np.asarray(freqs) + # check order for multiple frequencies + if len(freqs) >= 2: + delta_f = np.diff(freqs) + increase = np.all(delta_f > 0) + assert increase, "Frequencies should be in increasing order" + + # check that freqs corresponds to at least n_cycles cycles + dur = float(n_times) / sfreq + cycle_freq = n_cycles / dur + if np.any(freqs < cycle_freq): + raise ValueError('At least one value in n_cycles corresponds to a' + 'wavelet longer than the signal. Use less cycles, ' + 'higher frequencies, or longer epochs.') + # check for Nyquist + if np.any(freqs > sfreq / 2): + raise ValueError(f'Frequencies {freqs[freqs > sfreq / 2]} Hz are ' + f'larger than Nyquist = {sfreq / 2:.2f} Hz') + + # compute frequency mask based on specified min/max and decimation factor freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) # the frequency points where we compute connectivity @@ -357,15 +364,14 @@ def spectral_connectivity_time(data, method='coh', average=False, source_idx=source_idx, target_idx=target_idx, mode=mode, sfreq=sfreq, freqs=freqs, faverage=faverage, n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, - decim=decim, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, + decim=decim, padding=padding, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, verbose=verbose) for epoch_idx in np.arange(n_epochs): - epoch_idx = [epoch_idx] - conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) + logger.info(f' Processing epoch {epoch_idx+1} / {n_epochs} ...') + conn_tr = _spectral_connectivity(data[epoch_idx], **call_params) for m in method: - conn[m][epoch_idx, ...] = np.stack(conn_tr[m], - axis=1).squeeze(axis=-1) + conn[m][epoch_idx] = np.stack(conn_tr[m], axis=0) if indices is None: conn_flat = conn @@ -374,7 +380,7 @@ def spectral_connectivity_time(data, method='coh', average=False, this_conn = np.zeros((n_epochs, n_signals, n_signals) + conn_flat[m].shape[2:], dtype=conn_flat[m].dtype) - this_conn[:, source_idx, target_idx] = conn_flat[m][:, ...] + this_conn[:, source_idx, target_idx] = conn_flat[m] this_conn = this_conn.reshape((n_epochs, n_signals ** 2,) + conn_flat[m].shape[2:]) conn[m] = this_conn @@ -404,13 +410,55 @@ def spectral_connectivity_time(data, method='coh', average=False, def _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, mode, sfreq, freqs, faverage, n_cycles, - mt_bandwidth, decim, kw_cwt, kw_mt, + mt_bandwidth, decim, padding, kw_cwt, kw_mt, n_jobs, verbose): """Estimate time-resolved connectivity for one epoch. - See spectral_connectivity_epochs.""" - n_pairs = len(source_idx) + Parameters + ---------- + data : array_like, shape (n_channels, n_times) + Time-series data. + method : list of str + List of connectivity metrics to compute. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + source_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``target_idx``. + target_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``source_idx``. + mode : str + Time-frequency transformation method. + sfreq : float + Sampling frequency. + freqs : array_like + Array of frequencies of interest for time-frequency decomposition. + Only the frequencies within the range specified by ``fmin`` and + ``fmax`` are used. + faverage : bool + Average over frequency bands. + n_cycles : float | array_like of float + Number of cycles in the wavelet, either a fixed number or one per + frequency. + mt_bandwidth : float | None + Multitaper time-bandwidth. + decim : int + Decimation factor after time-frequency + decomposition. + padding : float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. + Returns + ------- + this_conn : list of array + List of connectivity estimates corresponding to the metrics in + ``method``. Each element is an array of shape (n_pairs, n_freqs) or + (n_pairs, n_fbands) if ``faverage`` is `True`. + """ + n_pairs = len(source_idx) + data = np.expand_dims(data, axis=0) if mode == 'cwt_morlet': out = tfr_array_morlet( data, sfreq, freqs, n_cycles=n_cycles, output='complex', @@ -438,16 +486,25 @@ def _spectral_connectivity(data, method, kernel, foi_idx, else: raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") + out = np.squeeze(out, axis=0) + + if padding: + if padding < 0: + raise ValueError(f'Padding cannot be negative, got {padding}.') + if padding >= data.shape[-1] / sfreq / 2: + raise ValueError(f'Padding cannot be larger than half of data ' + f'length, got {padding}.') + pad_idx = int(np.floor(padding * sfreq / decim)) + out = out[..., pad_idx:-pad_idx] + weights = weights[..., pad_idx:-pad_idx] if weights is not None \ + else None + # compute for each connectivity method this_conn = {} - conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, - 'wpli': _wpli} - for m in method: - c_func = conn_func[m] - this_conn[m] = c_func(out, kernel, foi_idx, source_idx, - target_idx, n_jobs=n_jobs, - verbose=verbose, total=n_pairs, - faverage=faverage, weights=weights) + conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, + n_jobs, verbose, n_pairs, faverage, weights) + for i, m in enumerate(method): + this_conn[m] = [out[i] for out in conn] return this_conn @@ -458,141 +515,167 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### ############################################################################### -def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise coherence. +def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, + verbose, total, faverage, weights): + """Compute spectral connectivity in parallel. - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data (complex signal). + method : list of str + List of connectivity metrics to compute. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + source_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``target_idx``. + target_idx : array_like, shape (n_pairs,) + Defines the signal pairs of interest together with ``source_idx``. + n_jobs : int + Number of parallel jobs. + total : int + Number of pairs of signals. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) + Multitaper weights. - if weights is not None: - psd = weights * w - psd = psd * np.conj(psd) - psd = psd.real.sum(axis=2) - psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) - else: - psd = w.real ** 2 + w.imag ** 2 - psd = np.squeeze(psd, axis=2) - - # smooth the psd - psd = _smooth_spectra(psd, kernel) - - def pairwise_coh(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - s_xx = psd[:, w_x] - s_yy = psd[:, w_y] - out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ - np.sqrt(s_xx.mean(axis=-1, keepdims=True) * - s_yy.mean(axis=-1, keepdims=True)) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) + Returns + ------- + out : array_like, shape (n_pairs, n_methods, n_freqs_out) + Connectivity estimates for each signal pair, method, and frequency or + frequency band. + """ + if 'coh' in method: + # psd + if weights is not None: + psd = weights * w + psd = psd * np.conj(psd) + psd = psd.real.sum(axis=1) + psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) else: - return out - - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_coh, n_jobs=n_jobs, verbose=verbose, total=total) + psd = w.real ** 2 + w.imag ** 2 + psd = np.squeeze(psd, axis=1) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + # smooth + psd = _smooth_spectra(psd, kernel) + else: + psd = None + # only show progress if verbosity level is DEBUG + if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10: + total = None -def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise phase-locking value. + # define the function to compute in parallel + parallel, my_pairwise_con, n_jobs = parallel_func( + _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_plv(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - exp_dphi = s_xy / np.abs(s_xy) - exp_dphi = _smooth_spectra(exp_dphi, kernel) - # mean over time - exp_dphi_mean = exp_dphi.mean(axis=-1, keepdims=True) - out = np.abs(exp_dphi_mean) - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out + return parallel( + my_pairwise_con(w, psd, s, t, method, kernel, + foi_idx, faverage, weights) + for s, t in zip(source_idx, target_idx)) - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_plv, n_jobs=n_jobs, verbose=verbose, total=total) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, + faverage, weights): + """Compute spectral connectivity metrics between two signals. + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data. + psd : array_like, shape (n_chans, n_freqs, n_times) + Power spectrum between signals ``x`` and ``y``. + x : int + Channel index. + y : int + Channel index. + method : str + Connectivity method. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) | None + Multitaper weights. -def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise phase-lag index. + Returns + ------- + out : list + List of connectivity estimates between signals ``x`` and ``y`` + corresponding to the methods in ``method``. Each element is an array + with shape (n_freqs,) or (n_fbands) depending on ``faverage``. + """ + w_x, w_y = w[x], w[y] + if weights is not None: + s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) + else: + s_xy = w_x * np.conj(w_y) + s_xy = np.squeeze(s_xy, axis=0) + s_xy = _smooth_spectra(s_xy, kernel) + out = [] + conn_func = {'plv': _plv, 'ciplv': _ciplv, 'pli': _pli, 'wpli': _wpli, + 'coh': _coh} + for m in method: + if m == 'coh': + s_xx = psd[x] + s_yy = psd[y] + out.append(conn_func[m](s_xx, s_yy, s_xy)) + else: + out.append(conn_func[m](s_xy)) - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_pli(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - out = np.abs(np.mean(np.sign(np.imag(s_xy)), - axis=-1, keepdims=True)) + for i, _ in enumerate(out): # mean inside frequency sliding window (if needed) if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out + out[i] = _foi_average(out[i], foi_idx) + # squeeze time dimension + out[i] = out[i].squeeze(axis=-1) - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_pli, n_jobs=n_jobs, verbose=verbose, total=total) + return out - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _plv(s_xy): + s_xy = s_xy / np.abs(s_xy) + plv = np.abs(s_xy.mean(axis=-1, keepdims=True)) + return plv -def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise weighted phase-lag index. - Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, - n_times).""" - def pairwise_wpli(w_x, w_y): - s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) - s_xy = _smooth_spectra(s_xy, kernel) - con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) - con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) - out = con_num / con_den - # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out +def _ciplv(s_xy): + s_xy = s_xy / np.abs(s_xy) + rplv = np.abs(np.mean(np.real(s_xy), axis=-1, keepdims=True)) + iplv = np.abs(np.mean(np.imag(s_xy), axis=-1, keepdims=True)) + ciplv = iplv / (np.sqrt(1 - rplv ** 2)) + return ciplv - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_wpli, n_jobs=n_jobs, verbose=verbose, total=total) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _pli(s_xy): + pli = np.abs(np.mean(np.sign(np.imag(s_xy)), + axis=-1, keepdims=True)) + return pli -def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage, weights): - """Pairwise cross-spectra.""" - def pairwise_cs(w_x, w_y): - out = _compute_csd(w[:, w_y], w[:, w_x], weights) - out = _smooth_spectra(out, kernel) - if isinstance(foi_idx, np.ndarray) and faverage: - return _foi_average(out, foi_idx) - else: - return out +def _wpli(s_xy): + con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) + wpli = con_num / con_den + return wpli - # define the function to compute in parallel - parallel, p_fun, n_jobs = parallel_func( - pairwise_cs, n_jobs=n_jobs, verbose=verbose, total=total) - return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _coh(s_xx, s_yy, s_xy): + con_num = np.abs(s_xy.mean(axis=-1, keepdims=True)) + con_den = np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) + coh = con_num / con_den + return coh def _compute_csd(x, y, weights): - """Compute cross spectral density of signals x and y.""" + """Compute cross spectral density between signals x and y.""" if weights is not None: s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3) s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3) @@ -609,15 +692,15 @@ def _foi_average(conn, foi_idx): Parameters ---------- - conn : np.ndarray - Array of shape (..., n_freqs, n_times) - foi_idx : array_like - Array of indices describing frequency bounds of shape (n_foi, 2) + conn : array_like, shape (..., n_freqs, n_times) + Connectivity estimate array. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower frequency bounds of each frequency band. Returns ------- - conn_f : np.ndarray - Array of shape (..., n_foi, n_times) + conn_f : np.ndarray, shape (..., n_fbands, n_times) + Connectivity estimate array, averaged within frequency bands. """ # get the number of foi n_foi = foi_idx.shape[0] diff --git a/requirements.txt b/requirements.txt index 067ea5c7..a564429a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy scipy -mne>=1.1 +mne>=1.3 xarray netCDF4 h5netcdf