Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add utility function to enable autoreject cleaning on Epoch level #62

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Post-processing on connectivity
seed_target_indices
check_indices
select_order
map_epoch_annotations_to_epoch

Visualization functions
=======================
Expand Down
1 change: 1 addition & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Enhancements
- Adding symmetric orthogonalization via :func:`mne_connectivity.symmetric_orth`, by `Eric Larson`_ (:gh:`36`)
- Improved RAM usage for :func:`mne_connectivity.vector_auto_regression` by leveraging code from ``statsmodels``, by `Adam Li`_ (:gh:`46`)
- Added :func:`mne_connectivity.select_order` for helping to select VAR order using information criterion, by `Adam Li`_ (:gh:`46`)
- Adds a utility function :func:`mne_connectivity.utils.map_epoch_annotations_to_epoch` to map arbitrary Epoch windows to another arbitrary Epoch window, by `Adam Li`_ (:gh:`62`)

Bug
~~~
Expand Down
41 changes: 40 additions & 1 deletion mne_connectivity/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
import pytest
from numpy.testing import assert_array_equal

from mne.io import RawArray
from mne.epochs import Epochs, make_fixed_length_epochs
from mne.io.meas_info import create_info

from mne_connectivity import Connectivity
from mne_connectivity.utils import degree, seed_target_indices
from mne_connectivity.utils import (
degree, seed_target_indices, map_epoch_annotations_to_epoch)


def test_indices():
Expand Down Expand Up @@ -64,3 +69,37 @@ def test_degree():
conn = Connectivity(data=np.zeros((4,)), n_nodes=2)
deg = degree(conn)
assert_array_equal(deg, [0, 0])


def test_mapping_epochs_to_epochs():
"""Test map_epoch_annotations_to_epoch function."""
n_times = 1000
sfreq = 100
data = np.random.random((2, n_times))
info = create_info(ch_names=['A1', 'A2'], sfreq=sfreq,
ch_types='mag')
raw = RawArray(data, info)

# create two different sets of Epochs
# the first one is just a contiguous chunks of 1 seconds
epoch_one = make_fixed_length_epochs(raw, duration=1, overlap=0)

events = np.zeros((2, 3), dtype=int)
events[:, 0] = [100, 900]
epoch_two = Epochs(raw, events, tmin=-0.5, tmax=0.5)

# map Epochs from two to one
all_cases = map_epoch_annotations_to_epoch(epoch_one, epoch_two)
assert all_cases.shape == (2, 10)

# only 1-3 Epochs of epoch_one should overlap with the epoch_two's
# 1st Epoch
assert all(all_cases[0, :2])
assert all(all_cases[1, -2:])

# map Epochs from one to two
all_cases = map_epoch_annotations_to_epoch(epoch_two, epoch_one)
assert all_cases.shape == (10, 2)

assert all(all_cases[:2, 0])
assert all(all_cases[-2:, 1])
2 changes: 1 addition & 1 deletion mne_connectivity/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .docs import fill_doc
from .utils import (check_indices, degree, seed_target_indices,
parallel_loop)
parallel_loop, map_epoch_annotations_to_epoch)
86 changes: 86 additions & 0 deletions mne_connectivity/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,97 @@
# Authors: Martin Luessi <mluessi@nmr.mgh.harvard.edu>
# Adam Li <adam2392@gmail.com>
#
# License: BSD (3-clause)
import numpy as np

from mne import BaseEpochs
from mne.utils import logger


def map_epoch_annotations_to_epoch(dest_epoch, src_epoch):
"""Map Annotations that occur in one Epoch to another Epoch.

Two different Epochs might occur at different time points.
This function will map Annotations that occur in one Epoch
setting to another Epoch taking into account their onset
samples and window lengths.

Parameters
----------
dest_epoch : instance of Epochs | events array
The reference Epochs that you want to match to.
src_epoch : instance of Epochs | events array
The source Epochs that contain Epochs you want to
see if it overlaps at any point with ``dest_epoch``.

Returns
-------
all_cases : np.ndarray of shape (n_src_epochs, n_dest_epochs)
This is an array indicating the overlap of any source epoch
relative to the destination epoch. An overlap is indicated
by a ``True``, whereas if a source Epoch does not overlap
with a destination Epoch, then the element will be ``False``.

Notes
-----
This is a useful utility function to enable mapping Autoreject
``RejectLog`` that occurs over a set of defined Epochs to
another Epoched data structure, such as a ``Epoch*`` connectivity
class, which computes connectivity over Epochs.
"""
if isinstance(dest_epoch, BaseEpochs):
dest_events = dest_epoch.events
dest_times = dest_epoch.times
dest_sfreq = dest_epoch._raw_sfreq
else:
dest_events = dest_epoch
if isinstance(src_epoch, BaseEpochs):
src_events = src_epoch.events
src_times = src_epoch.times
src_sfreq = src_epoch._raw_sfreq
else:
src_events = src_epoch

# get the sample points of the source Epochs we want
# to map over to the destination sample points
src_onset_sample = src_events[:, 0]
src_epoch_tzeros = src_onset_sample / src_sfreq
dest_onset_sample = dest_events[:, 0]
dest_epoch_tzeros = dest_onset_sample / dest_sfreq

# get start and stop points of every single source Epoch
src_epoch_starts, src_epoch_stops = np.atleast_2d(
src_epoch_tzeros) + np.atleast_2d(src_times[[0, -1]]).T

# get start and stop points of every single destination Epoch
dest_epoch_starts, dest_epoch_stops = np.atleast_2d(
dest_epoch_tzeros) + np.atleast_2d(dest_times[[0, -1]]).T

# get destination Epochs that start within the source Epoch
src_straddles_dest_start = np.logical_and(
np.atleast_2d(dest_epoch_starts) >= np.atleast_2d(src_epoch_starts).T,
np.atleast_2d(dest_epoch_starts) < np.atleast_2d(src_epoch_stops).T)

# get epochs that end within the annotations
src_straddles_dest_end = np.logical_and(
np.atleast_2d(dest_epoch_stops) > np.atleast_2d(src_epoch_starts).T,
np.atleast_2d(dest_epoch_stops) <= np.atleast_2d(src_epoch_stops).T)

# get epochs that are fully contained within annotations
src_fully_within_dest = np.logical_and(
np.atleast_2d(dest_epoch_starts) <= np.atleast_2d(src_epoch_starts).T,
np.atleast_2d(dest_epoch_stops) >= np.atleast_2d(src_epoch_stops).T)

# combine all cases to get array of shape (n_src_epochs, n_dest_epochs).
# Nonzero entries indicate overlap between the corresponding
# annotation (row index) and epoch (column index).
all_cases = (src_straddles_dest_start +
src_straddles_dest_end +
src_fully_within_dest)

return all_cases


def parallel_loop(func, n_jobs=1, verbose=1):
"""run loops in parallel, if joblib is available.

Expand Down