Skip to content

Commit

Permalink
animate_topomap - CSD fix (#12605)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
  • Loading branch information
3 people committed May 15, 2024
1 parent 44c69f5 commit 4cffc34
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 35 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12605.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug where :meth:`mne.Evoked.animate_topomap` did not work with :func:`mne.preprocessing.compute_current_source_density` - modified data, by `Michal Žák`_.
2 changes: 1 addition & 1 deletion mne/_fiff/pick.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def _picks_by_type(info, meg_combined=False, ref_meg=False, exclude="bads"):
exclude = _check_info_exclude(info, exclude)
if meg_combined == "auto":
meg_combined = _mag_grad_dependent(info)
picks_list = []

picks_list = {ch_type: list() for ch_type in _DATA_CH_TYPES_SPLIT}
for k in range(info["nchan"]):
if info["chs"][k]["ch_name"] not in exclude:
Expand Down
39 changes: 21 additions & 18 deletions mne/viz/tests/test_topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from matplotlib.colors import PowerNorm, TwoSlopeNorm
from matplotlib.patches import Circle
from numpy.testing import assert_almost_equal, assert_array_equal, assert_equal

Expand Down Expand Up @@ -43,7 +44,11 @@
)
from mne.datasets import testing
from mne.io import RawArray, read_info, read_raw_fif
from mne.preprocessing import compute_bridged_electrodes
from mne.preprocessing import (
ICA,
compute_bridged_electrodes,
compute_current_source_density,
)
from mne.time_frequency.tfr import AverageTFRArray
from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap
from mne.viz.tests.test_raw import _proj_status
Expand Down Expand Up @@ -179,7 +184,21 @@ def test_plot_topomap_animation(capsys):
anim._func(1) # _animate has to be tested separately on 'Agg' backend.
out, _ = capsys.readouterr()
assert "extrapolation mode local to 0" in out
plt.close("all")


def test_plot_topomap_animation_csd(capsys):
"""Test topomap plotting of CSD data."""
# evoked
evoked = read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0))
evoked_csd = compute_current_source_density(evoked)

# Test animation
_, anim = evoked_csd.animate_topomap(
ch_type="csd", times=[0, 0.1], butterfly=False, time_unit="s", verbose="debug"
)
anim._func(1) # _animate has to be tested separately on 'Agg' backend.
out, _ = capsys.readouterr()
assert "extrapolation mode head to 0" in out


@pytest.mark.filterwarnings("ignore:.*No contour levels.*:UserWarning")
Expand All @@ -190,7 +209,6 @@ def test_plot_topomap_animation_nirs(fnirs_evoked, capsys):
out, _ = capsys.readouterr()
assert "extrapolation mode head to 0" in out
assert len(fig.axes) == 2
plt.close("all")


def test_plot_evoked_topomap_errors(evoked, monkeypatch):
Expand Down Expand Up @@ -553,7 +571,6 @@ def patch():
orig_bads = evoked_grad.info["bads"]
evoked_grad.plot_topomap(ch_type="grad", times=[0], time_unit="ms")
assert_array_equal(evoked_grad.info["bads"], orig_bads)
plt.close("all")


def test_plot_tfr_topomap():
Expand Down Expand Up @@ -685,8 +702,6 @@ def test_plot_topomap_neuromag122():

def test_plot_topomap_bads():
"""Test plotting topomap with bad channels (gh-7213)."""
import matplotlib.pyplot as plt

data = np.random.RandomState(0).randn(3, 1000)
raw = RawArray(data, create_info(3, 1000.0, "eeg"))
ch_pos_dict = {name: pos for name, pos in zip(raw.ch_names, np.eye(3))}
Expand All @@ -695,7 +710,6 @@ def test_plot_topomap_bads():
raw.info["bads"] = raw.ch_names[:count]
raw.info._check_consistency()
plot_topomap(data[:, 0], raw.info)
plt.close("all")


def test_plot_topomap_channel_distance():
Expand All @@ -713,35 +727,28 @@ def test_plot_topomap_channel_distance():
evoked.set_montage(ten_five)

evoked.plot_topomap(sphere=0.05, res=8)
plt.close("all")


def test_plot_topomap_bads_grad():
"""Test plotting topomap with bad gradiometer channels (gh-8802)."""
import matplotlib.pyplot as plt

data = np.random.RandomState(0).randn(203)
info = read_info(evoked_fname)
info["bads"] = ["MEG 2242"]
picks = pick_types(info, meg="grad")
info = pick_info(info, picks)
assert len(info["chs"]) == 203
plot_topomap(data, info, res=8)
plt.close("all")


def test_plot_topomap_nirs_overlap(fnirs_epochs):
"""Test plotting nirs topomap with overlapping channels (gh-7414)."""
fig = fnirs_epochs["A"].average(picks="hbo").plot_topomap()
assert len(fig.axes) == 5
plt.close("all")


def test_plot_topomap_nirs_ica(fnirs_epochs):
"""Test plotting nirs ica topomap."""
pytest.importorskip("sklearn")
from mne.preprocessing import ICA

fnirs_epochs = fnirs_epochs.load_data().pick(picks="hbo")
fnirs_epochs = fnirs_epochs.pick(picks=range(30))

Expand All @@ -754,7 +761,6 @@ def test_plot_topomap_nirs_ica(fnirs_epochs):
ica = ICA().fit(fnirs_epochs)
fig = ica.plot_components()
assert len(fig[0].axes) == 20
plt.close("all")


def test_plot_cov_topomap():
Expand All @@ -763,13 +769,10 @@ def test_plot_cov_topomap():
info = read_info(evoked_fname)
cov.plot_topomap(info)
cov.plot_topomap(info, noise_cov=cov)
plt.close("all")


def test_plot_topomap_cnorm():
"""Test colormap normalization."""
from matplotlib.colors import PowerNorm, TwoSlopeNorm

rng = np.random.default_rng(42)
v = rng.uniform(low=-1, high=2.5, size=64)
v[:3] = [-1, 0, 2.5]
Expand Down
17 changes: 2 additions & 15 deletions mne/viz/topomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3249,21 +3249,8 @@ def _topomap_animation(
from matplotlib import pyplot as plt

if ch_type is None:
ch_type = _picks_by_type(evoked.info)[0][0]
if ch_type not in (
"mag",
"grad",
"eeg",
"hbo",
"hbr",
"fnirs_od",
"fnirs_cw_amplitude",
):
raise ValueError(
"Channel type not supported. Supported channel "
"types include 'mag', 'grad', 'eeg'. 'hbo', 'hbr', "
"'fnirs_cw_amplitude', and 'fnirs_od'."
)
ch_type = _get_plot_ch_type(evoked, ch_type)

time_unit, _ = _check_time_unit(time_unit, evoked.times)
if times is None:
times = np.linspace(evoked.times[0], evoked.times[-1], 10)
Expand Down
4 changes: 3 additions & 1 deletion mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2821,5 +2821,7 @@ def _get_plot_ch_type(inst, ch_type, allow_ref_meg=False):
ch_type = type_
break
else:
raise RuntimeError("No plottable channel types found")
raise RuntimeError(
f"No plottable channel types found. Allowed types are: {allowed_types}"
)
return ch_type

0 comments on commit 4cffc34

Please sign in to comment.