Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSOC] Add EpochsTFR support to spectral connectivity functions #232

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 75 additions & 43 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.source_estimate import _BaseSourceEstimate
from mne.time_frequency import (
EpochsSpectrum,
EpochsSpectrumArray,
EpochsTFR,
)
from mne.time_frequency.multitaper import (
_compute_mt_params,
_csd_from_mt,
_mt_spectra,
_psd_from_mt,
_psd_from_mt_adaptive,
)
from mne.time_frequency.spectrum import (
BaseSpectrum,
EpochsSpectrum,
EpochsSpectrumArray,
)
from mne.time_frequency.tfr import cwt, morlet
from mne.time_frequency.spectrum import BaseSpectrum
from mne.time_frequency.tfr import BaseTFR, cwt, morlet
from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn

from ..base import SpectralConnectivity, SpectroTemporalConnectivity
Expand Down Expand Up @@ -161,17 +162,18 @@ def _prepare_connectivity(
"""Check and precompute dimensions of results data."""
first_epoch = epoch_block[0]

# Sort times and freqs
if spectrum_computed:
# Sort times
if spectrum_computed and times_in is None: # is a Spectrum object
n_signals = first_epoch[0].shape[0]
times = None
n_times = None
times_in = None
n_times_in = None
n_times = 0
n_times_in = 0
tmin_idx = None
tmax_idx = None
warn_times = False
else:
else: # data has a time dimension (timeseries or TFR object)
if spectrum_computed: # is a TFR object
first_epoch = (first_epoch[0][:, 0],) # just take first freq
(
n_signals,
times,
Expand All @@ -184,6 +186,9 @@ def _prepare_connectivity(
) = _check_times(
data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax
)

# Sort freqs
if not spectrum_computed: # is an (ordinary) timeseries
# check that fmin corresponds to at least 5 cycles
fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times)
# compute frequencies to analyze based on number of samples, sampling rate,
Expand Down Expand Up @@ -511,14 +516,19 @@ def _epoch_spectral_connectivity(

# compute tapered spectra
if spectrum_computed: # use existing spectral info
# XXX: Will need to distinguish time-resolved spectra here if support added
# Select signals & freqs of interest (flexible indexing for optional tapers dim)
x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_
if weights is None: # also assumes no tapers dim
x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim
weights = np.ones((1, 1, 1))
# Select entries of interest (flexible indexing for optional tapers dim)
if tmin_idx is not None and tmax_idx is not None:
x_t = np.asarray(data)[:, sig_idx][..., freq_mask, tmin_idx:tmax_idx]
else:
x_t = np.asarray(data)[:, sig_idx][..., freq_mask]
if weights is None: # assumes no tapers dim
x_t = np.expand_dims(x_t, axis=2) # CSD construction expects tapers dim
weights = np.ones((1, 1, 1))
if accumulate_psd:
this_psd = _psd_from_mt(x_t, weights)
if weights is not None: # only None if mode == 'cwt_morlet'
this_psd = _psd_from_mt(x_t, weights)
else:
this_psd = (x_t * x_t.conj()).real
else: # compute spectral info from scratch
x_t, this_psd, weights = _compute_spectra(
data=data,
Expand Down Expand Up @@ -727,13 +737,14 @@ def spectral_connectivity_epochs(

Parameters
----------
data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum
data : array-like, shape=(n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsTFR
The data from which to compute connectivity. Can be epoched timeseries data as
an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients
for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If
timeseries data, the spectral information will be computed according to the
spectral estimation mode (see the ``mode`` parameter). If an
:class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information
for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral
information will be computed according to the spectral estimation mode (see the
``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object, existing spectral information
will be used and the ``mode`` parameter will be ignored.

Note that it is also possible to combine multiple timeseries signals by
Expand All @@ -748,8 +759,9 @@ def spectral_connectivity_epochs(

.. versionchanged:: 0.8
Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum`
or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed
in as data. Storing Fourier coefficients requires ``mne >= 1.8``.
or :class:`~mne.time_frequency.EpochsTFR` object can also be passed in as
data. Storing Fourier coefficients in
:class:`~mne.time_frequency.EpochsSpectrum` objects requires ``mne >= 1.8``.
%(names)s
method : str | list of str
Connectivity measure(s) to compute. These can be ``['coh', 'cohy',
Expand Down Expand Up @@ -789,7 +801,8 @@ def spectral_connectivity_epochs(
mode : str
Spectrum estimation mode can be either: 'multitaper', 'fourier', or
'cwt_morlet'. Ignored if ``data`` is an
:class:`~mne.time_frequency.EpochsSpectrum` object.
:class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object.
fmin : float | tuple of float
The lower frequency of interest. Multiple bands are defined using
a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq.
Expand Down Expand Up @@ -1105,7 +1118,7 @@ def spectral_connectivity_epochs(
weights = None
metadata = None
spectrum_computed = False
if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsSpectrumArray):
if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR):
names = data.ch_names
sfreq = data.info["sfreq"]

Expand All @@ -1126,28 +1139,45 @@ def spectral_connectivity_epochs(
data.add_annotations_to_metadata(overwrite=True)
metadata = data.metadata

if isinstance(data, EpochsSpectrum | EpochsSpectrumArray):
if isinstance(data, EpochsSpectrum | EpochsTFR):
# XXX: Will need to be updated if new Spectrum methods are added
if not np.iscomplexobj(data.get_data()):
raise TypeError(
"if `data` is an EpochsSpectrum object, it must contain "
"complex-valued Fourier coefficients, such as that returned from "
"Epochs.compute_psd(output='complex')"
"if `data` is an EpochsSpectrum or EpochsTFR object, it must "
"contain complex-valued Fourier coefficients, such as that "
"returned from Epochs.compute_psd/tfr() with `output='complex'`"
)
if "segment" in data._dims:
raise ValueError(
"`data` cannot contain Fourier coefficients for individual segments"
)
if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum
mode = data.method
mode = "fourier" if mode == "welch" else mode
else: # spectral method is "unknown", so take mode from data dimensions
# Currently, actual mode doesn't matter as long as we handle tapers and
# their weights in the same way as for multitaper spectra
mode = "multitaper" if "taper" in data._dims else "fourier"
mode = data.method
if isinstance(data, EpochsSpectrum | EpochsSpectrumArray):
if isinstance(data, EpochsSpectrum): # read mode from object
mode = "fourier" if mode == "welch" else mode
else: # infer mode from dimensions
# Currently, actual mode doesn't matter as long as we handle tapers
# and their weights in the same way as for multitaper spectra
mode = "multitaper" if "taper" in data._dims else "fourier"
Comment on lines +1155 to +1161
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really surprised this works (which probably means it's not tested?), because as noted above EpochsSpectrumArray is a subclass of EpochsSpectrum:

In [1]: import numpy as np
In [2]: import mne
In [3]: data = np.zeros((2,3,5))
In [4]: info = mne.create_info(3, ch_types='mag', sfreq=1000)
In [5]: freqs = np.arange(10, 20, 2)
In [6]: foo = mne.time_frequency.EpochsSpectrumArray(data, info, freqs)
In [7]: isinstance(foo, mne.time_frequency.EpochsSpectrum)
Out[7]: True

I'd think you'd need to first check for EpochsSpectrumArray and handle non-Array class in the else clause.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot! I'll look through the tests again and find where it's slipping through.

weights = data.weights
else:
if isinstance(data, EpochsTFR): # read mode from object
if mode != "morlet": # FIXME: Add support for other TFR methods
raise ValueError(
"if `data` is an EpochsTFR object, the spectral method "
"must be 'morlet'"
)
else:
if "taper" in data._dims: # FIXME: Add support for multitaper TFR
raise ValueError(
"if `data` is an EpochsTFRArray object, it cannot contain "
"Fourier coefficients for individual tapers"
)
Comment on lines +1164 to +1175
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for this if/else; need to check for & handle the *Array class in the if, not in the else

mode = "cwt_morlet" # currently only supported mode here
times_in = data.times
weights = None # no weights stored in TFR objects
spectrum_computed = True
freqs = data.freqs
weights = data.weights
else:
times_in = data.times # input times for Epochs input type
elif sfreq is None:
Expand Down Expand Up @@ -1235,7 +1265,7 @@ def spectral_connectivity_epochs(
spectral_params = dict(
eigvals=None, window_fun=None, wavelets=None, weights=weights
)
n_times_spectrum = 0
n_times_spectrum = n_times # 0 if no times
n_tapers = None if weights is None else weights.size

# unique signals for which we actually need to compute PSD etc.
Expand Down Expand Up @@ -1289,7 +1319,7 @@ def spectral_connectivity_epochs(
logger.info(f" the following metrics will be computed: {metrics_str}")

# check dimensions and time scale
if not spectrum_computed: # XXX: Can we assume upstream checks sufficient?
if not spectrum_computed:
for this_epoch in epoch_block:
_, _, _, warn_times = _get_and_verify_data_sizes(
this_epoch,
Expand Down Expand Up @@ -1469,7 +1499,9 @@ def spectral_connectivity_epochs(
freqs=freqs,
method=_method,
n_nodes=n_nodes,
spec_method=mode if not isinstance(data, BaseSpectrum) else data.method,
spec_method=(
mode if not isinstance(data, BaseSpectrum | BaseTFR) else data.method
),
indices=indices,
n_epochs_used=n_epochs,
freqs_used=freqs_used,
Expand Down
12 changes: 11 additions & 1 deletion mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import numpy as np
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray
from mne.time_frequency import (
EpochsSpectrum,
EpochsSpectrumArray,
EpochsTFR,
)
from mne.time_frequency.multitaper import _psd_from_mt
from mne.utils import ProgressBar, _validate_type, logger

Expand All @@ -40,6 +44,12 @@ def _check_rank_input(rank, data, indices):
data_arr = _psd_from_mt(data_arr, data.weights)
else:
data_arr = (data_arr * data_arr.conj()).real
elif isinstance(data, EpochsTFR):
# TFR objs will drop bad channels, so specify picking all channels
data_arr = data.get_data(picks=np.arange(data.info["nchan"]))
# Convert to power and aggregate over time before computing rank
# XXX: need to change when other types of TFR are supported
data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1)
else:
data_arr = data

Expand Down
Loading