From 617f01df204fb664abecbf082c7c878b9aa45331 Mon Sep 17 00:00:00 2001 From: Christian Brodbeck Date: Tue, 1 Nov 2022 14:01:43 -0400 Subject: [PATCH] WIP ENH MneExperiment: allow ICA with custom file --- eelbrain/_experiment/mne_experiment.py | 13 ++++++++++--- eelbrain/_experiment/preprocessing.py | 24 ++++++++++++++++++++---- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/eelbrain/_experiment/mne_experiment.py b/eelbrain/_experiment/mne_experiment.py index d228bedcd..e3e3dd6e6 100644 --- a/eelbrain/_experiment/mne_experiment.py +++ b/eelbrain/_experiment/mne_experiment.py @@ -4486,8 +4486,12 @@ def make_ica_selection( # display data subject = self.get('subject') pipe = self._get_ica_pipe(state) - bads = pipe.load_bad_channels(subject, self.get('recording')) - with self._temporary_state: + if pipe.session is None: + add_bads = True + else: + add_bads = pipe.load_bad_channels(subject, self.get('recording')) + with self._temporary_state, warnings.catch_warnings(): + warnings.filterwarnings('ignore', 'The measurement information indicates a low-pass', RuntimeWarning) if epoch is None: if session is None: session = pipe.session @@ -4498,7 +4502,10 @@ def make_ica_selection( elif session is not None: raise TypeError(f"{session=} with {epoch=}") else: - ds = self.load_epochs(ndvar=False, epoch=epoch, reject=False, raw=pipe.source.name, samplingrate=samplingrate, decim=decim, add_bads=bads) + ds = self.load_epochs(ndvar=False, epoch=epoch, reject=False, raw=pipe.source.name, samplingrate=samplingrate, decim=decim, add_bads=add_bads) + if add_bads is True: + ica = mne.preprocessing.read_ica(path) + ds['epochs'].info['bads'] = [ch for ch in ds['epochs'].ch_names if ch not in ica.ch_names] if isinstance(ds['epochs'], Datalist): # variable-length epoch data = np.concatenate([epoch.get_data()[0] for epoch in ds['epochs']], axis=1) # n_epochs, n_channels, n_times raw = mne.io.RawArray(data, ds[0, 'epochs'].info) diff --git a/eelbrain/_experiment/preprocessing.py b/eelbrain/_experiment/preprocessing.py index c94fa64a8..72ddd897c 100644 --- a/eelbrain/_experiment/preprocessing.py +++ b/eelbrain/_experiment/preprocessing.py @@ -597,6 +597,7 @@ class RawICA(CachedRawPipe): Name of the raw pipe to use for input data. session Session(s) to use for estimating ICA components. + Set to ``None`` to create an ICA pipe with user-created ICA files. method Method for ICA decomposition (default: ``'extended-infomax'``; see :class:`mne.preprocessing.ICA`). @@ -608,7 +609,7 @@ class RawICA(CachedRawPipe): :meth:`mne.preprocessing.ICA.fit`. This includes ``reject={'mag': 5e-12, 'grad': 5000e-13, 'eeg': 300e-6}`` unless a different value for ``reject`` is specified here. - cache : bool + cache Cache the resulting raw files (default ``False``). ... Additional parameters for :class:`mne.preprocessing.ICA`. @@ -639,7 +640,7 @@ class RawICA(CachedRawPipe): def __init__( self, source: str, - session: Union[str, Sequence[str]], + session: Union[str, Sequence[str], None], method: str = 'extended-infomax', random_state: int = 0, fit_kwargs: Dict[str, Any] = None, @@ -647,7 +648,15 @@ def __init__( **kwargs, ): CachedRawPipe.__init__(self, source, cache) - self.session = tuple_arg('session', session, allow_none=False) + if isinstance(session, str): + session = (session,) + elif session is None: + pass + else: + if not isinstance(session, tuple): + session = tuple(session) + assert all(isinstance(s, str) for s in session) + self.session = session self.kwargs = {'method': method, 'random_state': random_state, **kwargs} self.fit_kwargs = dict(fit_kwargs) if fit_kwargs else {} @@ -668,6 +677,9 @@ def load_bad_channels( recording: str, existing: Collection[str] = None, ): + if self.session is None: + ica = self.load_ica(subject, recording) + return ica.info['bads'] visit = _visit(recording) bad_chs = set() for session in self.session: @@ -732,6 +744,10 @@ def make_ica( visit: str, ): path = self._ica_path(subject, visit) + if self.session is None: + if Path(path).exists(): + return path + raise RuntimeError(f"{self} is user-generated ICA") recordings = [compound((session, visit)) for session in self.session] bad_channels = self.load_bad_channels(subject, recordings[0]) raw = self.source.load(subject, recordings[0], bad_channels, preload=-1) @@ -1093,7 +1109,7 @@ def assemble_pipeline(raw_dict, raw_dir, cache_path, root, sessions, log): for key in list(raw): if raw[key]._can_link(linked_raw): pipe = raw.pop(key)._link(key, linked_raw, root, raw_dir, cache_path, log) - if isinstance(pipe, RawICA): + if isinstance(pipe, RawICA) and pipe.session is not None: missing = set(pipe.session).difference(sessions) if missing: raise DefinitionError(f"RawICA {key!r} lists one or more non-exising sessions: {', '.join(missing)}")