Skip to content

Commit

Permalink
[ENH] Multiple improvements to spectral_connectivity_time: ciPLV, and…
Browse files Browse the repository at this point in the history
… efficient computation of multiple metrics (mne-tools#115)

* Add ciPLV: Add the corrected imaginary Phase-Locking-Value into the list of
available connectivity metrics.

* Speed up computation: All connectivity measures are now computed with only a single
computation of pairwise cross spectrum.

* Add the option to specify freqs in all modes: In some scenarios, users might want to specify the frequencies for
time-frequency decomposition also when using multitapering. These
changes allow users to specify the 'freqs' parameter to override the
automatically determined frequencies.

* BUG: Average over CSD instead of connectivity

* Add option to use part of signal as padding: This adds the option to use the edges of the signal at each epoch as
padding. The purpose of this is to avoid edge effects generated by the
time-frequency transformation methods.

* Fix test bug, use 'freqs' instead of 'cwt_freqs'

* Fix bug with dpss windows: Sym is not a parameter of dpss_windows. (But is one of the underlying
scipy.signal.dpss)

* Only show progress bar if verbosity level is DEBUG: This change will skip the rendering of the connectivity computation progress bar if the logging level is not DEBUG. This is in line with
MNE-Python, where progress bars are not shown at INFO or higher logging
levels. Rendering the progress bar regardless of logging levels has the
potential to cause unnecessary clutter in users' log files.

* Require freqs in all tfr modes

The user is required to specify the wavelet central frequencies in both
multitaper and cwt_morlet tfr mode. The reasoning is that the underlying
 tfr implementations are very similar. This is in contrast to
 spectral_connectivity_epochs, where multitaper assumes that the
 spectrum is stationary and therefore no wavelets are used.

* Require mne>=1.3

Signed-off-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
ruuskas and adam2392 committed Jan 10, 2023
1 parent 3684890 commit 627c661
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 207 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Enhancements
- Improve the documentation of :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Add the option to average connectivity across epochs and frequencies in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`104`).
- Select multitaper frequencies automatically in :func:`mne_connectivity.spectral_connectivity_time` similarly to :func:`mne_connectivity.spectral_connectivity_epochs` by `Santeri Ruuskanen`_ (:gh:`104`).
- Add the ``ciPLV`` method in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`).
- Add the option to use the edges of each epoch as padding in :func:`mne_connectivity.spectral_connectivity_time` by `Santeri Ruuskanen`_ (:gh:`115`).

Bug
~~~
Expand Down
86 changes: 72 additions & 14 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_epochs_tmin_tmax(kind):
assert len(w) == 1 # just one even though there were multiple epochs


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
@pytest.mark.parametrize('data_option', ['sync', 'random'])
Expand Down Expand Up @@ -504,11 +504,11 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
# hypothesized "connection"
freq_band_low_limit = (8.)
freq_band_high_limit = (13.)
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
con = spectral_connectivity_time(data, method=method, mode=mode,
freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
con = spectral_connectivity_time(data, freqs, method=method, mode=mode,
sfreq=sfreq, fmin=freq_band_low_limit,
fmax=freq_band_high_limit,
cwt_freqs=cwt_freqs, n_jobs=1,
n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]
Expand All @@ -526,12 +526,13 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
assert np.all(con_matrix) <= 0.5


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
@pytest.mark.parametrize(
'cwt_freqs', [[8., 10.], [8, 10], 10., 10])
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
'freqs', [[8., 10.], [8, 10], 10., 10])
@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper'])
def test_spectral_connectivity_time_freqs(method, freqs, mode):
"""Test time-resolved spectral connectivity with int and float values for
cwt_freqs."""
freqs."""
rng = np.random.default_rng(0)
n_epochs = 5
n_channels = 3
Expand All @@ -552,10 +553,10 @@ def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
data[i, c] = np.squeeze(np.sin(x))
# the frequency band should contain the frequency at which there is a
# hypothesized "connection"
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet',
sfreq=sfreq, fmin=np.min(cwt_freqs),
fmax=np.max(cwt_freqs),
cwt_freqs=cwt_freqs, n_jobs=1,
con = spectral_connectivity_time(data, freqs, method=method,
mode=mode, sfreq=sfreq,
fmin=np.min(freqs),
fmax=np.max(freqs), n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]
Expand Down Expand Up @@ -588,12 +589,12 @@ def test_spectral_connectivity_time_resolved(method, mode):
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
data = EpochsArray(data, info)

# define some frequencies for cwt
# define some frequencies for tfr
freqs = np.arange(3, 20.5, 1)

# run connectivity estimation
con = spectral_connectivity_time(
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode,
data, freqs, sfreq=sfreq, method=method, mode=mode,
n_cycles=5)
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
assert con.get_data(output='dense').shape == \
Expand All @@ -613,6 +614,63 @@ def test_spectral_connectivity_time_resolved(method, mode):
for idx, jdx in triu_inds)


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
@pytest.mark.parametrize('padding', [0, 1, 5])
def test_spectral_connectivity_time_padding(method, mode, padding):
"""Test time-resolved spectral connectivity with padding."""
sfreq = 50.
n_signals = 3
n_epochs = 2
n_times = 300
trans_bandwidth = 2.
tmin = 0.
tmax = (n_times - 1) / sfreq
# 5Hz..15Hz
fstart, fend = 5.0, 15.0
data, _ = create_test_dataset(
sfreq, n_signals=n_signals, n_epochs=n_epochs, n_times=n_times,
tmin=tmin, tmax=tmax,
fstart=fstart, fend=fend, trans_bandwidth=trans_bandwidth)
ch_names = np.arange(n_signals).astype(str).tolist()
info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
data = EpochsArray(data, info)

# define some frequencies for tfr
freqs = np.arange(3, 20.5, 1)

# run connectivity estimation
if padding == 5:
with pytest.raises(ValueError, match='Padding cannot be larger than '
'half of data length'):
con = spectral_connectivity_time(
data, freqs, sfreq=sfreq, method=method, mode=mode,
n_cycles=5, padding=padding)
return
else:
con = spectral_connectivity_time(
data, freqs, sfreq=sfreq, method=method, mode=mode,
n_cycles=5, padding=padding)

assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
assert con.get_data(output='dense').shape == \
(n_epochs, n_signals, n_signals, len(con.freqs))

# test the simulated signal
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T

# average over frequencies
conn_data = con.get_data(output='dense').mean(axis=-1)

# the indices at which there is a correlation should be greater
# then the rest of the components
for epoch_idx in range(n_epochs):
high_conn_val = conn_data[epoch_idx, 0, 1]
assert all(high_conn_val >= conn_data[epoch_idx, idx, jdx]
for idx, jdx in triu_inds)


def test_save(tmp_path):
"""Test saving results of spectral connectivity."""
rng = np.random.RandomState(0)
Expand Down
Loading

0 comments on commit 627c661

Please sign in to comment.