-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GSOC] Add EpochsTFR
support to spectral connectivity functions
#232
Open
tsbinns
wants to merge
18
commits into
mne-tools:main
Choose a base branch
from
tsbinns:specconn_tfr_support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
a9dc973
Add TFR support spec_conn_epochs
tsbinns 535db69
Update epochs docstring
tsbinns 14428f6
Update rank check comments
tsbinns 1b1456d
Add TFR support spec_conn_time
tsbinns 6fd0861
Switch tests to custom MNE branch
tsbinns fc55d70
Merge branch 'main' into specconn_tfr_support
tsbinns a38f8fd
Fix failing tfr_error test
tsbinns 9c04f81
Fix time_tfr tolerances
tsbinns c861449
Fix misleading error message
tsbinns 7013e19
Fix spec_conn_time docstring error
tsbinns 160bc64
Revert "Switch tests to custom MNE branch"
tsbinns 23dc1f2
Merge branch 'main' into specconn_tfr_support
tsbinns 8e79c9f
Apply suggestions from code review
tsbinns 23fd89c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 641fa1d
Name expected conn values
tsbinns 232c0e9
Update Welch-Fourier variation message
tsbinns 4898673
Merge branch 'main' into specconn_tfr_support
tsbinns 269d209
Merge branch 'main' into specconn_tfr_support
tsbinns File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,19 +12,20 @@ | |
from mne.epochs import BaseEpochs | ||
from mne.parallel import parallel_func | ||
from mne.source_estimate import _BaseSourceEstimate | ||
from mne.time_frequency import ( | ||
EpochsSpectrum, | ||
EpochsSpectrumArray, | ||
EpochsTFR, | ||
) | ||
from mne.time_frequency.multitaper import ( | ||
_compute_mt_params, | ||
_csd_from_mt, | ||
_mt_spectra, | ||
_psd_from_mt, | ||
_psd_from_mt_adaptive, | ||
) | ||
from mne.time_frequency.spectrum import ( | ||
BaseSpectrum, | ||
EpochsSpectrum, | ||
EpochsSpectrumArray, | ||
) | ||
from mne.time_frequency.tfr import cwt, morlet | ||
from mne.time_frequency.spectrum import BaseSpectrum | ||
from mne.time_frequency.tfr import BaseTFR, cwt, morlet | ||
from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn | ||
|
||
from ..base import SpectralConnectivity, SpectroTemporalConnectivity | ||
|
@@ -161,17 +162,18 @@ def _prepare_connectivity( | |
"""Check and precompute dimensions of results data.""" | ||
first_epoch = epoch_block[0] | ||
|
||
# Sort times and freqs | ||
if spectrum_computed: | ||
# Sort times | ||
if spectrum_computed and times_in is None: # is a Spectrum object | ||
n_signals = first_epoch[0].shape[0] | ||
times = None | ||
n_times = None | ||
times_in = None | ||
n_times_in = None | ||
n_times = 0 | ||
n_times_in = 0 | ||
tmin_idx = None | ||
tmax_idx = None | ||
warn_times = False | ||
else: | ||
else: # data has a time dimension (timeseries or TFR object) | ||
if spectrum_computed: # is a TFR object | ||
first_epoch = (first_epoch[0][:, 0],) # just take first freq | ||
( | ||
n_signals, | ||
times, | ||
|
@@ -184,6 +186,9 @@ def _prepare_connectivity( | |
) = _check_times( | ||
data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax | ||
) | ||
|
||
# Sort freqs | ||
if not spectrum_computed: # is an (ordinary) timeseries | ||
# check that fmin corresponds to at least 5 cycles | ||
fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) | ||
# compute frequencies to analyze based on number of samples, sampling rate, | ||
|
@@ -511,14 +516,19 @@ def _epoch_spectral_connectivity( | |
|
||
# compute tapered spectra | ||
if spectrum_computed: # use existing spectral info | ||
# XXX: Will need to distinguish time-resolved spectra here if support added | ||
# Select signals & freqs of interest (flexible indexing for optional tapers dim) | ||
x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_ | ||
if weights is None: # also assumes no tapers dim | ||
x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim | ||
weights = np.ones((1, 1, 1)) | ||
# Select entries of interest (flexible indexing for optional tapers dim) | ||
if tmin_idx is not None and tmax_idx is not None: | ||
x_t = np.asarray(data)[:, sig_idx][..., freq_mask, tmin_idx:tmax_idx] | ||
else: | ||
x_t = np.asarray(data)[:, sig_idx][..., freq_mask] | ||
if weights is None: # assumes no tapers dim | ||
x_t = np.expand_dims(x_t, axis=2) # CSD construction expects tapers dim | ||
weights = np.ones((1, 1, 1)) | ||
if accumulate_psd: | ||
this_psd = _psd_from_mt(x_t, weights) | ||
if weights is not None: # only None if mode == 'cwt_morlet' | ||
this_psd = _psd_from_mt(x_t, weights) | ||
else: | ||
this_psd = (x_t * x_t.conj()).real | ||
else: # compute spectral info from scratch | ||
x_t, this_psd, weights = _compute_spectra( | ||
data=data, | ||
|
@@ -727,13 +737,14 @@ def spectral_connectivity_epochs( | |
|
||
Parameters | ||
---------- | ||
data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum | ||
data : array-like, shape=(n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsTFR | ||
The data from which to compute connectivity. Can be epoched timeseries data as | ||
an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients | ||
for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If | ||
timeseries data, the spectral information will be computed according to the | ||
spectral estimation mode (see the ``mode`` parameter). If an | ||
:class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information | ||
for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` or | ||
:class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral | ||
information will be computed according to the spectral estimation mode (see the | ||
``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or | ||
:class:`~mne.time_frequency.EpochsTFR` object, existing spectral information | ||
will be used and the ``mode`` parameter will be ignored. | ||
|
||
Note that it is also possible to combine multiple timeseries signals by | ||
|
@@ -748,8 +759,9 @@ def spectral_connectivity_epochs( | |
|
||
.. versionchanged:: 0.8 | ||
Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum` | ||
or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed | ||
in as data. Storing Fourier coefficients requires ``mne >= 1.8``. | ||
or :class:`~mne.time_frequency.EpochsTFR` object can also be passed in as | ||
data. Storing Fourier coefficients in | ||
:class:`~mne.time_frequency.EpochsSpectrum` objects requires ``mne >= 1.8``. | ||
%(names)s | ||
method : str | list of str | ||
Connectivity measure(s) to compute. These can be ``['coh', 'cohy', | ||
|
@@ -789,7 +801,8 @@ def spectral_connectivity_epochs( | |
mode : str | ||
Spectrum estimation mode can be either: 'multitaper', 'fourier', or | ||
'cwt_morlet'. Ignored if ``data`` is an | ||
:class:`~mne.time_frequency.EpochsSpectrum` object. | ||
:class:`~mne.time_frequency.EpochsSpectrum` or | ||
:class:`~mne.time_frequency.EpochsTFR` object. | ||
fmin : float | tuple of float | ||
The lower frequency of interest. Multiple bands are defined using | ||
a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. | ||
|
@@ -1105,7 +1118,7 @@ def spectral_connectivity_epochs( | |
weights = None | ||
metadata = None | ||
spectrum_computed = False | ||
if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsSpectrumArray): | ||
if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR): | ||
names = data.ch_names | ||
sfreq = data.info["sfreq"] | ||
|
||
|
@@ -1126,28 +1139,45 @@ def spectral_connectivity_epochs( | |
data.add_annotations_to_metadata(overwrite=True) | ||
metadata = data.metadata | ||
|
||
if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): | ||
if isinstance(data, EpochsSpectrum | EpochsTFR): | ||
# XXX: Will need to be updated if new Spectrum methods are added | ||
if not np.iscomplexobj(data.get_data()): | ||
raise TypeError( | ||
"if `data` is an EpochsSpectrum object, it must contain " | ||
"complex-valued Fourier coefficients, such as that returned from " | ||
"Epochs.compute_psd(output='complex')" | ||
"if `data` is an EpochsSpectrum or EpochsTFR object, it must " | ||
"contain complex-valued Fourier coefficients, such as that " | ||
"returned from Epochs.compute_psd/tfr() with `output='complex'`" | ||
) | ||
if "segment" in data._dims: | ||
raise ValueError( | ||
"`data` cannot contain Fourier coefficients for individual segments" | ||
) | ||
if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum | ||
mode = data.method | ||
mode = "fourier" if mode == "welch" else mode | ||
else: # spectral method is "unknown", so take mode from data dimensions | ||
# Currently, actual mode doesn't matter as long as we handle tapers and | ||
# their weights in the same way as for multitaper spectra | ||
mode = "multitaper" if "taper" in data._dims else "fourier" | ||
mode = data.method | ||
if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): | ||
if isinstance(data, EpochsSpectrum): # read mode from object | ||
mode = "fourier" if mode == "welch" else mode | ||
else: # infer mode from dimensions | ||
# Currently, actual mode doesn't matter as long as we handle tapers | ||
# and their weights in the same way as for multitaper spectra | ||
mode = "multitaper" if "taper" in data._dims else "fourier" | ||
weights = data.weights | ||
else: | ||
if isinstance(data, EpochsTFR): # read mode from object | ||
if mode != "morlet": # FIXME: Add support for other TFR methods | ||
raise ValueError( | ||
"if `data` is an EpochsTFR object, the spectral method " | ||
"must be 'morlet'" | ||
) | ||
else: | ||
if "taper" in data._dims: # FIXME: Add support for multitaper TFR | ||
raise ValueError( | ||
"if `data` is an EpochsTFRArray object, it cannot contain " | ||
"Fourier coefficients for individual tapers" | ||
) | ||
Comment on lines
+1164
to
+1175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto for this if/else; need to check for & handle the |
||
mode = "cwt_morlet" # currently only supported mode here | ||
times_in = data.times | ||
weights = None # no weights stored in TFR objects | ||
spectrum_computed = True | ||
freqs = data.freqs | ||
weights = data.weights | ||
else: | ||
times_in = data.times # input times for Epochs input type | ||
elif sfreq is None: | ||
|
@@ -1235,7 +1265,7 @@ def spectral_connectivity_epochs( | |
spectral_params = dict( | ||
eigvals=None, window_fun=None, wavelets=None, weights=weights | ||
) | ||
n_times_spectrum = 0 | ||
n_times_spectrum = n_times # 0 if no times | ||
n_tapers = None if weights is None else weights.size | ||
|
||
# unique signals for which we actually need to compute PSD etc. | ||
|
@@ -1289,7 +1319,7 @@ def spectral_connectivity_epochs( | |
logger.info(f" the following metrics will be computed: {metrics_str}") | ||
|
||
# check dimensions and time scale | ||
if not spectrum_computed: # XXX: Can we assume upstream checks sufficient? | ||
if not spectrum_computed: | ||
for this_epoch in epoch_block: | ||
_, _, _, warn_times = _get_and_verify_data_sizes( | ||
this_epoch, | ||
|
@@ -1469,7 +1499,9 @@ def spectral_connectivity_epochs( | |
freqs=freqs, | ||
method=_method, | ||
n_nodes=n_nodes, | ||
spec_method=mode if not isinstance(data, BaseSpectrum) else data.method, | ||
spec_method=( | ||
mode if not isinstance(data, BaseSpectrum | BaseTFR) else data.method | ||
), | ||
indices=indices, | ||
n_epochs_used=n_epochs, | ||
freqs_used=freqs_used, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really surprised this works (which probably means it's not tested?), because as noted above
EpochsSpectrumArray
is a subclass ofEpochsSpectrum
:I'd think you'd need to first check for
EpochsSpectrumArray
and handle non-Array class in theelse
clause.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot! I'll look through the tests again and find where it's slipping through.