Skip to content

Commit

Permalink
ENH: make apply_function aware of channel index (#12206)
Browse files Browse the repository at this point in the history
Co-authored-by: Mathieu Scheltienne <mathieu.scheltienne@gmail.com>
  • Loading branch information
dominikwelke and mscheltienne committed Feb 6, 2024
1 parent e6b49ea commit 87df00d
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 24 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12206.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug in :meth:`mne.Epochs.apply_function` where data was handed down incorrectly in parallel processing, by `Dominik Welke`_.
3 changes: 3 additions & 0 deletions doc/changes/devel/12206.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Custom functions applied via :meth:`mne.io.Raw.apply_function`, :meth:`mne.Epochs.apply_function` or :meth:`mne.Evoked.apply_function` can now use ``ch_idx`` or ``ch_name`` to get access to the currently processed channel during channel wise processing.

:meth:`mne.Evoked.apply_function` can now also work on full data array instead of just channel wise, analogous to :meth:`mne.io.Raw.apply_function` and :meth:`mne.Epochs.apply_function`, by `Dominik Welke`_.
45 changes: 38 additions & 7 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import Counter
from copy import deepcopy
from functools import partial
from inspect import getfullargspec

import numpy as np
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -1972,22 +1973,52 @@ def apply_function(
if dtype is not None and dtype != self._data.dtype:
self._data = self._data.astype(dtype)

args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs
if channel_wise is False:
if ("ch_idx" in args) or ("ch_name" in args):
raise ValueError(
"apply_function cannot access ch_idx or ch_name "
"when channel_wise=False"
)
if "ch_idx" in args:
logger.info("apply_function requested to access ch_idx")
if "ch_name" in args:
logger.info("apply_function requested to access ch_name")

if channel_wise:
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
if n_jobs == 1:
_fun = partial(_check_fun, fun, **kwargs)
_fun = partial(_check_fun, fun)
# modify data inplace to save memory
for idx in picks:
self._data[:, idx, :] = np.apply_along_axis(
_fun, -1, data_in[:, idx, :]
for ch_idx in picks:
if "ch_idx" in args:
kwargs.update(ch_idx=ch_idx)
if "ch_name" in args:
kwargs.update(ch_name=self.info["ch_names"][ch_idx])
self._data[:, ch_idx, :] = np.apply_along_axis(
_fun, -1, data_in[:, ch_idx, :], **kwargs
)
else:
# use parallel function
_fun = partial(np.apply_along_axis, fun, -1)
data_picks_new = parallel(
p_fun(fun, data_in[:, p, :], **kwargs) for p in picks
p_fun(
_fun,
data_in[:, ch_idx, :],
**kwargs,
**{
k: v
for k, v in [
("ch_name", self.info["ch_names"][ch_idx]),
("ch_idx", ch_idx),
]
if k in args
},
)
for ch_idx in picks
)
for pp, p in enumerate(picks):
self._data[:, p, :] = data_picks_new[pp]
for run_idx, ch_idx in enumerate(picks):
self._data[:, ch_idx, :] = data_picks_new[run_idx]
else:
self._data = _check_fun(fun, data_in, **kwargs)

Expand Down
70 changes: 58 additions & 12 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Copyright the MNE-Python contributors.

from copy import deepcopy
from inspect import getfullargspec
from typing import Union

import numpy as np
Expand Down Expand Up @@ -258,7 +259,15 @@ def get_data(self, picks=None, units=None, tmin=None, tmax=None):

@verbose
def apply_function(
self, fun, picks=None, dtype=None, n_jobs=None, verbose=None, **kwargs
self,
fun,
picks=None,
dtype=None,
n_jobs=None,
channel_wise=True,
*,
verbose=None,
**kwargs,
):
"""Apply a function to a subset of channels.
Expand All @@ -271,6 +280,9 @@ def apply_function(
%(dtype_applyfun)s
%(n_jobs)s Ignored if ``channel_wise=False`` as the workload
is split across channels.
%(channel_wise_applyfun)s
.. versionadded:: 1.6
%(verbose)s
%(kwargs_fun)s
Expand All @@ -289,21 +301,55 @@ def apply_function(
if dtype is not None and dtype != self._data.dtype:
self._data = self._data.astype(dtype)

args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs
if channel_wise is False:
if ("ch_idx" in args) or ("ch_name" in args):
raise ValueError(
"apply_function cannot access ch_idx or ch_name "
"when channel_wise=False"
)
if "ch_idx" in args:
logger.info("apply_function requested to access ch_idx")
if "ch_name" in args:
logger.info("apply_function requested to access ch_name")

# check the dimension of the incoming evoked data
_check_option("evoked.ndim", self._data.ndim, [2])

parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
if n_jobs == 1:
# modify data inplace to save memory
for idx in picks:
self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs)
if channel_wise:
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
if n_jobs == 1:
# modify data inplace to save memory
for ch_idx in picks:
if "ch_idx" in args:
kwargs.update(ch_idx=ch_idx)
if "ch_name" in args:
kwargs.update(ch_name=self.info["ch_names"][ch_idx])
self._data[ch_idx, :] = _check_fun(
fun, data_in[ch_idx, :], **kwargs
)
else:
# use parallel function
data_picks_new = parallel(
p_fun(
fun,
data_in[ch_idx, :],
**kwargs,
**{
k: v
for k, v in [
("ch_name", self.info["ch_names"][ch_idx]),
("ch_idx", ch_idx),
]
if k in args
},
)
for ch_idx in picks
)
for run_idx, ch_idx in enumerate(picks):
self._data[ch_idx, :] = data_picks_new[run_idx]
else:
# use parallel function
data_picks_new = parallel(
p_fun(fun, data_in[p, :], **kwargs) for p in picks
)
for pp, p in enumerate(picks):
self._data[p, :] = data_picks_new[pp]
self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs)

return self

Expand Down
42 changes: 37 additions & 5 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import timedelta
from inspect import getfullargspec

import numpy as np

Expand Down Expand Up @@ -1087,19 +1088,50 @@ def apply_function(
if dtype is not None and dtype != self._data.dtype:
self._data = self._data.astype(dtype)

args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs
if channel_wise is False:
if ("ch_idx" in args) or ("ch_name" in args):
raise ValueError(
"apply_function cannot access ch_idx or ch_name "
"when channel_wise=False"
)
if "ch_idx" in args:
logger.info("apply_function requested to access ch_idx")
if "ch_name" in args:
logger.info("apply_function requested to access ch_name")

if channel_wise:
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
if n_jobs == 1:
# modify data inplace to save memory
for idx in picks:
self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs)
for ch_idx in picks:
if "ch_idx" in args:
kwargs.update(ch_idx=ch_idx)
if "ch_name" in args:
kwargs.update(ch_name=self.info["ch_names"][ch_idx])
self._data[ch_idx, :] = _check_fun(
fun, data_in[ch_idx, :], **kwargs
)
else:
# use parallel function
data_picks_new = parallel(
p_fun(fun, data_in[p], **kwargs) for p in picks
p_fun(
fun,
data_in[ch_idx],
**kwargs,
**{
k: v
for k, v in [
("ch_name", self.info["ch_names"][ch_idx]),
("ch_idx", ch_idx),
]
if k in args
},
)
for ch_idx in picks
)
for pp, p in enumerate(picks):
self._data[p, :] = data_picks_new[pp]
for run_idx, ch_idx in enumerate(picks):
self._data[ch_idx, :] = data_picks_new[run_idx]
else:
self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs)

Expand Down
29 changes: 29 additions & 0 deletions mne/io/tests/test_apply_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,32 @@ def test_apply_function_verbose():
assert out is raw
raw.apply_function(printer, verbose=True)
assert sio.getvalue().count("\n") == n_chan


def test_apply_function_ch_access():
"""Test apply_function is able to access channel idx."""

def _bad_ch_idx(x, ch_idx):
assert x[0] == ch_idx
return x

def _bad_ch_name(x, ch_name):
assert isinstance(ch_name, str)
assert x[0] == float(ch_name)
return x

data = np.full((2, 10), np.arange(2).reshape(-1, 1))
raw = RawArray(data, create_info(2, 1.0, "mag"))

# test ch_idx access in both code paths (parallel / 1 job)
raw.apply_function(_bad_ch_idx)
raw.apply_function(_bad_ch_idx, n_jobs=2)
raw.apply_function(_bad_ch_name)
raw.apply_function(_bad_ch_name, n_jobs=2)

# test input catches
with pytest.raises(
ValueError,
match="cannot access.*when channel_wise=False",
):
raw.apply_function(_bad_ch_idx, channel_wise=False)
33 changes: 33 additions & 0 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4764,6 +4764,39 @@ def fun(data):
assert_array_equal(out.get_data(non_picks), epochs.get_data(non_picks))


def test_apply_function_epo_ch_access():
"""Test ch-access within apply function to epoch objects."""

def _bad_ch_idx(x, ch_idx):
assert x.shape == (46,)
assert x[0] == ch_idx
return x

def _bad_ch_name(x, ch_name):
assert x.shape == (46,)
assert isinstance(ch_name, str)
assert x[0] == float(ch_name)
return x

data = np.full((2, 100), np.arange(2).reshape(-1, 1))
raw = RawArray(data, create_info(2, 1.0, "mag"))
ev = np.array([[0, 0, 33], [50, 0, 33]])
ep = Epochs(raw, ev, tmin=0, tmax=45, baseline=None, preload=True)

# test ch_idx access in both code paths (parallel / 1 job)
ep.apply_function(_bad_ch_idx)
ep.apply_function(_bad_ch_idx, n_jobs=2)
ep.apply_function(_bad_ch_name)
ep.apply_function(_bad_ch_name, n_jobs=2)

# test input catches
with pytest.raises(
ValueError,
match="cannot access.*when channel_wise=False",
):
ep.apply_function(_bad_ch_idx, channel_wise=False)


@testing.requires_testing_data
def test_add_channels_picks():
"""Check that add_channels properly deals with picks."""
Expand Down
30 changes: 30 additions & 0 deletions mne/tests/test_evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,33 @@ def fun(data, multiplier):
applied = evoked.apply_function(fun, n_jobs=None, multiplier=mult)
assert np.shape(applied.data) == np.shape(evoked_data)
assert np.equal(applied.data, evoked_data * mult).all()


def test_apply_function_evk_ch_access():
"""Check ch-access within the apply_function method for evoked data."""

def _bad_ch_idx(x, ch_idx):
assert x[0] == ch_idx
return x

def _bad_ch_name(x, ch_name):
assert isinstance(ch_name, str)
assert x[0] == float(ch_name)
return x

# create fake evoked data to use for checking apply_function
data = np.full((2, 100), np.arange(2).reshape(-1, 1))
evoked = EvokedArray(data, create_info(2, 1000.0, "eeg"))

# test ch_idx access in both code paths (parallel / 1 job)
evoked.apply_function(_bad_ch_idx)
evoked.apply_function(_bad_ch_idx, n_jobs=2)
evoked.apply_function(_bad_ch_name)
evoked.apply_function(_bad_ch_name, n_jobs=2)

# test input catches
with pytest.raises(
ValueError,
match="cannot access.*when channel_wise=False",
):
evoked.apply_function(_bad_ch_idx, channel_wise=False)
7 changes: 7 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
fun has to be a timeseries (:class:`numpy.ndarray`). The function must
operate on an array of shape ``(n_times,)`` {}.
The function must return an :class:`~numpy.ndarray` shaped like its input.
.. note::
If ``channel_wise=True``, one can optionally access the index and/or the
name of the currently processed channel within the applied function.
This can enable tailored computations for different channels.
To use this feature, add ``ch_idx`` and/or ``ch_name`` as
additional argument(s) to your function definition.
"""
docdict["fun_applyfun"] = applyfun_fun_base.format(
" if ``channel_wise=True`` and ``(len(picks), n_times)`` otherwise"
Expand Down

0 comments on commit 87df00d

Please sign in to comment.