Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] [BUG] [ENH] [WIP] Bug fixes and enhancements for time-resolved spectral connectivity estimation #104

Merged
merged 51 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
41bd87b
FIX: compute connectivity over multiple tapers when mode='multitaper'…
ruuskas Aug 31, 2022
05ae8f6
require MNE-Python 1.0 or newer due to breaking changes in mne.time_f…
ruuskas Aug 31, 2022
9cc283c
update tests corresponding to API changes and use longer test signal
ruuskas Aug 31, 2022
924b4ab
update docstring: connectivity is not averaged over Epochs by default
ruuskas Sep 1, 2022
c5b75f8
fix docstring typo
ruuskas Sep 1, 2022
7699983
update docstring: faverage is False by default
ruuskas Sep 1, 2022
89f683a
update docstring: lower bound of frequency range for connectivity com…
ruuskas Sep 1, 2022
cb2dab5
update docstring: fmax may be None
ruuskas Sep 1, 2022
0fd34c9
update docstring
ruuskas Sep 2, 2022
f4992a7
new default for fmin
ruuskas Sep 2, 2022
cc1d69d
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas Sep 2, 2022
fa9bce7
improve docstring and warnings for spectral_connectivity_time
ruuskas Sep 2, 2022
9e9f983
fix bug with indices: connectivity is now computed correctly between …
ruuskas Sep 5, 2022
2e4ccd1
DOC: improve documentation
ruuskas Sep 5, 2022
926c330
change smoothing default to no smoothing
ruuskas Sep 7, 2022
b881c34
DOC: updates to main docstring
ruuskas Sep 7, 2022
73fb937
BUG: number of blocks is now computed correctly
ruuskas Sep 7, 2022
baaca26
add test for time-resolved connectivity with simulated data
ruuskas Sep 12, 2022
2a7e09e
Change block_size default to 1
ruuskas Sep 22, 2022
2a0be06
Add documentation for block_size
ruuskas Sep 22, 2022
53ac7b5
Change for more useful variable names
ruuskas Sep 22, 2022
bc61e54
Remove regression test
ruuskas Oct 5, 2022
7ce13bd
Remove block_size parameter
ruuskas Oct 5, 2022
c7dd18c
Improve documentation
ruuskas Oct 5, 2022
4d2c1f0
Improve comments
ruuskas Oct 5, 2022
1b6224f
Remove unused code
ruuskas Oct 5, 2022
054512b
Merge branch 'main' into spectral_time
adam2392 Oct 16, 2022
283a1a1
Fix style issues
ruuskas Oct 17, 2022
083aa1b
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas Oct 17, 2022
e89ac77
Add comment
ruuskas Oct 19, 2022
19603bd
Improve comment
ruuskas Oct 19, 2022
e5da3ae
Rename test function
ruuskas Oct 19, 2022
39f0ee6
Update docstring
ruuskas Oct 19, 2022
4b6311c
Add comments
ruuskas Oct 19, 2022
7c64633
Merge branch 'spectral_time' of github.com:ruuskas/mne-connectivity i…
ruuskas Oct 19, 2022
f6684c0
DOC: Fix typos
ruuskas Nov 9, 2022
39fef93
DOC: Improve doc formulation
ruuskas Nov 9, 2022
08b5c79
DOC: Add note on memory mapping
ruuskas Nov 9, 2022
ccb0a2d
Remove unused names parameter
ruuskas Nov 9, 2022
fe727f7
Require sfreq with array input
ruuskas Nov 9, 2022
3b966ec
DOC: Improve documentation
ruuskas Nov 9, 2022
f082d6f
Add test for cwt_freqs
ruuskas Nov 9, 2022
84d073b
BUG: Fix spectral_connectivity time
ruuskas Nov 9, 2022
3e7f208
Compute weighted average over CSD
ruuskas Nov 14, 2022
71b61ad
Fix style
ruuskas Nov 15, 2022
9e073c6
Update the docstring of spectral_connectivity_time
ruuskas Nov 15, 2022
dcfbc8b
Remove unnecessary defaults
ruuskas Nov 15, 2022
fb9869f
Add entries in whats_new.rst and authors.inc
ruuskas Nov 17, 2022
eec0113
FIX: Test
larsoner Nov 17, 2022
b2dda41
FIX: Doc build
larsoner Nov 17, 2022
099f194
FIX: Not pre
larsoner Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: mne
name: mne-connectivity
channels:
- conda-forge
dependencies:
Expand All @@ -20,5 +20,5 @@ dependencies:
- pyvista>=0.32
- pyvistaqt>=0.4
- pyqt!=5.15.3
- mne
- mne>=1.0
- h5netcdf
13 changes: 8 additions & 5 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,

See Also
--------
mne_connectivity.spectral_connectivity_time
mne_connectivity.SpectralConnectivity
mne_connectivity.SpectroTemporalConnectivity

Expand All @@ -873,7 +874,9 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
connectivity structure. Within each Epoch, it is assumed that the spectral
measure is stationary. The spectral measures implemented in this function
are computed across Epochs. **Thus, spectral measures computed with only
one Epoch will result in errorful values.**
one Epoch will result in errorful values and spectral measures computed
with few Epochs will be unreliable.** Please see
``spectral_connectivity_time`` for time-resolved connectivity estimation.

The spectral densities can be estimated using a multitaper method with
digital prolate spheroidal sequence (DPSS) windows, a discrete Fourier
Expand All @@ -891,11 +894,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
indices = (np.array([0, 0, 0]), # row indices
np.array([2, 3, 4])) # col indices

con_flat = spectral_connectivity(data, method='coh',
indices=indices, ...)
con = spectral_connectivity_epochs(data, method='coh',
indices=indices, ...)

In this case con_flat.shape = (3, n_freqs). The connectivity scores are
in the same order as defined indices.
In this case con.get_data().shape = (3, n_freqs). The connectivity scores
are in the same order as defined indices.

**Supported Connectivity Measures**

Expand Down
210 changes: 104 additions & 106 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@
from numpy.testing import (assert_allclose, assert_array_almost_equal,
assert_array_less)
import pytest
import warnings

import mne
from mne import (EpochsArray, SourceEstimate, create_info,
make_fixed_length_epochs)
from mne import (EpochsArray, SourceEstimate, create_info)
from mne.filter import filter_data
from mne.utils import _resource_path
from mne_bids import BIDSPath, read_raw_bids

from mne_connectivity import (
SpectralConnectivity, spectral_connectivity_epochs,
Expand Down Expand Up @@ -478,15 +472,109 @@ def test_epochs_tmin_tmax(kind):
assert len(w) == 1 # just one even though there were multiple epochs


@pytest.mark.parametrize('method', ['coh', 'plv'])
@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
@pytest.mark.parametrize('data_option', ['sync', 'random'])
def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
"""Test time-resolved spectral connectivity with simulated phase-locked
data."""
rng = np.random.default_rng(0)
n_epochs = 5
n_channels = 3
n_times = 1000
sfreq = 250
data = np.zeros((n_epochs, n_channels, n_times))
if data_option == 'random':
# Data is random, there should be no consistent phase differences.
data = rng.random((n_epochs, n_channels, n_times))
if data_option == 'sync':
# Data consists of phase-locked 10Hz sine waves with constant phase
# difference within each epoch.
wave_freq = 10
epoch_length = n_times / sfreq
for i in range(n_epochs):
for c in range(n_channels):
phase = rng.random() * 10
x = np.linspace(-wave_freq * epoch_length * np.pi + phase,
wave_freq * epoch_length * np.pi + phase,
n_times)
data[i, c] = np.squeeze(np.sin(x))
# the frequency band should contain the frequency at which there is a
# hypothesized "connection"
freq_band_low_limit = (8.)
ruuskas marked this conversation as resolved.
Show resolved Hide resolved
freq_band_high_limit = (13.)
cwt_freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1)
con = spectral_connectivity_time(data, method=method, mode=mode,
sfreq=sfreq, fmin=freq_band_low_limit,
fmax=freq_band_high_limit,
cwt_freqs=cwt_freqs, n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]
if data_option == 'sync':
# signals are perfectly phase-locked, connectivity matrix should be
# a lower triangular matrix of ones
assert np.allclose(con_matrix,
np.tril(np.ones(con_matrix.shape),
k=-1),
atol=0.01)
if data_option == 'random':
# signals are random, all connectivity values should be small
# 0.5 is picked rather arbitrarily such that the obsolete wrong
# implementation fails
assert np.all(con_matrix) <= 0.5


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'cwt_freqs', [[8., 10.], [8, 10], 10., 10])
def test_spectral_connectivity_time_cwt_freqs(method, cwt_freqs):
"""Test time-resolved spectral connectivity with int and float values for
cwt_freqs."""
rng = np.random.default_rng(0)
n_epochs = 5
n_channels = 3
n_times = 1000
sfreq = 250
data = np.zeros((n_epochs, n_channels, n_times))

# Data consists of phase-locked 10Hz sine waves with constant phase
# difference within each epoch.
wave_freq = 10
epoch_length = n_times / sfreq
for i in range(n_epochs):
for c in range(n_channels):
phase = rng.random() * 10
x = np.linspace(-wave_freq * epoch_length * np.pi + phase,
wave_freq * epoch_length * np.pi + phase,
n_times)
data[i, c] = np.squeeze(np.sin(x))
# the frequency band should contain the frequency at which there is a
# hypothesized "connection"
con = spectral_connectivity_time(data, method=method, mode='cwt_morlet',
sfreq=sfreq, fmin=np.min(cwt_freqs),
fmax=np.max(cwt_freqs),
cwt_freqs=cwt_freqs, n_jobs=1,
faverage=True, average=True, sm_times=0)
assert con.shape == (n_channels ** 2, len(con.freqs))
con_matrix = con.get_data('dense')[..., 0]

# signals are perfectly phase-locked, connectivity matrix should be
# a lower triangular matrix of ones
assert np.allclose(con_matrix, np.tril(np.ones(con_matrix.shape), k=-1),
atol=0.01)


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli'])
@pytest.mark.parametrize(
'mode', ['cwt_morlet', 'multitaper'])
def test_spectral_connectivity_time_resolved(method, mode):
"""Test time-resolved spectral connectivity."""
sfreq = 50.
n_signals = 3
n_epochs = 2
n_times = 256
n_times = 1000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we increase the n_times here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an issue with the length of the wavelets being longer than the signal at some point during testing. Now it appears that the earlier value 256 would work just fine.

trans_bandwidth = 2.
tmin = 0.
tmax = (n_times - 1) / sfreq
Expand All @@ -502,22 +590,21 @@ def test_spectral_connectivity_time_resolved(method, mode):

# define some frequencies for cwt
freqs = np.arange(3, 20.5, 1)
n_freqs = len(freqs)

# run connectivity estimation
con = spectral_connectivity_time(
data, freqs=freqs, method=method, mode=mode)
assert con.shape == (n_epochs, n_signals * 2, n_freqs, n_times)
data, sfreq=sfreq, cwt_freqs=freqs, method=method, mode=mode,
n_cycles=5)
assert con.shape == (n_epochs, n_signals ** 2, len(con.freqs))
assert con.get_data(output='dense').shape == \
(n_epochs, n_signals, n_signals, n_freqs, n_times)

# average over time
conn_data = con.get_data(output='dense').mean(axis=-1)
conn_data = conn_data.mean(axis=-1)
(n_epochs, n_signals, n_signals, len(con.freqs))

# test the simulated signal
triu_inds = np.vstack(np.triu_indices(n_signals, k=1)).T

# average over frequencies
conn_data = con.get_data(output='dense').mean(axis=-1)

# the indices at which there is a correlation should be greater
# then the rest of the components
for epoch_idx in range(n_epochs):
Expand All @@ -526,95 +613,6 @@ def test_spectral_connectivity_time_resolved(method, mode):
for idx, jdx in triu_inds)


@pytest.mark.parametrize('method', ['coh', 'plv'])
@pytest.mark.parametrize(
'mode', ['morlet', 'multitaper'])
def test_time_resolved_spectral_conn_regression(method, mode):
"""Regression test against original implementation in Frites.

To see how the test dataset was generated, see
``benchmarks/single_epoch_conn.py``.
"""
test_file_path_str = str(_resource_path(
'mne_connectivity.tests',
f'data/test_frite_dataset_{mode}_{method}.npy'))
test_conn = np.load(test_file_path_str)

# paths to mne datasets - sample ECoG
bids_root = mne.datasets.epilepsy_ecog.data_path()

# first define the BIDS path and load in the dataset
bids_path = BIDSPath(root=bids_root, subject='pt1', session='presurgery',
task='ictal', datatype='ieeg', extension='.vhdr')
with warnings.catch_warnings():
warnings.simplefilter("ignore")
raw = read_raw_bids(bids_path=bids_path, verbose=False)
line_freq = raw.info['line_freq']

# Pick only the ECoG channels, removing the ECG channels
raw.pick_types(ecog=True)

# drop bad channels
raw.drop_channels(raw.info['bads'])

# only pick the first three channels to lower RAM usage
raw = raw.pick_channels(raw.ch_names[:3])

# Load the data
raw.load_data()

# Then we remove line frequency interference
raw.notch_filter(line_freq)

# crop data and then Epoch
raw_copy = raw.copy()
raw = raw.crop(tmin=0, tmax=4, include_tmax=False)
epochs = make_fixed_length_epochs(raw=raw, duration=2., overlap=1.)

######################################################################
# Perform basic test to match simulation data using time-resolved spec
######################################################################
# compare data to original run using Frites
freqs = [30, 90]

# mode was renamed in mne-connectivity
if mode == 'morlet':
mode = 'cwt_morlet'
conn = spectral_connectivity_time(
epochs, freqs=freqs, n_jobs=1, method=method, mode=mode)

# frites only stores the upper triangular parts of the raveled array
row_triu_inds, col_triu_inds = np.triu_indices(len(raw.ch_names), k=1)
conn_data = conn.get_data(output='dense')[
:, row_triu_inds, col_triu_inds, ...]
assert_array_almost_equal(conn_data, test_conn)

######################################################################
# Give varying set of frequency bands and frequencies to perform cWT
######################################################################
raw = raw_copy.crop(tmin=0, tmax=10, include_tmax=False)
ch_names = epochs.ch_names
epochs = make_fixed_length_epochs(raw=raw, duration=5, overlap=0.)

# sampling rate of my data
sfreq = raw.info['sfreq']

# frequency bands of interest
fois = np.array([[4, 8], [8, 12], [12, 16], [16, 32]])

# frequencies of Continuous Morlet Wavelet Transform
freqs = np.arange(4., 32., 1)

# compute coherence
cohs = spectral_connectivity_time(
epochs, names=None, method=method, indices=None,
sfreq=sfreq, foi=fois, sm_times=0.5, sm_freqs=1, sm_kernel='hanning',
mode=mode, mt_bandwidth=None, freqs=freqs, n_cycles=5)
assert cohs.get_data(output='dense').shape == (
len(epochs), len(ch_names), len(ch_names), len(fois), len(epochs.times)
)


def test_save(tmp_path):
"""Test saving results of spectral connectivity."""
rng = np.random.RandomState(0)
Expand Down
Loading