Skip to content

Commit

Permalink
WIP ENH MneExperiment: allow ICA with custom file
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed Jan 30, 2025
1 parent aca3dc8 commit 617f01d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
13 changes: 10 additions & 3 deletions eelbrain/_experiment/mne_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,8 +4486,12 @@ def make_ica_selection(
# display data
subject = self.get('subject')
pipe = self._get_ica_pipe(state)
bads = pipe.load_bad_channels(subject, self.get('recording'))
with self._temporary_state:
if pipe.session is None:
add_bads = True
else:
add_bads = pipe.load_bad_channels(subject, self.get('recording'))
with self._temporary_state, warnings.catch_warnings():
warnings.filterwarnings('ignore', 'The measurement information indicates a low-pass', RuntimeWarning)
if epoch is None:
if session is None:
session = pipe.session
Expand All @@ -4498,7 +4502,10 @@ def make_ica_selection(
elif session is not None:
raise TypeError(f"{session=} with {epoch=}")
else:
ds = self.load_epochs(ndvar=False, epoch=epoch, reject=False, raw=pipe.source.name, samplingrate=samplingrate, decim=decim, add_bads=bads)
ds = self.load_epochs(ndvar=False, epoch=epoch, reject=False, raw=pipe.source.name, samplingrate=samplingrate, decim=decim, add_bads=add_bads)
if add_bads is True:
ica = mne.preprocessing.read_ica(path)
ds['epochs'].info['bads'] = [ch for ch in ds['epochs'].ch_names if ch not in ica.ch_names]
if isinstance(ds['epochs'], Datalist): # variable-length epoch
data = np.concatenate([epoch.get_data()[0] for epoch in ds['epochs']], axis=1) # n_epochs, n_channels, n_times
raw = mne.io.RawArray(data, ds[0, 'epochs'].info)
Expand Down
24 changes: 20 additions & 4 deletions eelbrain/_experiment/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ class RawICA(CachedRawPipe):
Name of the raw pipe to use for input data.
session
Session(s) to use for estimating ICA components.
Set to ``None`` to create an ICA pipe with user-created ICA files.
method
Method for ICA decomposition (default: ``'extended-infomax'``; see
:class:`mne.preprocessing.ICA`).
Expand All @@ -608,7 +609,7 @@ class RawICA(CachedRawPipe):
:meth:`mne.preprocessing.ICA.fit`. This includes
``reject={'mag': 5e-12, 'grad': 5000e-13, 'eeg': 300e-6}`` unless
a different value for ``reject`` is specified here.
cache : bool
cache
Cache the resulting raw files (default ``False``).
...
Additional parameters for :class:`mne.preprocessing.ICA`.
Expand Down Expand Up @@ -639,15 +640,23 @@ class RawICA(CachedRawPipe):
def __init__(
self,
source: str,
session: Union[str, Sequence[str]],
session: Union[str, Sequence[str], None],
method: str = 'extended-infomax',
random_state: int = 0,
fit_kwargs: Dict[str, Any] = None,
cache: bool = False,
**kwargs,
):
CachedRawPipe.__init__(self, source, cache)
self.session = tuple_arg('session', session, allow_none=False)
if isinstance(session, str):
session = (session,)
elif session is None:
pass
else:
if not isinstance(session, tuple):
session = tuple(session)
assert all(isinstance(s, str) for s in session)
self.session = session
self.kwargs = {'method': method, 'random_state': random_state, **kwargs}
self.fit_kwargs = dict(fit_kwargs) if fit_kwargs else {}

Expand All @@ -668,6 +677,9 @@ def load_bad_channels(
recording: str,
existing: Collection[str] = None,
):
if self.session is None:
ica = self.load_ica(subject, recording)
return ica.info['bads']
visit = _visit(recording)
bad_chs = set()
for session in self.session:
Expand Down Expand Up @@ -732,6 +744,10 @@ def make_ica(
visit: str,
):
path = self._ica_path(subject, visit)
if self.session is None:
if Path(path).exists():
return path
raise RuntimeError(f"{self} is user-generated ICA")
recordings = [compound((session, visit)) for session in self.session]
bad_channels = self.load_bad_channels(subject, recordings[0])
raw = self.source.load(subject, recordings[0], bad_channels, preload=-1)
Expand Down Expand Up @@ -1093,7 +1109,7 @@ def assemble_pipeline(raw_dict, raw_dir, cache_path, root, sessions, log):
for key in list(raw):
if raw[key]._can_link(linked_raw):
pipe = raw.pop(key)._link(key, linked_raw, root, raw_dir, cache_path, log)
if isinstance(pipe, RawICA):
if isinstance(pipe, RawICA) and pipe.session is not None:
missing = set(pipe.session).difference(sessions)
if missing:
raise DefinitionError(f"RawICA {key!r} lists one or more non-exising sessions: {', '.join(missing)}")
Expand Down

0 comments on commit 617f01d

Please sign in to comment.