Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework and harmonize the channel selection 'pick' API and expose public functions #12341

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/_includes/memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Similarly, epochs can also be be read from disk on-demand. For example::
import mne
events = mne.find_events(raw)
event_id, tmin, tmax = 1, -0.2, 0.5
# TODO: https://github.com/mne-tools/mne-python/issues/11913
picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, eog=True)
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), reject=dict(eeg=80e-6, eog=150e-6),
Expand Down
6 changes: 0 additions & 6 deletions doc/api/sensor_space.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ Sensor Space Data
equalize_channels
grand_average
match_channel_orders
pick_channels
pick_channels_cov
pick_channels_forward
pick_channels_regexp
pick_types
pick_types_forward
pick_info
read_epochs
read_reject_parameters
Expand Down
4 changes: 2 additions & 2 deletions doc/changes/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Bugs

- Fix bug in :func:`mne.io.read_raw_brainvision` when BrainVision data are acquired with the Brain Products "V-Amp" amplifier and disabled lowpass filter is marked with value ``0`` (:gh:`10517` by :newcontrib:`Alessandro Tonin`)

- Fix bug in :func:`mne.pick_types` and related methods where ``csd=True`` was not passed handled properly (:gh:`10470` by :newcontrib:`Matthias Dold`)
- Fix bug in ``mne.pick_types`` and related methods where ``csd=True`` was not passed handled properly (:gh:`10470` by :newcontrib:`Matthias Dold`)

- Fix bug where plots produced using the ``'qt'`` / ``mne_qt_browser`` backend could not be added using :meth:`mne.Report.add_figure` (:gh:`10485` by `Eric Larson`_)

Expand Down Expand Up @@ -145,7 +145,7 @@ Bugs

- Fix bug in coregistration GUI that prevented it from starting up if only a high-resolution head model was available (:gh:`10543` by `Richard Höchenberger`_)

- Fix bug with :class:`mne.Epochs.add_reference_channels` where attributes were not updated properly so subsequent `~mne.Epochs.pick_types` calls were broken (:gh:`10912` by `Eric Larson`_)
- Fix bug with :class:`mne.Epochs.add_reference_channels` where attributes were not updated properly so subsequent ``mne.Epochs.pick_types`` calls were broken (:gh:`10912` by `Eric Larson`_)
-
- Fix bug in the :class:`mne.viz.Brain` tool bar that prevented the buttons to call the corresponding feature (:gh:`10560` by `Guillaume Favelier`_)

Expand Down
88 changes: 88 additions & 0 deletions mne/_fiff/_pick.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from mne.utils import _validate_type

if TYPE_CHECKING:
from re import Pattern
from typing import Optional

import numpy as np
from numpy.typing import DTypeLike, NDArray

from .. import Info

ScalarIntType: tuple[DTypeLike, ...] = (np.int8, np.int16, np.int32, np.int64)


# fmt: off
def pick_ch_names_to_idx(
ch_names: list[str] | tuple[str] | set[str],
picks: Optional[list[str | int] | tuple[str | int] | set[str | int] | NDArray[+ScalarIntType] | str | int | Pattern | slice], # noqa: E501
exclude: list[str | int] | tuple[str | int] | set[str | int] | NDArray[+ScalarIntType] | str | int | Pattern | slice, # noqa: E501
) -> NDArray[np.int32]:
"""Pick on a list-like of channels with validation.

Replaces:
- pick_channels
- pick_channel_regexp
"""
_validate_type(ch_names, (list, tuple, set), "ch_names")
ch_names = list(ch_names) if isinstance(ch_names, (set, tuple)) else ch_names
exclude = _ensure_int_array_pick_exclude_with_ch_names(ch_names, exclude, "exclude")
if picks is None or picks == "all":
picks = np.arange(len(ch_names), dtype=np.int32)
else:
picks = _ensure_int_array_pick_exclude_with_ch_names(ch_names, picks, "picks")
return np.setdiff1d(picks, exclude, assume_unique=True).astype(np.int32)


def _ensure_int_array_pick_exclude_with_ch_names(
ch_names: list[str],
var: list[str | int] | tuple[str | int] | set[str | int] | NDArray[+ScalarIntType] | str | int | Pattern | slice, # noqa: E501
var_name: str
) -> NDArray[np.int32]:
pass


def pick_info_to_idx(
info: Info,
picks: Optional[list[str | int] | tuple[str | int] | set[str | int] | NDArray[+ScalarIntType] | str | int | Pattern | slice], # noqa: E501
exclude: list[str | int] | tuple[str | int] | set[str | int] | NDArray[+ScalarIntType] | str | int | Pattern | slice, # noqa: E501
) -> NDArray[np.int32]:
"""Pick on an info with validation.

Replaces:
- pick_channels
- pick_channels_regexp
- pick_types
"""
_validate_type(info, Info, "info")
if exclude == "bads":
exclude = np.array([info["ch_names"].index(ch) for ch in info["bads"]], dtype=np.int32) # noqa: E501
else:
exclude = _ensure_int_array_pick_exclude_with_info(info, exclude, "exclude")
if picks is None or picks == "all":
picks = np.arange(len(info["ch_names"]), dtype=np.int32)
elif picks == "bads":
exclude = np.array([info["ch_names"].index(ch) for ch in info["bads"]], dtype=np.int32) # noqa: E501
elif picks == "data":
return _pick_data_to_idx(info, exclude)
else:
picks = _ensure_int_array_pick_exclude_with_info(info, picks, "picks")
return np.setdiff1d(picks, exclude, assume_unique=True).astype(np.int32)


def _pick_data_to_idx(info: Info, exclude: NDArray[np.int32]):
"""Pick all data channels without validation."""
pass


def _ensure_int_array_pick_exclude_with_info(
info: Info,
var: list[str | int] | tuple[str | int] | set[str | int] | NDArray[+ScalarIntType] | str | int | Pattern | slice, # noqa: E501
var_name: str
) -> NDArray[np.int32]:
pass
# fmt: on
7 changes: 7 additions & 0 deletions mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_ensure_int,
_validate_type,
fill_doc,
legacy,
logger,
verbose,
warn,
Expand Down Expand Up @@ -258,6 +259,7 @@ def channel_type(info, idx):
return first_kind


@legacy
@verbose
def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None):
"""Pick channels by names.
Expand Down Expand Up @@ -339,6 +341,7 @@ def pick_channels(ch_names, include, exclude=[], ordered=None, *, verbose=None):
return np.array(sel, int)


@legacy
def pick_channels_regexp(ch_names, regexp):
"""Pick channels using regular expression.

Expand Down Expand Up @@ -456,6 +459,7 @@ def _check_info_exclude(info, exclude):
return exclude


@legacy
@fill_doc
def pick_types(
info,
Expand Down Expand Up @@ -705,6 +709,7 @@ def _has_kit_refs(info, picks):
return False


@legacy
@verbose
def pick_channels_forward(
orig, include=[], exclude=[], ordered=None, copy=True, *, verbose=None
Expand Down Expand Up @@ -790,6 +795,7 @@ def pick_channels_forward(
return fwd


@legacy
def pick_types_forward(
orig,
meg=False,
Expand Down Expand Up @@ -892,6 +898,7 @@ def channel_indices_by_type(info, picks=None):
return idx_by_type


@legacy
@verbose
def pick_channels_cov(
orig, include=[], exclude="bads", ordered=None, copy=True, *, verbose=None
Expand Down
11 changes: 11 additions & 0 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
copy_function_doc_to_method_doc,
eigh,
fill_doc,
legacy,
logger,
verbose,
warn,
Expand Down Expand Up @@ -452,6 +453,7 @@ def plot_topomap(
time_format="",
)

@legacy
@verbose
def pick_channels(self, ch_names, ordered=None, *, verbose=None):
"""Pick channels from this covariance matrix.
Expand All @@ -478,6 +480,15 @@ def pick_channels(self, ch_names, ordered=None, *, verbose=None):
self, ch_names, exclude=[], ordered=ordered, copy=False
)

def pick(self, picks, exclude):
"""Pick channels from the covariance matrix.

Replaces:
- Covariance.pick_channels
- pick_channels_cov
"""
pass


###############################################################################
# IO
Expand Down
11 changes: 11 additions & 0 deletions mne/forward/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _repr_html_(self):
def ch_names(self):
return self["info"]["ch_names"]

@legacy
def pick_channels(self, ch_names, ordered=False):
"""Pick channels from this forward operator.

Expand All @@ -271,6 +272,16 @@ def pick_channels(self, ch_names, ordered=False):
self, ch_names, exclude=[], ordered=ordered, copy=False, verbose=False
)

def pick(self, picks, exclude):
"""Pick channels from the forward operator.

Replaces:
- Forward.pick_channels
- pick_channels_forward
- pick_types_forward
"""
pass


def _block_diag(A, n):
"""Construct a block diagonal from a packed structure.
Expand Down