Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dictionary-type ref_channels in set_eeg_reference() #12366

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12366.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for `dict` type argument ``ref_channels`` to :func:`mne.set_eeg_reference`, to allow flexible re-referencing (e.g. ``raw.set_eeg_reference(ref_channels={'A1': ['A2', 'A3']})`` will set the new A1 data to be ``A1 - (A2 + A3)/2``), by :newcontrib:`Alex Lepauvre` and `Qian Chu`_
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

.. _Alex Kiefer: https://home.alexk101.dev

.. _Alex Lepauvre: https://github.com/AlexLepauvre

.. _Alex Rockhill: https://github.com/alexrockhill/

.. _Alexander Rudiuk: https://github.com/ARudiuk
Expand Down
80 changes: 79 additions & 1 deletion mne/_fiff/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,56 @@ def _check_before_reference(inst, ref_from, ref_to, ch_type):
return ref_to


def _check_before_dict_reference(inst, ref_dict):
ref_from_channels = set()
for key, value in ref_dict.items():
# Check keys
# Check that keys are strings
assert isinstance(key, str), (
"Keys in dict-type ref_channels must be strings, " f"got {type(key)}"
)
# Check that keys are not repeated
assert key not in ref_from_channels, (
"Keys in dict-type ref_channels must be unique, " f"got repeated key {key}"
)
# Check that keys are in ch_names
assert (
key in inst.ch_names
), f"Channel {key} in ref_channels is not in the instance"
ref_from_channels.add(key)

# Check values
if isinstance(value, str):
# Check that value is in ch_names
assert (
value in inst.ch_names
), f"Channel {value} in ref_channels is not in the instance"
# If value is a bad channel, issue a warning
if value in inst.info["bads"]:
msg = f"Channel {value} in ref_channels is marked as bad!"
_on_missing("warns", msg)
elif isinstance(value, list):
for val in value:
# Check that values are strings
assert isinstance(val, str), (
"Values in dict-type ref_channels must be strings or "
f"lists of strings, got {type(val)}"
)
# Check that values are in ch_names
assert (
val in inst.ch_names
), f"Channel {val} in ref_channels is not in the instance"
# If value is a bad channel, issue a warning
if val in inst.info["bads"]:
msg = f"Channel {val} in ref_channels is marked as bad!"
_on_missing("warns", msg)
else:
raise ValueError(
"Values in dict-type ref_channels must be strings or "
f"lists of strings, got {type(value)}"
)


def _apply_reference(inst, ref_from, ref_to=None, forward=None, ch_type="auto"):
"""Apply a custom EEG referencing scheme."""
ref_to = _check_before_reference(inst, ref_from, ref_to, ch_type)
Expand Down Expand Up @@ -155,6 +205,31 @@ def _apply_reference(inst, ref_from, ref_to=None, forward=None, ch_type="auto"):
return inst, ref_data


def _apply_dict_reference(inst, ref_dict):
"""Apply a dict-based custom EEG referencing scheme."""
_check_before_dict_reference(inst, ref_dict)

# Copy the data instance to re-reference:
ref_to_data = inst.copy()._data
if len(ref_dict) > 0:
# Loop through each channel to re-reference:
for ch in ref_dict.keys():
assert len(ref_dict[ch]) > 0, f"No channel to re-reference ch-{ch}"
# Get indices of the channels to use as reference
ref_from = pick_channels(inst.ch_names, ref_dict[ch], ordered=True)
# Get indice of channel to re.reference:
ref_to = pick_channels(inst.ch_names, ch, ordered=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add a warning around here if ref_from and ref_to belong to different channel types

# Compute the reference data:
ref_data = inst._data[..., ref_from, :].mean(-2, keepdims=True)
# Subtract the reference data to the channel to re-reference:
ref_to_data[..., ref_to, :] -= ref_data
# Add the data back to the instance:
inst._data = ref_to_data
# Set that custom reference was applied:
inst.info["custom_ref_applied"] = FIFF.FIFFV_MNE_CUSTOM_REF_ON
return inst, ref_to_data


@fill_doc
def add_reference_channels(inst, ref_channels, copy=True):
"""Add reference channels to data that consists of all zeros.
Expand Down Expand Up @@ -430,7 +505,10 @@ def set_eeg_reference(
"reference."
)

return _apply_reference(inst, ref_channels, ch_sel, forward, ch_type=ch_type)
if isinstance(ref_channels, dict):
return _apply_dict_reference(inst, ref_channels)
else:
return _apply_reference(inst, ref_channels, ch_sel, forward, ch_type=ch_type)


def _get_ch_type(inst, ch_type):
Expand Down
22 changes: 20 additions & 2 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3660,13 +3660,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
"""

docdict["ref_channels_set_eeg_reference"] = """
ref_channels : list of str | str
ref_channels : list of str | str | dict
Can be:

- The name(s) of the channel(s) used to construct the reference.
- The name(s) of the channel(s) used to construct the reference for
every channel of ``ch_type``.
- ``'average'`` to apply an average reference (default)
- ``'REST'`` to use the Reference Electrode Standardization Technique
infinity reference :footcite:`Yao2001`.
- A dictionary mapping names of channels to be referenced to (a list of)
names of channels to use as reference. This is the most flexible
re-referencing approaching. For example, {'A1': 'A3'} would replace the
data in channel 'A1' with the difference between 'A1' and 'A3'. To take
the average of multiple channels as reference, supply a list of channel
names as the dictionary value, e.g. {'A1': ['A2', 'A3']}.
- An empty list, in which case MNE will not attempt any re-referencing of
the data
"""
Expand Down Expand Up @@ -3995,6 +4002,17 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
The given EEG electrodes are referenced to a point at infinity using the
lead fields in ``forward``, which helps standardize the signals.

- Different references for different channels
Set ``ref_channels`` to a dictionary mapping source channel names (str)
to the reference channel names (str or list of str). Unlike the other
approaches where the same reference is applied globally, you can set
different references for different channels with this method. For example,
to re-reference channel 'A1' to 'A2' and 'B1' to the average of 'B2' and
'B3', set ``ref_channels={'A1': 'A2', 'B1': ['B2', 'B3']}``. Keys in the
dictionary must be unique. Warnings are issued when a bad channel is
used as a reference and when a mapping involves channels of different
types.

1. If a reference is requested that is not the average reference, this
function removes any pre-existing average reference projections.

Expand Down