From 41bd87b97e5b6edee61fbdf72edebbd8265f7843 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 31 Aug 2022 12:14:58 +0300 Subject: [PATCH 01/47] FIX: compute connectivity over multiple tapers when mode='multitaper', then average FIX: add working support for computation of multiple connectivity metrics at once, as indicated by existing docstring FIX: correct calculation of PLV and coherence connectivity metrics FIX: block_size parameter now actually corresponds to the size of blocks instead of number of blocks ENH: add PLI and wPLI connectivity metrics ENH: improve docstring and typechecks in code ENH: streamline the public API with mne_connectivity.spectral_connectivity_epochs ENH: enable averaging connectivity results over frequencies and epochs --- mne_connectivity/spectral/time.py | 392 ++++++++++++++++++++---------- 1 file changed, 263 insertions(+), 129 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index f1a53bce..4ffb77eb 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -1,29 +1,32 @@ # Authors: Adam Li +# Santeri Ruuskanen # # License: BSD (3-clause) import numpy as np import xarray as xr +from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper) from mne.utils import logger -from ..base import (EpochSpectroTemporalConnectivity) +from ..base import (SpectralConnectivity, EpochSpectralConnectivity) +from .epochs import _compute_freqs, _compute_freq_mask from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, fill_doc @fill_doc -def spectral_connectivity_time(data, names=None, method='coh', indices=None, - sfreq=2 * np.pi, foi=None, sm_times=.5, +def spectral_connectivity_time(data, names=None, method='coh', average=False, + indices=None, sfreq=2 * np.pi, fmin=None, + fmax=None, fskip=0, faverage=False, sm_times=.5, sm_freqs=1, sm_kernel='hanning', mode='cwt_morlet', mt_bandwidth=None, - freqs=None, n_cycles=7, decim=1, - block_size=None, n_jobs=1, - verbose=None): + cwt_freqs=None, n_cycles=7, decim=1, + block_size=1000, n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. - This method computes single-Epoch time-resolved spectral connectivity. + This method computes time-resolved connectivity measures for Epochs. The connectivity method(s) are specified using the "method" parameter. All methods are based on estimates of the cross- and power spectral @@ -31,7 +34,7 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, Parameters ---------- - data : Epochs + data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. %(names)s method : str | list of str @@ -43,16 +46,31 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, * 'sxy' : Cross-spectrum By default, the coherence is used. + average : bool + Average connectivity scores over Epochs. If True, output will be + an instance of ``SpectralConnectivity`` , otherwise + ``EpochSpectralConnectivity``. indices : tuple of array | None Two arrays with indices of connections for which to compute connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. If None, all connections are computed. sfreq : float The sampling frequency. - foi : array_like | None - Extract frequencies of interest. This parameters should be an array of - shapes (n_foi, 2) defining where each band of interest start and - finish. + fmin : float | tuple of float + The lower frequency of interest. Multiple bands are defined using + a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. + If None the frequency corresponding to an epoch length of 5 cycles + is used. + fmax : float | tuple of float + The upper frequency of interest. Multiple bands are dedined using + a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. + fskip : int + Omit every "(fskip + 1)-th" frequency bin to decimate in frequency + domain. + faverage : bool + Average connectivity scores for each frequency band. If True, + the output freqs will be a list with arrays of the frequencies + that were averaged. sm_times : float Amount of time to consider for the temporal smoothing in seconds. By default, 0.5 sec smoothing is used. @@ -67,9 +85,9 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, mt_bandwidth : float | None The bandwidth of the multitaper windowing function in Hz. Only used in 'multitaper' mode. - freqs : array - Array of frequencies of interest for use in time-frequency - decomposition method (specified by ``mode``). + cwt_freqs : array + Array of frequencies of interest for time-frequency decomposition. + Only used in 'cwt_morlet' mode. n_cycles : float | array of float Number of cycles for use in time-frequency decomposition method (specified by ``mode``). Fixed number or one per frequency. @@ -78,7 +96,7 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, returns tfr[…, decim]. block_size : int - How many connections to compute at once (higher numbers are faster + How many epochs to compute at once (higher numbers are faster but require more memory). n_jobs : int How many epochs to process in parallel. @@ -86,16 +104,15 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, Returns ------- - con : array | instance of Connectivity - Computed connectivity measure(s). Either an instance of - ``SpectralConnectivity`` or ``SpectroTemporalConnectivity``. - The shape of each connectivity dataset is either - (n_signals ** 2, n_freqs) mode: 'multitaper' or 'fourier' - (n_signals ** 2, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is None, or - (n_con, n_freqs) mode: 'multitaper' or 'fourier' - (n_con, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is specified and "n_con = len(indices[0])". + con : instance of Connectivity | list + Computed connectivity measure(s). An instance of + ``EpochSpectralConnectivity``, ``SpectralConnectivity`` + or a list of instances corresponding to connectivity measures if + several connectivity measures are specified. + The shape of each connectivity dataset is + (n_epochs, n_signals, n_signals, n_freqs) when indices is None + and (n_epochs, n_nodes, n_nodes, n_freqs) when "indices" is specified + and "n_nodes = len(indices[0])". See Also -------- @@ -113,25 +130,47 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, events = None event_id = None # extract data from Epochs object - names = data.ch_names - times = data.times # input times for Epochs input type - sfreq = data.info['sfreq'] - events = data.events - event_id = data.event_id - n_epochs, n_signals, n_times = data.get_data().shape - # Extract metadata from the Epochs data structure. - # Make Annotations persist through by adding them to the metadata. - metadata = data.metadata - if metadata is None: - annots_in_metadata = False + if isinstance(data, BaseEpochs): + names = data.ch_names + times = data.times # input times for Epochs input type + sfreq = data.info['sfreq'] + events = data.events + event_id = data.event_id + n_epochs, n_signals, n_times = data.get_data().shape + # Extract metadata from the Epochs data structure. + # Make Annotations persist through by adding them to the metadata. + metadata = data.metadata + if metadata is None: + annots_in_metadata = False + else: + annots_in_metadata = all( + name not in metadata.columns for name in [ + 'annot_onset', 'annot_duration', 'annot_description']) + if hasattr(data, 'annotations') and not annots_in_metadata: + data.add_annotations_to_metadata(overwrite=True) + metadata = data.metadata + data = data.get_data() else: - annots_in_metadata = all( - name not in metadata.columns for name in [ - 'annot_onset', 'annot_duration', 'annot_description']) - if hasattr(data, 'annotations') and not annots_in_metadata: - data.add_annotations_to_metadata(overwrite=True) - metadata = data.metadata - data = data.get_data() + data = np.asarray(data) + n_epochs, n_signals, n_times = data.shape + times = np.arange(0, n_times) + names = np.arange(0, n_signals) + metadata = None + + # check that method is a list + if isinstance(method, str): + method = [method] + # check that fmin and fmax are lists + if fmin is None: + fmin = 1 + if fmax is None: + fmax = sfreq / 2 + fmin = np.array((fmin,), dtype=float).ravel() + fmax = np.array((fmax,), dtype=float).ravel() + if len(fmin) != len(fmax): + raise ValueError('fmin and fmax must have the same length') + if np.any(fmin > fmax): + raise ValueError('fmax must be larger than fmin') # convert kernel width in time to samples if isinstance(sm_times, (int, float)): @@ -166,48 +205,71 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, n_pairs = len(source_idx) # frequency checking - if freqs is not None: + if cwt_freqs is not None: # check for single frequency - if isinstance(freqs, (int, float)): - freqs = [freqs] + if isinstance(cwt_freqs, (int, float)): + cwt_freqs = [cwt_freqs] # array conversion - freqs = np.asarray(freqs) + cwt_freqs = np.asarray(cwt_freqs) # check order for multiple frequencies - if len(freqs) >= 2: - delta_f = np.diff(freqs) + if len(cwt_freqs) >= 2: + delta_f = np.diff(cwt_freqs) increase = np.all(delta_f > 0) assert increase, "Frequencies should be in increasing order" - # frequency mean - if foi is None: - foi_idx = foi_s = foi_e = None - f_vec = freqs - else: - _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), - coords=(freqs,)) - foi_s = _f.sel(freqs=foi[:, 0], method='nearest').data - foi_e = _f.sel(freqs=foi[:, 1], method='nearest').data - foi_idx = np.c_[foi_s, foi_e] - f_vec = freqs[foi_idx].mean(1) + # compute frequencies to analyze based on number of samples, + # sampling rate, specified wavelet frequencies and mode + freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) + + if fmin is not None and fmax is not None: + # compute the mask based on specified min/max and decimation factor + freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) + + # the frequency points where we compute connectivity + freqs = freqs[freq_mask] + + # frequency mean + if fmin is None or fmax is None: + foi_idx = None + f_vec = freqs + else: + _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), + coords=(freqs,)) + foi_s = _f.sel(freqs=fmin, method='nearest').data + foi_e = _f.sel(freqs=fmax, method='nearest').data + foi_idx = np.c_[foi_s, foi_e] + f_vec = freqs[foi_idx].mean(1) + + if faverage: + n_freqs = len(fmin) + out_freqs = f_vec + else: + n_freqs = len(freqs) + out_freqs = freqs # build block size indices + if block_size > n_epochs: + block_size = n_epochs + if isinstance(block_size, int) and (block_size > 1): - blocks = np.array_split(np.arange(n_epochs), block_size) + n_blocks = n_epochs // block_size + n_epochs % block_size + blocks = np.array_split(np.arange(n_epochs), n_blocks) else: blocks = [np.arange(n_epochs)] - n_freqs = len(f_vec) + # compute connectivity on blocks of trials + conn = {} + for m in method: + conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) - # compute coherence on blocks of trials - conn = np.zeros((n_epochs, n_pairs, n_freqs, len(times))) logger.info('Connectivity computation...') # parameters to pass to the connectivity function call_params = dict( method=method, kernel=kernel, foi_idx=foi_idx, source_idx=source_idx, target_idx=target_idx, - mode=mode, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, - mt_bandwidth=mt_bandwidth, + mode=mode, sfreq=sfreq, freqs=freqs, faverage=faverage, + n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, decim=decim, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, verbose=verbose) @@ -216,70 +278,71 @@ def spectral_connectivity_time(data, names=None, method='coh', indices=None, conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) # merge results - conn[epoch_idx, ...] = np.stack(conn_tr, axis=1) + for m in method: + conn[m][epoch_idx, ...] = np.stack(conn_tr[m], + axis=1).squeeze(axis=-1) # create a Connectivity container indices = 'symmetric' - conn = EpochSpectroTemporalConnectivity( - conn, freqs=f_vec, times=times, - n_nodes=n_signals, names=names, indices=indices, method=method, - spec_method=mode, events=events, event_id=event_id, metadata=metadata) - return conn + if average: + out = [SpectralConnectivity( + conn[m].mean(axis=0), freqs=out_freqs, n_nodes=n_signals, + names=names, indices=indices, method=method, spec_method=mode, + events=events, event_id=event_id, metadata=metadata) + for m in method] + else: + out = [EpochSpectralConnectivity( + conn[m], freqs=out_freqs, n_nodes=n_signals, names=names, + indices=indices, method=method, spec_method=mode, events=events, + event_id=event_id, metadata=metadata) for m in method] + + # return the object instead of list of length one + if len(out) == 1: + return out[0] + else: + return out def _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, - mode, sfreq, freqs, n_cycles, mt_bandwidth=None, - decim=1, kw_cwt={}, kw_mt={}, n_jobs=1, - verbose=False): - """EStimate time-resolved connectivity for one epoch. + mode, sfreq, freqs, faverage, n_cycles, + mt_bandwidth=None, decim=1, kw_cwt={}, kw_mt={}, + n_jobs=1, verbose=False): + """Estimate time-resolved connectivity for one epoch. See spectral_connectivity_epoch.""" n_pairs = len(source_idx) # first compute time-frequency decomposition - collapse = None if mode == 'cwt_morlet': out = tfr_array_morlet( data, sfreq, freqs, n_cycles=n_cycles, output='complex', decim=decim, n_jobs=n_jobs, **kw_cwt) + out = np.expand_dims(out, axis=2) # same dims with multitaper elif mode == 'multitaper': - # In case multiple values are provided for mt_bandwidth - # the MT decomposition is done separatedly for each - # Frequency center - if isinstance(mt_bandwidth, (list, tuple, np.ndarray)): - # Arrays freqs, n_cycles, mt_bandwidth should have the same size - assert len(freqs) == len(n_cycles) == len(mt_bandwidth) - out = [] - for f_c, n_c, mt in zip(freqs, n_cycles, mt_bandwidth): - out += [tfr_array_multitaper( - data, sfreq, [f_c], n_cycles=float(n_c), time_bandwidth=mt, - output='complex', decim=decim, n_jobs=n_jobs, **kw_mt)] - out = np.stack(out, axis=3).squeeze() - elif isinstance(mt_bandwidth, (type(None), int, float)): - out = tfr_array_multitaper( - data, sfreq, freqs, n_cycles=n_cycles, - time_bandwidth=mt_bandwidth, output='complex', decim=decim, - n_jobs=n_jobs, **kw_mt) - collapse = True - if out.ndim == 5: # newest MNE-Python - collapse = -3 - - # get the supported connectivity function - conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs}[method] - - # computes conn across trials - # TODO: This is wrong -- it averages in the complex domain (over tapers). - # What it *should* do is compute the conn for each taper, then average - # (see below). - if collapse is not None: - out = np.mean(out, axis=collapse) - this_conn = conn_func(out, kernel, foi_idx, source_idx, target_idx, - n_jobs=n_jobs, verbose=verbose, total=n_pairs) - # This is where it should go, but the regression test fails... - # if collapse is not None: - # this_conn = [c.mean(axis=collapse) for c in this_conn] + print(data.shape) + out = tfr_array_multitaper( + data, sfreq, freqs, n_cycles=n_cycles, + time_bandwidth=mt_bandwidth, output='complex', decim=decim, + n_jobs=n_jobs, **kw_mt) + else: + raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") + + # compute for each required connectivity method + this_conn = {} + conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, + 'wpli': _wpli} + for m in method: + c_func = conn_func[m] + # compute connectivity + this_conn[m] = c_func(out, kernel, foi_idx, source_idx, + target_idx, n_jobs=n_jobs, + verbose=verbose, total=n_pairs, + faverage=faverage) + # mean over tapers + this_conn[m] = [c.mean(axis=1) for c in this_conn[m]] + return this_conn @@ -289,24 +352,29 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### ############################################################################### -def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total): - """Pairwise coherence.""" +def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage): + """Pairwise coherence. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" # auto spectra (faster that w * w.conj()) s_auto = w.real ** 2 + w.imag ** 2 # smooth the auto spectra s_auto = _smooth_spectra(s_auto, kernel) - # define the pairwise coherence def pairwise_coh(w_x, w_y): - # computes the coherence + # compute coherence s_xy = w[:, w_y] * np.conj(w[:, w_x]) s_xy = _smooth_spectra(s_xy, kernel) s_xx = s_auto[:, w_x] s_yy = s_auto[:, w_y] - out = np.abs(s_xy) ** 2 / (s_xx * s_yy) + out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ + np.sqrt(s_xx.mean(axis=-1, keepdims=True) * + s_yy.mean(axis=-1, keepdims=True)) # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray): + if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) else: return out @@ -315,24 +383,29 @@ def pairwise_coh(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_coh, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial coherence + # compute pairwise coherence coherence return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) -def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total): - """Pairwise phase-locking value.""" +def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage): + """Pairwise phase-locking value. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" # define the pairwise plv def pairwise_plv(w_x, w_y): - # computes the plv + # compute plv s_xy = w[:, w_y] * np.conj(w[:, w_x]) # complex exponential of phase differences exp_dphi = s_xy / np.abs(s_xy) # smooth e^(-i*\delta\phi) exp_dphi = _smooth_spectra(exp_dphi, kernel) - # computes plv - out = np.abs(exp_dphi) + # mean over samples (time axis) + exp_dphi_mean = exp_dphi.mean(axis=-1, keepdims=True) + out = np.abs(exp_dphi_mean) # mean inside frequency sliding window (if needed) - if isinstance(foi_idx, np.ndarray): + if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) else: return out @@ -341,18 +414,79 @@ def pairwise_plv(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_plv, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial coherence + # compute the single trial plv + return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + + +def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage): + """Pairwise phase-lag index. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" + # define the pairwise pli + def pairwise_pli(w_x, w_y): + # compute cross spectrum + s_xy = w[:, w_y] * np.conj(w[:, w_x]) + # smooth e^(-i*\delta\phi) + s_xy = _smooth_spectra(s_xy, kernel) + # phase lag index + out = np.abs(np.mean(np.sign(np.imag(s_xy)), + axis=-1, keepdims=True)) + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + return _foi_average(out, foi_idx) + else: + return out + + # define the function to compute in parallel + parallel, p_fun, n_jobs = parallel_func( + pairwise_pli, n_jobs=n_jobs, verbose=verbose, total=total) + + # compute the single trial pli + return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) + + +def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage): + """Pairwise weighted phase-lag index. + + Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, + n_times).""" + # define the pairwise wpli + def pairwise_wpli(w_x, w_y): + # compute cross spectrum + s_xy = w[:, w_y] * np.conj(w[:, w_x]) + # smooth + s_xy = _smooth_spectra(s_xy, kernel) + # magnitude of the mean of the imaginary part of the cross spectrum + s_xy_mean_abs = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + # mean of the magnitudes of the imaginary part of the cross spectrum + s_xy_abs_mean = np.abs(s_xy.imag).mean(axis=-1, keepdims=True) + out = s_xy_mean_abs / s_xy_abs_mean + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + return _foi_average(out, foi_idx) + else: + return out + + # define the function to compute in parallel + parallel, p_fun, n_jobs = parallel_func( + pairwise_wpli, n_jobs=n_jobs, verbose=verbose, total=total) + + # compute the single trial wpli return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) -def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total): +def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, + faverage): """Pairwise cross-spectra.""" # define the pairwise cross-spectra def pairwise_cs(w_x, w_y): # computes the cross-spectra out = w[:, w_x] * np.conj(w[:, w_y]) out = _smooth_spectra(out, kernel) - if foi_idx is not None: + if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) else: return out From 05ae8f63d060cdc09bf2461db17aedf598699d38 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 31 Aug 2022 12:16:11 +0300 Subject: [PATCH 02/47] require MNE-Python 1.0 or newer due to breaking changes in mne.time_frequency.tfr_array_multitaper --- environment.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index 3fe4ed24..8fb68645 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: mne +name: mne-connectivity channels: - conda-forge dependencies: @@ -20,5 +20,5 @@ dependencies: - pyvista>=0.32 - pyvistaqt>=0.4 - pyqt!=5.15.3 -- mne +- mne>=1.0 - h5netcdf From 9cc283c13939282bb01833b1794a51f117a617f8 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 31 Aug 2022 12:17:20 +0300 Subject: [PATCH 03/47] update tests corresponding to API changes and use longer test signal --- .../spectral/tests/test_spectral.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 28d69128..e8167b9b 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -486,7 +486,7 @@ def test_spectral_connectivity_time_resolved(method, mode): sfreq = 50. n_signals = 3 n_epochs = 2 - n_times = 256 + n_times = 500 trans_bandwidth = 2. tmin = 0. tmax = (n_times - 1) / sfreq @@ -506,18 +506,17 @@ def test_spectral_connectivity_time_resolved(method, mode): # run connectivity estimation con = spectral_connectivity_time( - data, freqs=freqs, method=method, mode=mode) - assert con.shape == (n_epochs, n_signals * 2, n_freqs, n_times) + data, cwt_freqs=freqs, method=method, mode=mode) + assert con.shape == (n_epochs, n_signals * 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ - (n_epochs, n_signals, n_signals, n_freqs, n_times) - - # average over time - conn_data = con.get_data(output='dense').mean(axis=-1) - conn_data = conn_data.mean(axis=-1) + (n_epochs, n_signals, n_signals, len(con.freqs)) # test the simulated signal triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T + # average over frequencies + conn_data = con.get_data(output='dense').mean(axis=-1) + # the indices at which there is a correlation should be greater # then the rest of the components for epoch_idx in range(n_epochs): @@ -538,7 +537,7 @@ def test_time_resolved_spectral_conn_regression(method, mode): test_file_path_str = str(_resource_path( 'mne_connectivity.tests', f'data/test_frite_dataset_{mode}_{method}.npy')) - test_conn = np.load(test_file_path_str) + test_conn = np.load(test_file_path_str).mean(axis=-1) # paths to mne datasets - sample ECoG bids_root = mne.datasets.epilepsy_ecog.data_path() @@ -581,7 +580,7 @@ def test_time_resolved_spectral_conn_regression(method, mode): if mode == 'morlet': mode = 'cwt_morlet' conn = spectral_connectivity_time( - epochs, freqs=freqs, n_jobs=1, method=method, mode=mode) + epochs, cwt_freqs=freqs, n_jobs=1, method=method, mode=mode) # frites only stores the upper triangular parts of the raveled array row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1) From 924b4abb50cae30fbf22208dafafd69f8bafed2b Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:25:23 +0300 Subject: [PATCH 04/47] update docstring: connectivity is not averaged over Epochs by default Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 4ffb77eb..05a85fcb 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -49,7 +49,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, average : bool Average connectivity scores over Epochs. If True, output will be an instance of ``SpectralConnectivity`` , otherwise - ``EpochSpectralConnectivity``. + ``EpochSpectralConnectivity``. By default False. indices : tuple of array | None Two arrays with indices of connections for which to compute connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. From c5b75f8c014f222b7bc4b8d6866ab1f62bbfb35a Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:27:09 +0300 Subject: [PATCH 05/47] fix docstring typo Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 05a85fcb..8e22f8bb 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -62,7 +62,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, If None the frequency corresponding to an epoch length of 5 cycles is used. fmax : float | tuple of float - The upper frequency of interest. Multiple bands are dedined using + The upper frequency of interest. Multiple bands are defined using a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. fskip : int Omit every "(fskip + 1)-th" frequency bin to decimate in frequency From 7699983f563fe2b1416a9edfc4851137ec708f89 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:32:35 +0300 Subject: [PATCH 06/47] update docstring: faverage is False by default Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 8e22f8bb..25270b5b 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -70,7 +70,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, faverage : bool Average connectivity scores for each frequency band. If True, the output freqs will be a list with arrays of the frequencies - that were averaged. + that were averaged. By default False. sm_times : float Amount of time to consider for the temporal smoothing in seconds. By default, 0.5 sec smoothing is used. From 89f683a600a8907405038401c3694f40677af6bc Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Thu, 1 Sep 2022 11:13:04 +0300 Subject: [PATCH 07/47] update docstring: lower bound of frequency range for connectivity computation may be None Co-authored-by: Britta Westner --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 25270b5b..66455afc 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -56,7 +56,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, If None, all connections are computed. sfreq : float The sampling frequency. - fmin : float | tuple of float + fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. If None the frequency corresponding to an epoch length of 5 cycles From cb2dab5a035266573fdbea4fac62fa360f38afde Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Thu, 1 Sep 2022 11:20:08 +0300 Subject: [PATCH 08/47] update docstring: fmax may be None Co-authored-by: Britta Westner --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 66455afc..0fa03f02 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -61,7 +61,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. If None the frequency corresponding to an epoch length of 5 cycles is used. - fmax : float | tuple of float + fmax : float | tuple of float | None The upper frequency of interest. Multiple bands are defined using a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. fskip : int From 0fd34c9238ca88e72b02649260609fdf0f70181e Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 2 Sep 2022 10:26:05 +0300 Subject: [PATCH 09/47] update docstring --- mne_connectivity/spectral/time.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 4ffb77eb..690b52ed 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -26,7 +26,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, block_size=1000, n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. - This method computes time-resolved connectivity measures for Epochs. + This method computes time-resolved connectivity measures from epoched data. The connectivity method(s) are specified using the "method" parameter. All methods are based on estimates of the cross- and power spectral @@ -39,15 +39,17 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'plv', - 'sxy']``. These are: + 'sxy', 'pli', 'wpli']``. These are: * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) * 'sxy' : Cross-spectrum + * 'pli' : Phase-Lag Index + * 'wpli': Weighted Phase-Lag Index - By default, the coherence is used. + By default, coherence is used. average : bool - Average connectivity scores over Epochs. If True, output will be + Average connectivity scores over epochs. If True, output will be an instance of ``SpectralConnectivity`` , otherwise ``EpochSpectralConnectivity``. indices : tuple of array | None From f4992a7f7556626380ef51f0840e37b3fb38cceb Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 2 Sep 2022 10:27:11 +0300 Subject: [PATCH 10/47] new default for fmin --- mne_connectivity/spectral/time.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 690b52ed..94311c58 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -8,7 +8,7 @@ from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper) -from mne.utils import logger +from mne.utils import (logger, warn) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) from .epochs import _compute_freqs, _compute_freq_mask @@ -162,9 +162,20 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # check that method is a list if isinstance(method, str): method = [method] - # check that fmin and fmax are lists + + # check that fmin corresponds to at least 5 cycles + dur = float(n_times) / sfreq + five_cycle_freq = 5. / dur if fmin is None: - fmin = 1 + # we use the 5 cycle freq. as default + fmin = five_cycle_freq + else: + if np.any(fmin < five_cycle_freq): + warn('fmin=%0.3f Hz corresponds to %0.3f < 5 cycles ' + 'based on the epoch length %0.3f sec, need at least %0.3f ' + 'sec epochs or fmin=%0.3f. Spectrum estimate will be ' + 'unreliable.' % (np.min(fmin), dur * np.min(fmin), dur, + 5. / np.min(fmin), five_cycle_freq)) if fmax is None: fmax = sfreq / 2 fmin = np.array((fmin,), dtype=float).ravel() From fa9bce7d6ca902488070459e62b5ed8884153b22 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Fri, 2 Sep 2022 15:54:23 +0300 Subject: [PATCH 11/47] improve docstring and warnings for spectral_connectivity_time --- mne_connectivity/spectral/time.py | 139 +++++++++++++++++++++++------- 1 file changed, 108 insertions(+), 31 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 2750ef8e..98332692 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -57,15 +57,16 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. If None, all connections are computed. sfreq : float - The sampling frequency. + The sampling frequency. Should be specified if data is not ``Epochs``. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - If None the frequency corresponding to an epoch length of 5 cycles + If None, the frequency corresponding to an epoch length of 5 cycles is used. fmax : float | tuple of float | None The upper frequency of interest. Multiple bands are defined using a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. + If None, sfreq/2 is used. fskip : int Omit every "(fskip + 1)-th" frequency bin to decimate in frequency domain. @@ -74,34 +75,40 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, the output freqs will be a list with arrays of the frequencies that were averaged. By default False. sm_times : float - Amount of time to consider for the temporal smoothing in seconds. By - default, 0.5 sec smoothing is used. + Amount of time to consider for the temporal smoothing in seconds. + If zero, no temporal smoothing is applied. By default, + 0.5 sec smoothing is used. sm_freqs : int Number of points for frequency smoothing. By default, 1 is used which is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} - Kernel type to use. Choose either 'square' or 'hanning' (default). - mode : str, optional - Spectrum estimation mode can be either: 'multitaper', or - 'cwt_morlet'. + Smoothing kernel type. Choose either 'square' or 'hanning' (default). + mode : str + Time-frequency decomposition method. Can be either: 'multitaper', or + 'cwt_morlet'. See ``tfr_array_multitaper`` and ``tfr_array_wavelet`` + for reference. mt_bandwidth : float | None - The bandwidth of the multitaper windowing function in Hz. - Only used in 'multitaper' mode. + Multitaper time bandwidth. If None, will be set to 4.0 (3 tapers). + Time x (Full) Bandwidth product. The number of good tapers (low-bias) + is chosen automatically based on this to equal + floor(time_bandwidth - 1). By default None. cwt_freqs : array Array of frequencies of interest for time-frequency decomposition. - Only used in 'cwt_morlet' mode. + Only used in 'cwt_morlet' mode. Only the frequencies within + the range specified by fmin and fmax are used. Must be specified if + `mode='cwt_morlet'`. Not used when `mode='multitaper'`. n_cycles : float | array of float - Number of cycles for use in time-frequency decomposition method + Number of wavelet cycles for use in time-frequency decomposition method (specified by ``mode``). Fixed number or one per frequency. decim : int | 1 To reduce memory usage, decimation factor after time-frequency decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, returns tfr[…, decim]. block_size : int - How many epochs to compute at once (higher numbers are faster + Number of epochs to compute at once (higher numbers are faster but require more memory). n_jobs : int - How many epochs to process in parallel. + Number of connections to compute in parallel. %(verbose)s Returns @@ -120,14 +127,79 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, -------- mne_connectivity.spectral_connectivity_epochs mne_connectivity.SpectralConnectivity - mne_connectivity.SpectroTemporalConnectivity + mne_connectivity.EpochSpectralConnectivity Notes ----- + + Please note that the interpretation of the measures in this function + depends on the data and underlying assumptions and does not necessarily + reflect a causal relationship between brain regions. + + The connectivity measures are computed over time within each epoch and + optionally averaged over epochs. High connectivity values indicate that + the phase differences between signals stay consistent over time. + + The spectral densities can be estimated using a multitaper method with + digital prolate spheroidal sequence (DPSS) windows, or a continuous wavelet + transform using Morlet wavelets. The spectral estimation mode is specified + using the "mode" parameter. + + By default, the connectivity between all signals is computed (only + connections corresponding to the lower-triangular part of the + connectivity matrix). If one is only interested in the connectivity + between some signals, the "indices" parameter can be used. For example, + to compute the connectivity between the signal with index 0 and signals + "2, 3, 4" (a total of 3 connections) one can use the following:: + + indices = (np.array([0, 0, 0]), # row indices + np.array([2, 3, 4])) # col indices + + con = spectral_connectivity_time(data, method='coh', + indices=indices, ...) + + In this case con.get_data().shape = (3, n_freqs). The connectivity scores + are in the same order as defined indices. + + **Supported Connectivity Measures** + + The connectivity method(s) is specified using the "method" parameter. The + following methods are supported (note: ``E[]`` denotes average over + epochs). Multiple measures can be computed at once by using a list/tuple, + e.g., ``['coh', 'pli']`` to compute coherence and PLI. + + 'coh' : Coherence given by:: + + | E[Sxy] | + C = --------------------- + sqrt(E[Sxx] * E[Syy]) + + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given + by:: + + PLV = |E[Sxy/|Sxy|]| + + 'sxy' : Cross spectrum Sxy + + 'pli' : Phase Lag Index (PLI) :footcite:`StamEtAl2007` given by:: + + PLI = |E[sign(Im(Sxy))]| + + 'wpli' : Weighted Phase Lag Index (WPLI) :footcite:`VinckEtAl2011` + given by:: + + |E[Im(Sxy)]| + WPLI = ------------------ + E[|Im(Sxy)|] + This function was originally implemented in ``frites`` and was ported over. .. versionadded:: 0.3 + + References + ---------- + .. footbibliography:: """ events = None event_id = None @@ -158,6 +230,10 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, times = np.arange(0, n_times) names = np.arange(0, n_signals) metadata = None + if sfreq is None: + warn("Sampling frequency (sfreq) was not specified and could not " + "be inferred from data. Using default value 2*numpy.pi. " + "Connectivity results might not be interpretable.") # check that method is a list if isinstance(method, str): @@ -169,6 +245,8 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, if fmin is None: # we use the 5 cycle freq. as default fmin = five_cycle_freq + logger.info(f'Fmin was not specified. Using fmin={fmin:.2f}, which ' + 'corresponds to at least five cycles.') else: if np.any(fmin < five_cycle_freq): warn('fmin=%0.3f Hz corresponds to %0.3f < 5 cycles ' @@ -178,6 +256,9 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, 5. / np.min(fmin), five_cycle_freq)) if fmax is None: fmax = sfreq / 2 + logger.info(f'Fmax was not specified. Using fmax={fmax:.2f}, which ' + f'corresponds to Nyquist.') + fmin = np.array((fmin,), dtype=float).ravel() fmax = np.array((fmax,), dtype=float).ravel() if len(fmin) != len(fmax): @@ -234,24 +315,19 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # sampling rate, specified wavelet frequencies and mode freqs = _compute_freqs(n_times, sfreq, cwt_freqs, mode) - if fmin is not None and fmax is not None: - # compute the mask based on specified min/max and decimation factor - freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) + # compute the mask based on specified min/max and decimation factor + freq_mask = _compute_freq_mask(freqs, fmin, fmax, fskip) - # the frequency points where we compute connectivity - freqs = freqs[freq_mask] + # the frequency points where we compute connectivity + freqs = freqs[freq_mask] # frequency mean - if fmin is None or fmax is None: - foi_idx = None - f_vec = freqs - else: - _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), - coords=(freqs,)) - foi_s = _f.sel(freqs=fmin, method='nearest').data - foi_e = _f.sel(freqs=fmax, method='nearest').data - foi_idx = np.c_[foi_s, foi_e] - f_vec = freqs[foi_idx].mean(1) + _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), + coords=(freqs,)) + foi_s = _f.sel(freqs=fmin, method='nearest').data + foi_e = _f.sel(freqs=fmax, method='nearest').data + foi_idx = np.c_[foi_s, foi_e] + f_vec = freqs[foi_idx].mean(1) if faverage: n_freqs = len(fmin) @@ -297,7 +373,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # create a Connectivity container indices = 'symmetric' - if average: out = [SpectralConnectivity( conn[m].mean(axis=0), freqs=out_freqs, n_nodes=n_signals, @@ -310,6 +385,8 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, indices=indices, method=method, spec_method=mode, events=events, event_id=event_id, metadata=metadata) for m in method] + logger.info('[Connectivity computation done]') + # return the object instead of list of length one if len(out) == 1: return out[0] From 9e9f983a266e7f4561491942e6c6f4bb3292ea06 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 5 Sep 2022 11:57:21 +0300 Subject: [PATCH 12/47] fix bug with indices: connectivity is now computed correctly between specified indices when set --- .../spectral/tests/test_spectral.py | 3 +- mne_connectivity/spectral/time.py | 31 +++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index e8167b9b..ce676680 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -502,12 +502,11 @@ def test_spectral_connectivity_time_resolved(method, mode): # define some frequencies for cwt freqs = np.arange(3, 20.5, 1) - n_freqs = len(freqs) # run connectivity estimation con = spectral_connectivity_time( data, cwt_freqs=freqs, method=method, mode=mode) - assert con.shape == (n_epochs, n_signals * 2, len(con.freqs)) + assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ (n_epochs, n_signals, n_signals, len(con.freqs)) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 98332692..62a4c618 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -284,18 +284,14 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, kernel = _create_kernel(sm_times, sm_freqs, kernel=sm_kernel) # get indices of pairs of (group) regions - roi = names # ch_names if indices is None: - # roi_gp and roi_idx - roi_gp, _ = roi, np.arange(len(roi)).reshape(-1, 1) - # get pairs for directed / undirected conn - source_idx, target_idx = np.triu_indices(len(roi_gp), k=0) + indices_use = np.tril_indices(n_signals, k=-1) else: indices_use = check_indices(indices) - source_idx = [x[0] for x in indices_use] - target_idx = [x[1] for x in indices_use] - roi_gp, _ = roi, np.arange(len(roi)).reshape(-1, 1) + + source_idx = indices_use[0] + target_idx = indices_use[1] n_pairs = len(source_idx) # frequency checking @@ -347,10 +343,9 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, blocks = [np.arange(n_epochs)] # compute connectivity on blocks of trials - conn = {} + conn = dict() for m in method: - conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) - + conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) logger.info('Connectivity computation...') # parameters to pass to the connectivity function @@ -371,8 +366,19 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, conn[m][epoch_idx, ...] = np.stack(conn_tr[m], axis=1).squeeze(axis=-1) + if indices is None: + conn_flat = conn + conn = dict() + for m in method: + this_conn = np.zeros((n_epochs, n_signals, n_signals) + + conn_flat[m].shape[2:], + dtype=conn_flat[m].dtype) + this_conn[:, source_idx, target_idx] = conn_flat[m][:, ...] + this_conn = this_conn.reshape((n_epochs, n_signals ** 2,) + + conn_flat[m].shape[2:]) + conn[m] = this_conn + # create a Connectivity container - indices = 'symmetric' if average: out = [SpectralConnectivity( conn[m].mean(axis=0), freqs=out_freqs, n_nodes=n_signals, @@ -411,7 +417,6 @@ def _spectral_connectivity(data, method, kernel, foi_idx, decim=decim, n_jobs=n_jobs, **kw_cwt) out = np.expand_dims(out, axis=2) # same dims with multitaper elif mode == 'multitaper': - print(data.shape) out = tfr_array_multitaper( data, sfreq, freqs, n_cycles=n_cycles, time_bandwidth=mt_bandwidth, output='complex', decim=decim, From 2e4ccd1e73c8fe15b03a9c72982ccc539d3608cf Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 5 Sep 2022 16:19:25 +0300 Subject: [PATCH 13/47] DOC: improve documentation --- mne_connectivity/spectral/epochs.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 412a8c3c..118b5cf0 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -860,6 +860,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, See Also -------- + mne_connectivity.spectral_connectivity_time mne_connectivity.SpectralConnectivity mne_connectivity.SpectroTemporalConnectivity @@ -874,7 +875,9 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, connectivity structure. Within each Epoch, it is assumed that the spectral measure is stationary. The spectral measures implemented in this function are computed across Epochs. **Thus, spectral measures computed with only - one Epoch will result in errorful values.** + one Epoch will result in errorful values and spectral measures computed + with few Epochs will be unreliable.** Please see + ``spectral_connectivity_time`` for time-resolved connectivity estimation. The spectral densities can be estimated using a multitaper method with digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier @@ -892,11 +895,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, indices = (np.array([0, 0, 0]), # row indices np.array([2, 3, 4])) # col indices - con_flat = spectral_connectivity(data, method='coh', - indices=indices, ...) + con = spectral_connectivity_epochs(data, method='coh', + indices=indices, ...) - In this case con_flat.shape = (3, n_freqs). The connectivity scores are - in the same order as defined indices. + In this case con.get_data().shape = (3, n_freqs). The connectivity scores + are in the same order as defined indices. **Supported Connectivity Measures** From 926c330ecd6e433e4c421090102c92212a8e384d Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 7 Sep 2022 17:04:47 +0300 Subject: [PATCH 14/47] change smoothing default to no smoothing --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 62a4c618..37ba4264 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -19,7 +19,7 @@ @fill_doc def spectral_connectivity_time(data, names=None, method='coh', average=False, indices=None, sfreq=2 * np.pi, fmin=None, - fmax=None, fskip=0, faverage=False, sm_times=.5, + fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', mode='cwt_morlet', mt_bandwidth=None, cwt_freqs=None, n_cycles=7, decim=1, From b881c347bb0593606c56e0944d494515037f4362 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 7 Sep 2022 17:05:23 +0300 Subject: [PATCH 15/47] DOC: updates to main docstring --- mne_connectivity/spectral/time.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 37ba4264..b8d5cedb 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -73,11 +73,10 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, faverage : bool Average connectivity scores for each frequency band. If True, the output freqs will be a list with arrays of the frequencies - that were averaged. By default False. + that were averaged. By default, False. sm_times : float Amount of time to consider for the temporal smoothing in seconds. - If zero, no temporal smoothing is applied. By default, - 0.5 sec smoothing is used. + If zero, no temporal smoothing is applied. By default, 0. sm_freqs : int Number of points for frequency smoothing. By default, 1 is used which is equivalent to no smoothing. From 73fb937309cdd58d3bdf35fe73efbea7a638f067 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 7 Sep 2022 17:05:50 +0300 Subject: [PATCH 16/47] BUG: number of blocks is now computed correctly --- mne_connectivity/spectral/time.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index b8d5cedb..6d47de42 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -335,8 +335,9 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, if block_size > n_epochs: block_size = n_epochs - if isinstance(block_size, int) and (block_size > 1): - n_blocks = n_epochs // block_size + n_epochs % block_size + if isinstance(block_size, int): + n_blocks = n_epochs // block_size + 1 if n_epochs % block_size \ + else n_epochs // block_size blocks = np.array_split(np.arange(n_epochs), n_blocks) else: blocks = [np.arange(n_epochs)] From baaca264c1d096ab77ddfd0a880ef2771d4d3634 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 12 Sep 2022 15:47:03 +0300 Subject: [PATCH 17/47] add test for time-resolved connectivity with simulated data --- .../spectral/tests/test_spectral.py | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index ce676680..958376b6 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -2,6 +2,7 @@ from numpy.testing import (assert_allclose, assert_array_almost_equal, assert_array_less) import pytest +import scipy import warnings import mne @@ -478,7 +479,49 @@ def test_epochs_tmin_tmax(kind): assert len(w) == 1 # just one even though there were multiple epochs -@pytest.mark.parametrize('method', ['coh', 'plv']) +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'mode', ['cwt_morlet', 'multitaper']) +@pytest.mark.parametrize('data_option', ['sync', 'random']) +def test_spectral_connectivity_time_sim(method, mode, data_option): + """Test time-resolved spectral connectivity with simulated data.""" + rng = np.random.default_rng(0) + n_epochs = 5 + n_channels = 3 + n_times = 1000 + sfreq = 250 + data = np.zeros((n_epochs, n_channels, n_times)) + if data_option == 'random': + # Data is random, there should be no consistent phase differences. + data = rng.random((n_epochs, n_channels, n_times)) + if data_option == 'sync': + # Data consists of phase-locked 10Hz sine waves with constant phase + # difference within each epoch. + wave_freq = 10 + epoch_length = n_times/sfreq + for i in range(n_epochs): + for c in range(n_channels): + phase = rng.random() * 10 + x = np.linspace(-wave_freq*epoch_length*np.pi+phase, + wave_freq*epoch_length*np.pi+phase, n_times) + data[i, c] = np.squeeze(np.sin(x)) + freq_band_low_limit = (8.) + freq_band_high_limit = (13.) + cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit+1) + con = spectral_connectivity_time(data, method=method, mode=mode, + sfreq=sfreq, fmin=freq_band_low_limit, + fmax=freq_band_high_limit, + cwt_freqs=cwt_freqs, n_jobs=1, + faverage=True, average=True, sm_times=0) + assert con.shape == (n_channels ** 2, len(con.freqs)) + con_matrix = con.get_data('dense')[..., 0] + if data_option == 'sync': + assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), + atol=0.01) + if data_option == 'random': + assert np.all(con_matrix) <= 0.5 + +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) def test_spectral_connectivity_time_resolved(method, mode): @@ -486,7 +529,7 @@ def test_spectral_connectivity_time_resolved(method, mode): sfreq = 50. n_signals = 3 n_epochs = 2 - n_times = 500 + n_times = 1000 trans_bandwidth = 2. tmin = 0. tmax = (n_times - 1) / sfreq @@ -505,7 +548,8 @@ def test_spectral_connectivity_time_resolved(method, mode): # run connectivity estimation con = spectral_connectivity_time( - data, cwt_freqs=freqs, method=method, mode=mode) + data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode, + n_cycles=5) assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs)) assert con.get_data(output='dense').shape == \ (n_epochs, n_signals, n_signals, len(con.freqs)) @@ -578,8 +622,10 @@ def test_time_resolved_spectral_conn_regression(method, mode): # mode was renamed in mne-connectivity if mode == 'morlet': mode = 'cwt_morlet' - conn = spectral_connectivity_time( - epochs, cwt_freqs=freqs, n_jobs=1, method=method, mode=mode) + sfreq = raw.info['sfreq'] + conn = spectral_connectivity_time(epochs, sfreq=sfreq, cwt_freqs=freqs, + n_jobs=1, method=method, mode=mode, + n_cycles=5) # frites only stores the upper triangular parts of the raveled array row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1) From 2a7e09e880556846d576fca01562603c76a7f863 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 22 Sep 2022 11:35:39 +0300 Subject: [PATCH 18/47] Change block_size default to 1 This change will minimize memory usage by default. --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 6d47de42..7db934e3 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -23,7 +23,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, sm_freqs=1, sm_kernel='hanning', mode='cwt_morlet', mt_bandwidth=None, cwt_freqs=None, n_cycles=7, decim=1, - block_size=1000, n_jobs=1, verbose=None): + block_size=1, n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. This method computes time-resolved connectivity measures from epoched data. From 2a0be06ccb92ca2b1331b66e1506a559d110f1a0 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 22 Sep 2022 11:37:08 +0300 Subject: [PATCH 19/47] Add documentation for block_size Add a short description for the block_size parameter to allow users to better understand the memory usage of the spectral_connectivity_time function. --- mne_connectivity/spectral/time.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7db934e3..ceee1c42 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -104,8 +104,14 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, returns tfr[…, decim]. block_size : int - Number of epochs to compute at once (higher numbers are faster - but require more memory). + Number of epochs to compute at once. Higher numbers are faster but + require more memory. Memory requirement in bytes is proportional to + `16*block_size*n_channels*n_tapers*n_freqs*n_times`, + where `n_tapers=mt_bandwidth-1` when `mode='multitaper'` and + `n_tapers=1` when `mode='cwt_morlet'`, and `n_freqs` is the number + of frequencies for connectivity computation, `n_freqs` is determined by + ``scipy.fft.rfftfreq`` when `mode='multitaper'` + and `n_freqs=len(cwt_freqs)` when `mode='cwt_morlet'`. n_jobs : int Number of connections to compute in parallel. %(verbose)s From 53ac7b590cad9291ae5bc040bd8401d2ff9c75cb Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 22 Sep 2022 11:39:03 +0300 Subject: [PATCH 20/47] Change for more useful variable names Improve the readability of code by shortening variable names in pairwise_pli. --- mne_connectivity/spectral/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index ceee1c42..6f1fd2de 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -561,10 +561,10 @@ def pairwise_wpli(w_x, w_y): # smooth s_xy = _smooth_spectra(s_xy, kernel) # magnitude of the mean of the imaginary part of the cross spectrum - s_xy_mean_abs = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) + con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) # mean of the magnitudes of the imaginary part of the cross spectrum - s_xy_abs_mean = np.abs(s_xy.imag).mean(axis=-1, keepdims=True) - out = s_xy_mean_abs / s_xy_abs_mean + con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) + out = con_num / con_den # mean inside frequency sliding window (if needed) if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) From bc61e54840db013e22d6b77e47b4026c572686fe Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 5 Oct 2022 10:30:20 +0300 Subject: [PATCH 21/47] Remove regression test Remove the regression test which tests against the spectral connectivity implementation in frites. The implementation in frites is erroneous, and therefore we should not test against it. --- .../spectral/tests/test_spectral.py | 91 ------------------- 1 file changed, 91 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 958376b6..90cb8a01 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -568,97 +568,6 @@ def test_spectral_connectivity_time_resolved(method, mode): for idx, jdx in triu_inds) -@pytest.mark.parametrize('method', ['coh', 'plv']) -@pytest.mark.parametrize( - 'mode', ['morlet', 'multitaper']) -def test_time_resolved_spectral_conn_regression(method, mode): - """Regression test against original implementation in Frites. - - To see how the test dataset was generated, see - ``benchmarks/single_epoch_conn.py``. - """ - test_file_path_str = str(_resource_path( - 'mne_connectivity.tests', - f'data/test_frite_dataset_{mode}_{method}.npy')) - test_conn = np.load(test_file_path_str).mean(axis=-1) - - # paths to mne datasets - sample ECoG - bids_root = mne.datasets.epilepsy_ecog.data_path() - - # first define the BIDS path and load in the dataset - bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery', - task='ictal', datatype='ieeg', extension='vhdr') - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - raw = read_raw_bids(bids_path=bids_path, verbose=False) - line_freq = raw.info['line_freq'] - - # Pick only the ECoG channels, removing the ECG channels - raw.pick_types(ecog=True) - - # drop bad channels - raw.drop_channels(raw.info['bads']) - - # only pick the first three channels to lower RAM usage - raw = raw.pick_channels(raw.ch_names[:3]) - - # Load the data - raw.load_data() - - # Then we remove line frequency interference - raw.notch_filter(line_freq) - - # crop data and then Epoch - raw_copy = raw.copy() - raw = raw.crop(tmin=0, tmax=4, include_tmax=False) - epochs = make_fixed_length_epochs(raw=raw, duration=2., overlap=1.) - - ###################################################################### - # Perform basic test to match simulation data using time-resolved spec - ###################################################################### - # compare data to original run using Frites - freqs = [30, 90] - - # mode was renamed in mne-connectivity - if mode == 'morlet': - mode = 'cwt_morlet' - sfreq = raw.info['sfreq'] - conn = spectral_connectivity_time(epochs, sfreq=sfreq, cwt_freqs=freqs, - n_jobs=1, method=method, mode=mode, - n_cycles=5) - - # frites only stores the upper triangular parts of the raveled array - row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1) - conn_data = conn.get_data(output='dense')[ - :, row_triu_inds, col_triu_inds, ...] - assert_array_almost_equal(conn_data, test_conn) - - ###################################################################### - # Give varying set of frequency bands and frequencies to perform cWT - ###################################################################### - raw = raw_copy.crop(tmin=0, tmax=10, include_tmax=False) - ch_names = epochs.ch_names - epochs = make_fixed_length_epochs(raw=raw, duration=5, overlap=0.) - - # sampling rate of my data - sfreq = raw.info['sfreq'] - - # frequency bands of interest - fois = np.array([[4, 8], [8, 12], [12, 16], [16, 32]]) - - # frequencies of Continuous Morlet Wavelet Transform - freqs = np.arange(4., 32., 1) - - # compute coherence - cohs = spectral_connectivity_time( - epochs, names=None, method=method, indices=None, - sfreq=sfreq, foi=fois, sm_times=0.5, sm_freqs=1, sm_kernel='hanning', - mode=mode, mt_bandwidth=None, freqs=freqs, n_cycles=5) - assert cohs.get_data(output='dense').shape == ( - len(epochs), len(ch_names), len(ch_names), len(fois), len(epochs.times) - ) - - def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) From 7ce13bd6e25527f6d0404fa56bbb9137fde8ccfc Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 5 Oct 2022 14:13:51 +0300 Subject: [PATCH 22/47] Remove block_size parameter The block_size parameter is not useful, as testing shows that running the computation in blocks of epochs does not have a meaningful effect on the speed of computation, but significantly increases memory usage. --- mne_connectivity/spectral/time.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 6f1fd2de..f7a41e69 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -23,7 +23,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, sm_freqs=1, sm_kernel='hanning', mode='cwt_morlet', mt_bandwidth=None, cwt_freqs=None, n_cycles=7, decim=1, - block_size=1, n_jobs=1, verbose=None): + n_jobs=1, verbose=None): """Compute frequency- and time-frequency-domain connectivity measures. This method computes time-resolved connectivity measures from epoched data. @@ -103,15 +103,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, To reduce memory usage, decimation factor after time-frequency decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, returns tfr[…, decim]. - block_size : int - Number of epochs to compute at once. Higher numbers are faster but - require more memory. Memory requirement in bytes is proportional to - `16*block_size*n_channels*n_tapers*n_freqs*n_times`, - where `n_tapers=mt_bandwidth-1` when `mode='multitaper'` and - `n_tapers=1` when `mode='cwt_morlet'`, and `n_freqs` is the number - of frequencies for connectivity computation, `n_freqs` is determined by - ``scipy.fft.rfftfreq`` when `mode='multitaper'` - and `n_freqs=len(cwt_freqs)` when `mode='cwt_morlet'`. n_jobs : int Number of connections to compute in parallel. %(verbose)s @@ -363,8 +354,8 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, decim=decim, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, verbose=verbose) - for epoch_idx in blocks: - # compute time-resolved spectral connectivity + for epoch_idx in np.arange(n_epochs): + epoch_idx = [epoch_idx] conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) # merge results From c7dd18caa29018a62af02636ec6bab559ef1ef2a Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 5 Oct 2022 14:15:48 +0300 Subject: [PATCH 23/47] Improve documentation Added a note on memory mapping in the docstring of spectral_connectivity_time. Corrected some typos and inconsistent backticks. --- mne_connectivity/spectral/time.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index f7a41e69..89986566 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -51,7 +51,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, average : bool Average connectivity scores over epochs. If True, output will be an instance of ``SpectralConnectivity`` , otherwise - ``EpochSpectralConnectivity``. By default False. + ``EpochSpectralConnectivity``. By default, False. indices : tuple of array | None Two arrays with indices of connections for which to compute connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. @@ -95,16 +95,17 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, Array of frequencies of interest for time-frequency decomposition. Only used in 'cwt_morlet' mode. Only the frequencies within the range specified by fmin and fmax are used. Must be specified if - `mode='cwt_morlet'`. Not used when `mode='multitaper'`. + ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. n_cycles : float | array of float Number of wavelet cycles for use in time-frequency decomposition method (specified by ``mode``). Fixed number or one per frequency. - decim : int | 1 + decim : int To reduce memory usage, decimation factor after time-frequency decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, returns tfr[…, decim]. n_jobs : int - Number of connections to compute in parallel. + Number of connections to compute in parallel. Memory mapping must be + activated. Please see the Notes section for details. %(verbose)s Returns @@ -188,6 +189,16 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, WPLI = ------------------ E[|Im(Sxy)|] + Parallel computation can be activated by setting the ``n_jobs`` parameter. + Under the hood, this utilizes the ``joblib`` library. For effective + parallelization, you should activate memory mapping in MNE-Python by + setting ``MNE_MEMMAP_MIN_SIZE`` and ``MNE_CACHE_DIR``. For example, in your + code, run + ``` + mne.set_config('MNE_MEMMAP_MIN_SIZE', '10M') + mne.set_config('MNE_CACHE_DIR', '/dev/shm') + ``` + This function was originally implemented in ``frites`` and was ported over. @@ -328,18 +339,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, n_freqs = len(freqs) out_freqs = freqs - # build block size indices - if block_size > n_epochs: - block_size = n_epochs - - if isinstance(block_size, int): - n_blocks = n_epochs // block_size + 1 if n_epochs % block_size \ - else n_epochs // block_size - blocks = np.array_split(np.arange(n_epochs), n_blocks) - else: - blocks = [np.arange(n_epochs)] - - # compute connectivity on blocks of trials conn = dict() for m in method: conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) From 4d2c1f060085cb5da8732993c48c82565bce69b0 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 5 Oct 2022 14:22:00 +0300 Subject: [PATCH 24/47] Improve comments Removed redundant comments, clarified and fixed typos. --- mne_connectivity/spectral/time.py | 41 ++++++------------------------- 1 file changed, 7 insertions(+), 34 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 89986566..bcd4b221 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -250,7 +250,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, dur = float(n_times) / sfreq five_cycle_freq = 5. / dur if fmin is None: - # we use the 5 cycle freq. as default + # use the 5 cycle freq. as default fmin = five_cycle_freq logger.info(f'Fmin was not specified. Using fmin={fmin:.2f}, which ' 'corresponds to at least five cycles.') @@ -292,16 +292,14 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # get indices of pairs of (group) regions if indices is None: - # get pairs for directed / undirected conn indices_use = np.tril_indices(n_signals, k=-1) else: indices_use = check_indices(indices) - source_idx = indices_use[0] target_idx = indices_use[1] n_pairs = len(source_idx) - # frequency checking + # check cwt_freqs if cwt_freqs is not None: # check for single frequency if isinstance(cwt_freqs, (int, float)): @@ -324,7 +322,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # the frequency points where we compute connectivity freqs = freqs[freq_mask] - # frequency mean + # compute central frequencies _f = xr.DataArray(np.arange(len(freqs)), dims=('freqs',), coords=(freqs,)) foi_s = _f.sel(freqs=fmin, method='nearest').data @@ -356,8 +354,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, for epoch_idx in np.arange(n_epochs): epoch_idx = [epoch_idx] conn_tr = _spectral_connectivity(data[epoch_idx, ...], **call_params) - - # merge results for m in method: conn[m][epoch_idx, ...] = np.stack(conn_tr[m], axis=1).squeeze(axis=-1) @@ -403,10 +399,9 @@ def _spectral_connectivity(data, method, kernel, foi_idx, n_jobs=1, verbose=False): """Estimate time-resolved connectivity for one epoch. - See spectral_connectivity_epoch.""" + See spectral_connectivity_epochs.""" n_pairs = len(source_idx) - # first compute time-frequency decomposition if mode == 'cwt_morlet': out = tfr_array_morlet( data, sfreq, freqs, n_cycles=n_cycles, output='complex', @@ -420,13 +415,12 @@ def _spectral_connectivity(data, method, kernel, foi_idx, else: raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") - # compute for each required connectivity method + # compute for each connectivity method this_conn = {} conn_func = {'coh': _coh, 'plv': _plv, 'sxy': _cs, 'pli': _pli, 'wpli': _wpli} for m in method: c_func = conn_func[m] - # compute connectivity this_conn[m] = c_func(out, kernel, foi_idx, source_idx, target_idx, n_jobs=n_jobs, verbose=verbose, total=n_pairs, @@ -449,14 +443,13 @@ def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" - # auto spectra (faster that w * w.conj()) + # auto spectra (faster than w * w.conj()) s_auto = w.real ** 2 + w.imag ** 2 # smooth the auto spectra s_auto = _smooth_spectra(s_auto, kernel) def pairwise_coh(w_x, w_y): - # compute coherence s_xy = w[:, w_y] * np.conj(w[:, w_x]) s_xy = _smooth_spectra(s_xy, kernel) s_xx = s_auto[:, w_x] @@ -474,7 +467,6 @@ def pairwise_coh(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_coh, n_jobs=n_jobs, verbose=verbose, total=total) - # compute pairwise coherence coherence return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) @@ -484,15 +476,11 @@ def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" - # define the pairwise plv def pairwise_plv(w_x, w_y): - # compute plv s_xy = w[:, w_y] * np.conj(w[:, w_x]) - # complex exponential of phase differences exp_dphi = s_xy / np.abs(s_xy) - # smooth e^(-i*\delta\phi) exp_dphi = _smooth_spectra(exp_dphi, kernel) - # mean over samples (time axis) + # mean over time exp_dphi_mean = exp_dphi.mean(axis=-1, keepdims=True) out = np.abs(exp_dphi_mean) # mean inside frequency sliding window (if needed) @@ -505,7 +493,6 @@ def pairwise_plv(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_plv, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial plv return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) @@ -515,13 +502,9 @@ def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" - # define the pairwise pli def pairwise_pli(w_x, w_y): - # compute cross spectrum s_xy = w[:, w_y] * np.conj(w[:, w_x]) - # smooth e^(-i*\delta\phi) s_xy = _smooth_spectra(s_xy, kernel) - # phase lag index out = np.abs(np.mean(np.sign(np.imag(s_xy)), axis=-1, keepdims=True)) # mean inside frequency sliding window (if needed) @@ -534,7 +517,6 @@ def pairwise_pli(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_pli, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial pli return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) @@ -544,15 +526,10 @@ def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" - # define the pairwise wpli def pairwise_wpli(w_x, w_y): - # compute cross spectrum s_xy = w[:, w_y] * np.conj(w[:, w_x]) - # smooth s_xy = _smooth_spectra(s_xy, kernel) - # magnitude of the mean of the imaginary part of the cross spectrum con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) - # mean of the magnitudes of the imaginary part of the cross spectrum con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) out = con_num / con_den # mean inside frequency sliding window (if needed) @@ -565,16 +542,13 @@ def pairwise_wpli(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_wpli, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial wpli return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, faverage): """Pairwise cross-spectra.""" - # define the pairwise cross-spectra def pairwise_cs(w_x, w_y): - # computes the cross-spectra out = w[:, w_x] * np.conj(w[:, w_y]) out = _smooth_spectra(out, kernel) if isinstance(foi_idx, np.ndarray) and faverage: @@ -586,7 +560,6 @@ def pairwise_cs(w_x, w_y): parallel, p_fun, n_jobs = parallel_func( pairwise_cs, n_jobs=n_jobs, verbose=verbose, total=total) - # compute the single trial coherence return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) From 1b6224fa1fc2b592582ae1adde31b2d6f9e93f3c Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 5 Oct 2022 14:23:05 +0300 Subject: [PATCH 25/47] Remove unused code --- mne_connectivity/spectral/time.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index bcd4b221..7217d569 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -213,7 +213,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # extract data from Epochs object if isinstance(data, BaseEpochs): names = data.ch_names - times = data.times # input times for Epochs input type + times = data.times sfreq = data.info['sfreq'] events = data.events event_id = data.event_id @@ -283,7 +283,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # temporal decimation if isinstance(decim, int): - times = times[::decim] sm_times = int(np.round(sm_times / decim)) sm_times = max(sm_times, 1) From 283a1a1202f8c3176d2616a31ea52bf365d9f5b7 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 17 Oct 2022 16:02:04 +0300 Subject: [PATCH 26/47] Fix style issues --- .../spectral/tests/test_spectral.py | 23 ++++++++----------- mne_connectivity/spectral/time.py | 2 -- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 90cb8a01..14eaa3dd 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -2,15 +2,8 @@ from numpy.testing import (assert_allclose, assert_array_almost_equal, assert_array_less) import pytest -import scipy -import warnings - -import mne -from mne import (EpochsArray, SourceEstimate, create_info, - make_fixed_length_epochs) +from mne import (EpochsArray, SourceEstimate, create_info) from mne.filter import filter_data -from mne.utils import _resource_path -from mne_bids import BIDSPath, read_raw_bids from mne_connectivity import ( SpectralConnectivity, spectral_connectivity_epochs, @@ -498,16 +491,17 @@ def test_spectral_connectivity_time_sim(method, mode, data_option): # Data consists of phase-locked 10Hz sine waves with constant phase # difference within each epoch. wave_freq = 10 - epoch_length = n_times/sfreq + epoch_length = n_times / sfreq for i in range(n_epochs): for c in range(n_channels): phase = rng.random() * 10 - x = np.linspace(-wave_freq*epoch_length*np.pi+phase, - wave_freq*epoch_length*np.pi+phase, n_times) + x = np.linspace(-wave_freq * epoch_length * np.pi + phase, + wave_freq * epoch_length * np.pi + phase, + n_times) data[i, c] = np.squeeze(np.sin(x)) freq_band_low_limit = (8.) freq_band_high_limit = (13.) - cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit+1) + cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) con = spectral_connectivity_time(data, method=method, mode=mode, sfreq=sfreq, fmin=freq_band_low_limit, fmax=freq_band_high_limit, @@ -516,11 +510,14 @@ def test_spectral_connectivity_time_sim(method, mode, data_option): assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] if data_option == 'sync': - assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), + assert np.allclose(con_matrix, + np.tril(np.ones(con_matrix.shape), + k=-1), atol=0.01) if data_option == 'random': assert np.all(con_matrix) <= 0.5 + @pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 7217d569..c71bb6c0 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -213,7 +213,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, # extract data from Epochs object if isinstance(data, BaseEpochs): names = data.ch_names - times = data.times sfreq = data.info['sfreq'] events = data.events event_id = data.event_id @@ -234,7 +233,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, else: data = np.asarray(data) n_epochs, n_signals, n_times = data.shape - times = np.arange(0, n_times) names = np.arange(0, n_signals) metadata = None if sfreq is None: From e89ac777ce56a16b3e4892c2e4c9390a8a4424a7 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 19 Oct 2022 10:28:16 +0300 Subject: [PATCH 27/47] Add comment Co-authored-by: Adam Li --- mne_connectivity/spectral/tests/test_spectral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 14eaa3dd..7fc7386d 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -499,6 +499,7 @@ def test_spectral_connectivity_time_sim(method, mode, data_option): wave_freq * epoch_length * np.pi + phase, n_times) data[i, c] = np.squeeze(np.sin(x)) + # the frequency band should contain the frequency at which there is a hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) From 19603bd77252b32ccf62abfde2c6afe9afeca860 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 19 Oct 2022 10:36:40 +0300 Subject: [PATCH 28/47] Improve comment Co-authored-by: Adam Li --- mne_connectivity/spectral/tests/test_spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 7fc7386d..6fa51c23 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -477,7 +477,7 @@ def test_epochs_tmin_tmax(kind): 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) def test_spectral_connectivity_time_sim(method, mode, data_option): - """Test time-resolved spectral connectivity with simulated data.""" + """Test time-resolved spectral connectivity with simulated phase-locked data.""" rng = np.random.default_rng(0) n_epochs = 5 n_channels = 3 From e5da3aec18c4f25f7d01969c9d64d55490e1acf8 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 19 Oct 2022 10:39:24 +0300 Subject: [PATCH 29/47] Rename test function Co-authored-by: Adam Li --- mne_connectivity/spectral/tests/test_spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 6fa51c23..2c8e6190 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -476,7 +476,7 @@ def test_epochs_tmin_tmax(kind): @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) -def test_spectral_connectivity_time_sim(method, mode, data_option): +def test_spectral_connectivity_time_phaselocked(method, mode, data_option): """Test time-resolved spectral connectivity with simulated phase-locked data.""" rng = np.random.default_rng(0) n_epochs = 5 From 39f0ee657286978b0239eb1839d2a58143beb3f8 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 19 Oct 2022 10:40:45 +0300 Subject: [PATCH 30/47] Update docstring Add clear reference to MNE-Python functions. Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index c71bb6c0..f79ff68c 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -84,7 +84,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, Smoothing kernel type. Choose either 'square' or 'hanning' (default). mode : str Time-frequency decomposition method. Can be either: 'multitaper', or - 'cwt_morlet'. See ``tfr_array_multitaper`` and ``tfr_array_wavelet`` + 'cwt_morlet'. See `mne.time_frequency.tfr_array_multitaper` and `mne.time_frequency.tfr_array_wavelet` for reference. mt_bandwidth : float | None Multitaper time bandwidth. If None, will be set to 4.0 (3 tapers). From 4b6311cefdaf3f5ccec8e370b4bd485a2920cab0 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 19 Oct 2022 11:10:13 +0300 Subject: [PATCH 31/47] Add comments Add some comments to clarify the new test case for time-resolved spectral connectivity. --- mne_connectivity/spectral/tests/test_spectral.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 14eaa3dd..dbd7c5a4 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -510,11 +510,16 @@ def test_spectral_connectivity_time_sim(method, mode, data_option): assert con.shape == (n_channels ** 2, len(con.freqs)) con_matrix = con.get_data('dense')[..., 0] if data_option == 'sync': + # signals are perfectly phase-locked, connectivity matrix should be + # a lower triangular matrix of ones assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), atol=0.01) if data_option == 'random': + # signals are random, all connectivity values should be small + # 0.5 is picked rather arbitrarily such that the obsolete wrong + # implementation fails assert np.all(con_matrix) <= 0.5 From f6684c02ee68f78ebb8ef5f9cdbaf93226ec14a9 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 9 Nov 2022 14:45:21 +0200 Subject: [PATCH 32/47] DOC: Fix typos Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index f79ff68c..0645ef2d 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -101,7 +101,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, (specified by ``mode``). Fixed number or one per frequency. decim : int To reduce memory usage, decimation factor after time-frequency - decomposition. default 1 If int, returns tfr[…, ::decim]. If slice, + decomposition. Default to 1. If int, returns tfr[…, ::decim]. If slice, returns tfr[…, decim]. n_jobs : int Number of connections to compute in parallel. Memory mapping must be From 39fef9313b0d47916562852ca8514bf4adccfaeb Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 9 Nov 2022 14:46:21 +0200 Subject: [PATCH 33/47] DOC: Improve doc formulation Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 0645ef2d..2ddef235 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -135,7 +135,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, The connectivity measures are computed over time within each epoch and optionally averaged over epochs. High connectivity values indicate that - the phase differences between signals stay consistent over time. + the phase coupling (interpreted as estimated connectivity) differences between signals stay consistent over time. The spectral densities can be estimated using a multitaper method with digital prolate spheroidal sequence (DPSS) windows, or a continuous wavelet From 08b5c79f6180a8f841a2d59662d3dcc7bbb65089 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen <66060772+ruuskas@users.noreply.github.com> Date: Wed, 9 Nov 2022 14:47:37 +0200 Subject: [PATCH 34/47] DOC: Add note on memory mapping Co-authored-by: Adam Li --- mne_connectivity/spectral/time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 2ddef235..2ad251a0 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -199,6 +199,7 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, mne.set_config('MNE_CACHE_DIR', '/dev/shm') ``` +When ``MNE_MEMMAP_MIN_SIZE=None``, the underlying joblib implementation results in pickling and unpickling the whole array each time a pair of indices is accessed, which is slow, compared to memory mapping the array. This function was originally implemented in ``frites`` and was ported over. From ccb0a2db89cefcb1e29d14efbf7a09baa3406d29 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 9 Nov 2022 14:58:56 +0200 Subject: [PATCH 35/47] Remove unused names parameter --- mne_connectivity/spectral/time.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 2ad251a0..850e870c 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -17,7 +17,7 @@ @fill_doc -def spectral_connectivity_time(data, names=None, method='coh', average=False, +def spectral_connectivity_time(data, method='coh', average=False, indices=None, sfreq=2 * np.pi, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', @@ -36,7 +36,6 @@ def spectral_connectivity_time(data, names=None, method='coh', average=False, ---------- data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. - %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: From fe727f7035305f39f61fa8fb74f0d7b11b2e0892 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 9 Nov 2022 15:03:05 +0200 Subject: [PATCH 36/47] Require sfreq with array input --- mne_connectivity/spectral/time.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 850e870c..be8b531b 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -18,7 +18,7 @@ @fill_doc def spectral_connectivity_time(data, method='coh', average=False, - indices=None, sfreq=2 * np.pi, fmin=None, + indices=None, sfreq=None, fmin=None, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', mode='cwt_morlet', mt_bandwidth=None, @@ -56,7 +56,7 @@ def spectral_connectivity_time(data, method='coh', average=False, connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. If None, all connections are computed. sfreq : float - The sampling frequency. Should be specified if data is not ``Epochs``. + The sampling frequency. Required if data is not ``Epochs``. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. @@ -236,9 +236,8 @@ def spectral_connectivity_time(data, method='coh', average=False, names = np.arange(0, n_signals) metadata = None if sfreq is None: - warn("Sampling frequency (sfreq) was not specified and could not " - "be inferred from data. Using default value 2*numpy.pi. " - "Connectivity results might not be interpretable.") + raise ValueError('Sampling frequency (sfreq) is required with ' + 'array input.') # check that method is a list if isinstance(method, str): From 3b966eceb2750cfbe9dbb5281e4ee1bd52d9d52e Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 9 Nov 2022 16:55:05 +0200 Subject: [PATCH 37/47] DOC: Improve documentation Revised documentation of spectral_connectivity_time. --- mne_connectivity/spectral/time.py | 37 ++++++++++++++++++------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index be8b531b..a8bb9c61 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -83,21 +83,23 @@ def spectral_connectivity_time(data, method='coh', average=False, Smoothing kernel type. Choose either 'square' or 'hanning' (default). mode : str Time-frequency decomposition method. Can be either: 'multitaper', or - 'cwt_morlet'. See `mne.time_frequency.tfr_array_multitaper` and `mne.time_frequency.tfr_array_wavelet` - for reference. + 'cwt_morlet'. See `mne.time_frequency.tfr_array_multitaper` and + `mne.time_frequency.tfr_array_wavelet` for reference. mt_bandwidth : float | None - Multitaper time bandwidth. If None, will be set to 4.0 (3 tapers). - Time x (Full) Bandwidth product. The number of good tapers (low-bias) - is chosen automatically based on this to equal - floor(time_bandwidth - 1). By default None. + Product between the temporal window length (in seconds) and the full + frequency bandwidth (in Hz). This product can be seen as the surface + of the window on the time/frequency plane and controls the frequency + bandwidth (thus the frequency resolution) and the number of good + tapers. See `mne.time_frequency.tfr_array_multitaper` documentation. cwt_freqs : array Array of frequencies of interest for time-frequency decomposition. Only used in 'cwt_morlet' mode. Only the frequencies within - the range specified by fmin and fmax are used. Must be specified if + the range specified by fmin and fmax are used. Required if ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. n_cycles : float | array of float - Number of wavelet cycles for use in time-frequency decomposition method - (specified by ``mode``). Fixed number or one per frequency. + Number of cycles in the wavelet, either a fixed number or one per + frequency. The number of cycles n_cycles and the frequencies of + interest freqs define the temporal window length. decim : int To reduce memory usage, decimation factor after time-frequency decomposition. Default to 1. If int, returns tfr[…, ::decim]. If slice, @@ -134,7 +136,8 @@ def spectral_connectivity_time(data, method='coh', average=False, The connectivity measures are computed over time within each epoch and optionally averaged over epochs. High connectivity values indicate that - the phase coupling (interpreted as estimated connectivity) differences between signals stay consistent over time. + the phase coupling (interpreted as estimated connectivity) differences + between signals stay consistent over time. The spectral densities can be estimated using a multitaper method with digital prolate spheroidal sequence (DPSS) windows, or a continuous wavelet @@ -191,16 +194,20 @@ def spectral_connectivity_time(data, method='coh', average=False, Parallel computation can be activated by setting the ``n_jobs`` parameter. Under the hood, this utilizes the ``joblib`` library. For effective parallelization, you should activate memory mapping in MNE-Python by - setting ``MNE_MEMMAP_MIN_SIZE`` and ``MNE_CACHE_DIR``. For example, in your - code, run + setting ``MNE_MEMMAP_MIN_SIZE`` and ``MNE_CACHE_DIR``. Activating memory + mapping will make ``joblib`` store arrays greater than the minimum size on + disc, and forego direct RAM access for more efficient processing. + For example, in your code, run ``` mne.set_config('MNE_MEMMAP_MIN_SIZE', '10M') mne.set_config('MNE_CACHE_DIR', '/dev/shm') ``` -When ``MNE_MEMMAP_MIN_SIZE=None``, the underlying joblib implementation results in pickling and unpickling the whole array each time a pair of indices is accessed, which is slow, compared to memory mapping the array. - This function was originally implemented in ``frites`` and was - ported over. + When ``MNE_MEMMAP_MIN_SIZE=None``, the underlying joblib implementation + results in pickling and unpickling the whole array each time a pair of + indices is accessed, which is slow, compared to memory mapping the array. + + This function is based on ``conn_spec`` implementation in Frites. .. versionadded:: 0.3 From f082d6f05c364e8e0393d9c8a4ce05ba7f036a8b Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 9 Nov 2022 17:34:25 +0200 Subject: [PATCH 38/47] Add test for cwt_freqs --- .../spectral/tests/test_spectral.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 2694510e..4cbcf3f5 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -524,6 +524,46 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): assert np.all(con_matrix) <= 0.5 +@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) +@pytest.mark.parametrize( + 'cwt_freqs', [[8., 10.], [8, 10], 10., 10]) +def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs): + """Test time-resolved spectral connectivity with int and float values for + cwt_freqs.""" + rng = np.random.default_rng(0) + n_epochs = 5 + n_channels = 3 + n_times = 1000 + sfreq = 250 + data = np.zeros((n_epochs, n_channels, n_times)) + + # Data consists of phase-locked 10Hz sine waves with constant phase + # difference within each epoch. + wave_freq = 10 + epoch_length = n_times / sfreq + for i in range(n_epochs): + for c in range(n_channels): + phase = rng.random() * 10 + x = np.linspace(-wave_freq * epoch_length * np.pi + phase, + wave_freq * epoch_length * np.pi + phase, + n_times) + data[i, c] = np.squeeze(np.sin(x)) + # the frequency band should contain the frequency at which there is a + # hypothesized "connection" + con = spectral_connectivity_time(data, method=method, mode='cwt_morlet', + sfreq=sfreq, fmin=np.min(cwt_freqs), + fmax=np.max(cwt_freqs), + cwt_freqs=cwt_freqs, n_jobs=1, + faverage=True, average=True, sm_times=0) + assert con.shape == (n_channels ** 2, len(con.freqs)) + con_matrix = con.get_data('dense')[..., 0] + + # signals are perfectly phase-locked, connectivity matrix should be + # a lower triangular matrix of ones + assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1), + atol=0.01) + + @pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli']) @pytest.mark.parametrize( 'mode', ['cwt_morlet', 'multitaper']) From 84d073b37599bc18b1de574ab0637e2510c73e55 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Wed, 9 Nov 2022 17:37:41 +0200 Subject: [PATCH 39/47] BUG: Fix spectral_connectivity time Spectral connectivity computation failed if cwt_freqs was only a single number or an array with a single entry due to invalid array slicing. Fixed by incrementing the upper bound of the slice by one when computing the average in a frequency band. --- mne_connectivity/spectral/time.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a8bb9c61..0d5b61ba 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -593,5 +593,6 @@ def _foi_average(conn, foi_idx): # compute average conn_f = np.zeros(sh, dtype=conn.dtype) for n_f, (f_s, f_e) in enumerate(foi_idx): + f_e += 1 if f_s == f_e else f_e conn_f[..., n_f, :] = conn[..., f_s:f_e, :].mean(-2) return conn_f From 3e7f20810ee5fa7ff34d975955ab20ce86cac271 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Mon, 14 Nov 2022 11:34:25 +0200 Subject: [PATCH 40/47] Compute weighted average over CSD Compute a weighted average of the tapered cross spectra when using the multitaper mode. Weighting is derived from the concentration ratios between the DPSS windows. --- mne_connectivity/spectral/time.py | 71 ++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 0d5b61ba..f287ac8b 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -7,7 +7,8 @@ import xarray as xr from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper) +from mne.time_frequency import (tfr_array_morlet, tfr_array_multitaper, + dpss_windows) from mne.utils import (logger, warn) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) @@ -410,11 +411,25 @@ def _spectral_connectivity(data, method, kernel, foi_idx, data, sfreq, freqs, n_cycles=n_cycles, output='complex', decim=decim, n_jobs=n_jobs, **kw_cwt) out = np.expand_dims(out, axis=2) # same dims with multitaper + weights = None elif mode == 'multitaper': out = tfr_array_multitaper( data, sfreq, freqs, n_cycles=n_cycles, time_bandwidth=mt_bandwidth, output='complex', decim=decim, n_jobs=n_jobs, **kw_mt) + if isinstance(n_cycles, (int, float)): + n_cycles = [n_cycles] * len(freqs) + mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 + n_tapers = int(np.floor(mt_bandwidth - 1)) + weights = np.zeros((n_tapers, len(freqs), out.shape[-1])) + for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): + window_length = np.arange(0., n_c / float(f), 1.0 / sfreq).shape[0] + half_nbw = mt_bandwidth / 2. + n_tapers = int(np.floor(mt_bandwidth - 1)) + _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, + sym=False) + weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) + # weights have shape (n_tapers, n_freqs, n_times) else: raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") @@ -427,9 +442,7 @@ def _spectral_connectivity(data, method, kernel, foi_idx, this_conn[m] = c_func(out, kernel, foi_idx, source_idx, target_idx, n_jobs=n_jobs, verbose=verbose, total=n_pairs, - faverage=faverage) - # mean over tapers - this_conn[m] = [c.mean(axis=1) for c in this_conn[m]] + faverage=faverage, weights=weights) return this_conn @@ -441,22 +454,29 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### def _coh(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage): + faverage, weights): """Pairwise coherence. Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" - # auto spectra (faster than w * w.conj()) - s_auto = w.real ** 2 + w.imag ** 2 - # smooth the auto spectra - s_auto = _smooth_spectra(s_auto, kernel) + if weights is not None: + psd = weights * w + psd = psd * np.conj(psd) + psd = psd.real.sum(axis=2) + psd = psd * 2 / (weights * weights.conj()).real.sum(axis=0) + else: + psd = w.real ** 2 + w.imag ** 2 + psd = np.squeeze(psd, axis=2) + + # smooth the psd + psd = _smooth_spectra(psd, kernel) def pairwise_coh(w_x, w_y): - s_xy = w[:, w_y] * np.conj(w[:, w_x]) + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) s_xy = _smooth_spectra(s_xy, kernel) - s_xx = s_auto[:, w_x] - s_yy = s_auto[:, w_y] + s_xx = psd[:, w_x] + s_yy = psd[:, w_y] out = np.abs(s_xy.mean(axis=-1, keepdims=True)) / \ np.sqrt(s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True)) @@ -474,13 +494,13 @@ def pairwise_coh(w_x, w_y): def _plv(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage): + faverage, weights): """Pairwise phase-locking value. Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" def pairwise_plv(w_x, w_y): - s_xy = w[:, w_y] * np.conj(w[:, w_x]) + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) exp_dphi = s_xy / np.abs(s_xy) exp_dphi = _smooth_spectra(exp_dphi, kernel) # mean over time @@ -500,13 +520,13 @@ def pairwise_plv(w_x, w_y): def _pli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage): + faverage, weights): """Pairwise phase-lag index. Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" def pairwise_pli(w_x, w_y): - s_xy = w[:, w_y] * np.conj(w[:, w_x]) + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) s_xy = _smooth_spectra(s_xy, kernel) out = np.abs(np.mean(np.sign(np.imag(s_xy)), axis=-1, keepdims=True)) @@ -524,13 +544,13 @@ def pairwise_pli(w_x, w_y): def _wpli(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage): + faverage, weights): """Pairwise weighted phase-lag index. Input signal w is of shape (n_epochs, n_chans, n_tapers, n_freqs, n_times).""" def pairwise_wpli(w_x, w_y): - s_xy = w[:, w_y] * np.conj(w[:, w_x]) + s_xy = _compute_csd(w[:, w_y], w[:, w_x], weights) s_xy = _smooth_spectra(s_xy, kernel) con_num = np.abs(s_xy.imag.mean(axis=-1, keepdims=True)) con_den = np.mean(np.abs(s_xy.imag), axis=-1, keepdims=True) @@ -549,10 +569,10 @@ def pairwise_wpli(w_x, w_y): def _cs(w, kernel, foi_idx, source_idx, target_idx, n_jobs, verbose, total, - faverage): + faverage, weights): """Pairwise cross-spectra.""" def pairwise_cs(w_x, w_y): - out = w[:, w_x] * np.conj(w[:, w_y]) + out = _compute_csd(w[:, w_y], w[:, w_x], weights) out = _smooth_spectra(out, kernel) if isinstance(foi_idx, np.ndarray) and faverage: return _foi_average(out, foi_idx) @@ -566,6 +586,17 @@ def pairwise_cs(w_x, w_y): return parallel(p_fun(s, t) for s, t in zip(source_idx, target_idx)) +def _compute_csd(x, y, weights): + """Compute cross spectral density of signals x and y.""" + if weights is not None: + s_xy = np.sum(weights * x * np.conj(weights * y), axis=-3) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=-3) + else: + s_xy = x * np.conj(y) + s_xy = np.squeeze(s_xy, axis=-3) + return s_xy + + def _foi_average(conn, foi_idx): """Average inside frequency bands. From 71b61ad2c1b6853e7a6f75922a0e41b680995b06 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Tue, 15 Nov 2022 14:11:51 +0200 Subject: [PATCH 41/47] Fix style --- mne_connectivity/spectral/tests/test_spectral.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 4cbcf3f5..8b4c71a8 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -477,7 +477,8 @@ def test_epochs_tmin_tmax(kind): 'mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) def test_spectral_connectivity_time_phaselocked(method, mode, data_option): - """Test time-resolved spectral connectivity with simulated phase-locked data.""" + """Test time-resolved spectral connectivity with simulated phase-locked + data.""" rng = np.random.default_rng(0) n_epochs = 5 n_channels = 3 @@ -499,7 +500,8 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): wave_freq * epoch_length * np.pi + phase, n_times) data[i, c] = np.squeeze(np.sin(x)) - # the frequency band should contain the frequency at which there is a hypothesized "connection" + # the frequency band should contain the frequency at which there is a + # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) From 9e073c6407ea73c7503a44da313463b49cc4970c Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Tue, 15 Nov 2022 15:35:54 +0200 Subject: [PATCH 42/47] Update the docstring of spectral_connectivity_time Made the docstring more stylish, removed unnecessary things and added better compliance with MNE-Python style guidelines. --- mne_connectivity/spectral/time.py | 96 ++++++++++++++++--------------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index f287ac8b..d88909be 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -29,7 +29,7 @@ def spectral_connectivity_time(data, method='coh', average=False, This method computes time-resolved connectivity measures from epoched data. - The connectivity method(s) are specified using the "method" parameter. + The connectivity method(s) are specified using the ``method`` parameter. All methods are based on estimates of the cross- and power spectral densities (CSD/PSD) Sxy and Sxx, Syy. @@ -46,65 +46,64 @@ def spectral_connectivity_time(data, method='coh', average=False, * 'sxy' : Cross-spectrum * 'pli' : Phase-Lag Index * 'wpli': Weighted Phase-Lag Index - - By default, coherence is used. average : bool Average connectivity scores over epochs. If True, output will be - an instance of ``SpectralConnectivity`` , otherwise - ``EpochSpectralConnectivity``. By default, False. - indices : tuple of array | None + an instance of :class:`SpectralConnectivity` , otherwise + :class:`EpochSpectralConnectivity`. + indices : tuple of array_like | None Two arrays with indices of connections for which to compute connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. - If None, all connections are computed. + If `None`, all connections are computed. sfreq : float - The sampling frequency. Required if data is not ``Epochs``. + The sampling frequency. Required if data is not + :class:`Epochs `. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using - a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - If None, the frequency corresponding to an epoch length of 5 cycles - is used. + a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower + bounds. If `None`, the frequency corresponding to an epoch length of + 5 cycles is used. fmax : float | tuple of float | None The upper frequency of interest. Multiple bands are defined using - a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. - If None, sfreq/2 is used. + a tuple, e.g. ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper + bounds. If `None`, ``sfreq/2`` is used. fskip : int - Omit every "(fskip + 1)-th" frequency bin to decimate in frequency + Omit every ``(fskip + 1)``th frequency bin to decimate in frequency domain. faverage : bool - Average connectivity scores for each frequency band. If True, - the output freqs will be a list with arrays of the frequencies - that were averaged. By default, False. + Average connectivity scores for each frequency band. If `True`, + the output ``freqs`` will be an array of the median frequencies of each + band. sm_times : float Amount of time to consider for the temporal smoothing in seconds. - If zero, no temporal smoothing is applied. By default, 0. + If zero, no temporal smoothing is applied. sm_freqs : int Number of points for frequency smoothing. By default, 1 is used which is equivalent to no smoothing. sm_kernel : {'square', 'hanning'} - Smoothing kernel type. Choose either 'square' or 'hanning' (default). + Smoothing kernel type. Choose either 'square' or 'hanning'. mode : str Time-frequency decomposition method. Can be either: 'multitaper', or 'cwt_morlet'. See `mne.time_frequency.tfr_array_multitaper` and - `mne.time_frequency.tfr_array_wavelet` for reference. + `mne.time_frequency.tfr_array_morlet` for reference. mt_bandwidth : float | None Product between the temporal window length (in seconds) and the full - frequency bandwidth (in Hz). This product can be seen as the surface - of the window on the time/frequency plane and controls the frequency - bandwidth (thus the frequency resolution) and the number of good - tapers. See `mne.time_frequency.tfr_array_multitaper` documentation. - cwt_freqs : array + frequency bandwidth (in Hz). This product can be seen as the surface + of the window on the time/frequency plane and controls the frequency + bandwidth (thus the frequency resolution) and the number of good + tapers. See `mne.time_frequency.tfr_array_multitaper` documentation. + cwt_freqs : array_like Array of frequencies of interest for time-frequency decomposition. Only used in 'cwt_morlet' mode. Only the frequencies within - the range specified by fmin and fmax are used. Required if + the range specified by ``fmin`` and ``fmax`` are used. Required if ``mode='cwt_morlet'``. Not used when ``mode='multitaper'``. - n_cycles : float | array of float + n_cycles : float | array_like of float Number of cycles in the wavelet, either a fixed number or one per - frequency. The number of cycles n_cycles and the frequencies of - interest freqs define the temporal window length. + frequency. The number of cycles ``n_cycles`` and the frequencies of + interest ``cwt_freqs`` define the temporal window length. For details, + see `mne.time_frequency.tfr_array_morlet` documentation. decim : int To reduce memory usage, decimation factor after time-frequency - decomposition. Default to 1. If int, returns tfr[…, ::decim]. If slice, - returns tfr[…, decim]. + decomposition. Returns ``tfr[…, ::decim]``. n_jobs : int Number of connections to compute in parallel. Memory mapping must be activated. Please see the Notes section for details. @@ -114,13 +113,13 @@ def spectral_connectivity_time(data, method='coh', average=False, ------- con : instance of Connectivity | list Computed connectivity measure(s). An instance of - ``EpochSpectralConnectivity``, ``SpectralConnectivity`` + `EpochSpectralConnectivity`, `SpectralConnectivity` or a list of instances corresponding to connectivity measures if several connectivity measures are specified. The shape of each connectivity dataset is - (n_epochs, n_signals, n_signals, n_freqs) when indices is None - and (n_epochs, n_nodes, n_nodes, n_freqs) when "indices" is specified - and "n_nodes = len(indices[0])". + (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None` + and (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified + and ``n_nodes = len(indices[0])``. See Also -------- @@ -143,14 +142,19 @@ def spectral_connectivity_time(data, method='coh', average=False, The spectral densities can be estimated using a multitaper method with digital prolate spheroidal sequence (DPSS) windows, or a continuous wavelet transform using Morlet wavelets. The spectral estimation mode is specified - using the "mode" parameter. + using the ``mode`` parameter. + + When using the multitaper spectral estimation method, the + cross-spectral density is computed separately for each taper and aggregated + using a weighted average, where the weights correspond to the concentration + ratios between the DPSS windows. By default, the connectivity between all signals is computed (only connections corresponding to the lower-triangular part of the connectivity matrix). If one is only interested in the connectivity - between some signals, the "indices" parameter can be used. For example, + between some signals, the ``indices`` parameter can be used. For example, to compute the connectivity between the signal with index 0 and signals - "2, 3, 4" (a total of 3 connections) one can use the following:: + 2, 3, 4 (a total of 3 connections), one can use the following:: indices = (np.array([0, 0, 0]), # row indices np.array([2, 3, 4])) # col indices @@ -158,12 +162,12 @@ def spectral_connectivity_time(data, method='coh', average=False, con = spectral_connectivity_time(data, method='coh', indices=indices, ...) - In this case con.get_data().shape = (3, n_freqs). The connectivity scores - are in the same order as defined indices. + In this case ``con.get_data().shape = (3, n_freqs)``. The connectivity + scores are in the same order as defined indices. **Supported Connectivity Measures** - The connectivity method(s) is specified using the "method" parameter. The + The connectivity method(s) is specified using the ``method`` parameter. The following methods are supported (note: ``E[]`` denotes average over epochs). Multiple measures can be computed at once by using a list/tuple, e.g., ``['coh', 'pli']`` to compute coherence and PLI. @@ -199,16 +203,16 @@ def spectral_connectivity_time(data, method='coh', average=False, mapping will make ``joblib`` store arrays greater than the minimum size on disc, and forego direct RAM access for more efficient processing. For example, in your code, run - ``` - mne.set_config('MNE_MEMMAP_MIN_SIZE', '10M') - mne.set_config('MNE_CACHE_DIR', '/dev/shm') - ``` + + mne.set_config('MNE_MEMMAP_MIN_SIZE', '10M') + mne.set_config('MNE_CACHE_DIR', '/dev/shm') When ``MNE_MEMMAP_MIN_SIZE=None``, the underlying joblib implementation results in pickling and unpickling the whole array each time a pair of indices is accessed, which is slow, compared to memory mapping the array. - This function is based on ``conn_spec`` implementation in Frites. + This function is based on the ``frites.conn.conn_spec`` implementation in + Frites. .. versionadded:: 0.3 From dcfbc8bdf545530cf13ea71b65429c3abe331b8d Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Tue, 15 Nov 2022 15:37:03 +0200 Subject: [PATCH 43/47] Remove unnecessary defaults The _spectral_connectivity function doesn't need defaults as these are already spelled out in the main spectral_connectivity_time function signature. --- mne_connectivity/spectral/time.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index d88909be..71f12d64 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -403,8 +403,8 @@ def spectral_connectivity_time(data, method='coh', average=False, def _spectral_connectivity(data, method, kernel, foi_idx, source_idx, target_idx, mode, sfreq, freqs, faverage, n_cycles, - mt_bandwidth=None, decim=1, kw_cwt={}, kw_mt={}, - n_jobs=1, verbose=False): + mt_bandwidth, decim, kw_cwt, kw_mt, + n_jobs, verbose): """Estimate time-resolved connectivity for one epoch. See spectral_connectivity_epochs.""" From fb9869f28811fe22d106ddfaaa44ce95fdfd4841 Mon Sep 17 00:00:00 2001 From: Santeri Ruuskanen Date: Thu, 17 Nov 2022 14:48:30 +0200 Subject: [PATCH 44/47] Add entries in whats_new.rst and authors.inc --- doc/authors.inc | 1 + doc/whats_new.rst | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/doc/authors.inc b/doc/authors.inc index a4caeb5c..096e5d55 100644 --- a/doc/authors.inc +++ b/doc/authors.inc @@ -7,3 +7,4 @@ .. _Szonja Weigl: https://github.com/weiglszonja .. _Kenji Marshall: https://github.com/kenjimarshall .. _Sezan Mert: https://github.com/SezanMert +.. _Santeri Ruuskanen: https://github.com/ruuskas diff --git a/doc/whats_new.rst b/doc/whats_new.rst index e360570c..2e8031fd 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -19,27 +19,35 @@ Here we list a changelog of MNE-connectivity. Version 0.5 (Unreleased) ------------------------ -... +This version has major changes in :func:`mne_connectivity.spectral_connectivity_time`. Several bugs are fixed, and the +function now computes static connectivity over time, as opposed to static connectivity over trials computed by :func:`mne_connectivity.spectral_connectivity_epochs`. Enhancements ~~~~~~~~~~~~ -- +- Add the ``PLI`` and ``wPLI`` methods in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). Bug ~~~ -- +- When using the ``multitaper`` mode in :func:`mne_connectivity.spectral_connectivity_time`, average CSD over tapers instead of the complex signal by `Santeri Ruuskanen`_ (:gh:`104`). +- Average over time when computing connectivity measures in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Fix support for multiple connectivity methods in calls to :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`). +- Fix bug with the ``indices`` parameter in :func:`mne_connectivity.spectral_connectivity_time`, the behavior is now as expected by `Santeri Ruuskanen`_ (:gh:`104`). +- Fix bug with parallel computation in :func:`mne_connectivity.spectral_connectivity_time`, add instructions for memory mapping in doc by `Santeri Ruuskanen`_ (:gh:`104`). API ~~~ -- +- Streamline the API of :func:`mne_connectivity.spectral_connectivity_time` with :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`). Authors ~~~~~~~ -* +* `Santeri Ruuskanen`_ :doc:`Find out what was new in previous releases ` From eec01130d560aa83e39c7d33340615ccaba88095 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 17 Nov 2022 10:14:45 -0500 Subject: [PATCH 45/47] FIX: Test --- mne_connectivity/conftest.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/conftest.py b/mne_connectivity/conftest.py index 33487f5a..d1366cfb 100644 --- a/mne_connectivity/conftest.py +++ b/mne_connectivity/conftest.py @@ -151,6 +151,6 @@ def _check_skip_backend(name): if not has_imageio_ffmpeg(): pytest.skip("Test skipped, requires imageio-ffmpeg") if name == 'pyvistaqt' and not _check_qt_version(): - pytest.skip("Test skipped, requires PyQt5.") + pytest.skip("Test skipped, requires Python Qt bindings.") if name == 'pyvistaqt' and not has_pyvistaqt(): pytest.skip("Test skipped, requires pyvistaqt") diff --git a/requirements.txt b/requirements.txt index 0f5821b0..067ea5c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ h5netcdf tqdm matplotlib qtpy -PySide6!=6.3.0 +PySide6!=6.3.0,!=6.4.0,!=6.4.0.1 sip pyvista>=0.30 pyvistaqt>=0.4 From b2dda41850222d6867fc3041c4f8bc018fe4f91a Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 17 Nov 2022 10:39:47 -0500 Subject: [PATCH 46/47] FIX: Doc build --- doc/conf.py | 3 ++- mne_connectivity/spectral/time.py | 20 ++++++++++---------- requirements_doc.txt | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 7532bfb3..1fc9fe0a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -197,7 +197,8 @@ 'use_edit_page_button': False, 'navigation_with_keys': False, 'show_toc_level': 1, - 'navbar_end': ['version-switcher', 'navbar-icon-links'], + 'navbar_end': ['theme-switcher', 'version-switcher', 'navbar-icon-links'], + 'secondary_sidebar_items': ['page-toc'], } # Custom sidebar templates, maps document names to template names. html_sidebars = { diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 71f12d64..1c16052a 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -38,8 +38,8 @@ def spectral_connectivity_time(data, method='coh', average=False, data : array_like, shape (n_epochs, n_signals, n_times) | Epochs The data from which to compute connectivity. method : str | list of str - Connectivity measure(s) to compute. These can be ``['coh', 'plv', - 'sxy', 'pli', 'wpli']``. These are: + Connectivity measure(s) to compute. These can be + ``['coh', 'plv', 'sxy', 'pli', 'wpli']``. These are: * 'coh' : Coherence * 'plv' : Phase-Locking Value (PLV) @@ -48,7 +48,7 @@ def spectral_connectivity_time(data, method='coh', average=False, * 'wpli': Weighted Phase-Lag Index average : bool Average connectivity scores over epochs. If True, output will be - an instance of :class:`SpectralConnectivity` , otherwise + an instance of :class:`SpectralConnectivity`, otherwise :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None Two arrays with indices of connections for which to compute @@ -67,7 +67,7 @@ def spectral_connectivity_time(data, method='coh', average=False, a tuple, e.g. ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper bounds. If `None`, ``sfreq/2`` is used. fskip : int - Omit every ``(fskip + 1)``th frequency bin to decimate in frequency + Omit every ``(fskip + 1)``-th frequency bin to decimate in frequency domain. faverage : bool Average connectivity scores for each frequency band. If `True`, @@ -83,14 +83,15 @@ def spectral_connectivity_time(data, method='coh', average=False, Smoothing kernel type. Choose either 'square' or 'hanning'. mode : str Time-frequency decomposition method. Can be either: 'multitaper', or - 'cwt_morlet'. See `mne.time_frequency.tfr_array_multitaper` and - `mne.time_frequency.tfr_array_morlet` for reference. + 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. mt_bandwidth : float | None Product between the temporal window length (in seconds) and the full frequency bandwidth (in Hz). This product can be seen as the surface of the window on the time/frequency plane and controls the frequency bandwidth (thus the frequency resolution) and the number of good - tapers. See `mne.time_frequency.tfr_array_multitaper` documentation. + tapers. See :func:`mne.time_frequency.tfr_array_multitaper` + documentation. cwt_freqs : array_like Array of frequencies of interest for time-frequency decomposition. Only used in 'cwt_morlet' mode. Only the frequencies within @@ -100,7 +101,7 @@ def spectral_connectivity_time(data, method='coh', average=False, Number of cycles in the wavelet, either a fixed number or one per frequency. The number of cycles ``n_cycles`` and the frequencies of interest ``cwt_freqs`` define the temporal window length. For details, - see `mne.time_frequency.tfr_array_morlet` documentation. + see :func:`mne.time_frequency.tfr_array_morlet` documentation. decim : int To reduce memory usage, decimation factor after time-frequency decomposition. Returns ``tfr[…, ::decim]``. @@ -113,7 +114,7 @@ def spectral_connectivity_time(data, method='coh', average=False, ------- con : instance of Connectivity | list Computed connectivity measure(s). An instance of - `EpochSpectralConnectivity`, `SpectralConnectivity` + :class:`EpochSpectralConnectivity`, :class:`SpectralConnectivity` or a list of instances corresponding to connectivity measures if several connectivity measures are specified. The shape of each connectivity dataset is @@ -129,7 +130,6 @@ def spectral_connectivity_time(data, method='coh', average=False, Notes ----- - Please note that the interpretation of the measures in this function depends on the data and underlying assumptions and does not necessarily reflect a causal relationship between brain regions. diff --git a/requirements_doc.txt b/requirements_doc.txt index 49330528..51d4fa53 100644 --- a/requirements_doc.txt +++ b/requirements_doc.txt @@ -6,7 +6,7 @@ sphinx-copybutton numpydoc nibabel nilearn -pydata-sphinx-theme +https://github.com/pydata/pydata-sphinx-theme/archive/cef3e724e15852fc2a84bee256c457c9497834b8.zip typing-extensions sphinx-autodoc-typehints sphinxcontrib-bibtex From 099f194f4a4b8d5b64f0bfcc5e17895814490d05 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Thu, 17 Nov 2022 11:12:58 -0500 Subject: [PATCH 47/47] FIX: Not pre --- tools/circleci_dependencies.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index e656f5af..78586360 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -2,7 +2,7 @@ echo "Installing setuptools and sphinx" python -m pip install --progress-bar off --upgrade "pip!=20.3.0" setuptools wheel -python -m pip install --upgrade --progress-bar off --pre sphinx +python -m pip install --upgrade --progress-bar off sphinx echo "Installing doc build dependencies" python -m pip uninstall -y pydata-sphinx-theme