Skip to content

Commit

Permalink
BUG: Fix epochs interpolation for sEEG and don't interpolate over spa…
Browse files Browse the repository at this point in the history
…ns in electrodes and don't include stray contacts from other electrodes circumstantially in a line with an electrode (#12593)
  • Loading branch information
alexrockhill committed May 2, 2024
1 parent 7fd22d6 commit 356e854
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 25 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12593.bugfix.rst
@@ -0,0 +1 @@
Fix error causing :meth:`mne.Epochs.interpolate_bads` not to work for ``seeg`` channels and fix a single contact on neighboring shafts sometimes being included in interpolation, by `Alex Rockhill`_
100 changes: 77 additions & 23 deletions mne/channels/interpolation.py
Expand Up @@ -292,14 +292,16 @@ def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
return inst


def _find_seeg_electrode_shaft(pos, tol=2e-3):
def _find_seeg_electrode_shaft(pos, tol_shaft=0.002, tol_spacing=1):
# 1) find nearest neighbor to define the electrode shaft line
# 2) find all contacts on the same line
# 3) remove contacts with large distances

dist = squareform(pdist(pos))
np.fill_diagonal(dist, np.inf)

shafts = list()
shaft_ts = list()
for i, n1 in enumerate(pos):
if any([i in shaft for shaft in shafts]):
continue
Expand All @@ -308,12 +310,59 @@ def _find_seeg_electrode_shaft(pos, tol=2e-3):
shaft_dists = np.linalg.norm(
np.cross((pos - n1), (pos - n2)), axis=1
) / np.linalg.norm(n2 - n1)
shafts.append(np.where(shaft_dists < tol)[0]) # 2
return shafts
shaft = np.where(shaft_dists < tol_shaft)[0] # 2
shaft_prev = None
for _ in range(10): # avoid potential cycles
if np.array_equal(shaft, shaft_prev):
break
shaft_prev = shaft
# compute median shaft line
v = np.median(
[
pos[i] - pos[j]
for idx, i in enumerate(shaft)
for j in shaft[idx + 1 :]
],
axis=0,
)
c = np.median(pos[shaft], axis=0)
# recompute distances
shaft_dists = np.linalg.norm(
np.cross((pos - c), (pos - c + v)), axis=1
) / np.linalg.norm(v)
shaft = np.where(shaft_dists < tol_shaft)[0]
ts = np.array([np.dot(c - n0, v) / np.linalg.norm(v) ** 2 for n0 in pos[shaft]])
shaft_order = np.argsort(ts)
shaft = shaft[shaft_order]
ts = ts[shaft_order]

# only include the largest group with spacing with the error tolerance
# avoid interpolating across spans between contacts
t_diffs = np.diff(ts)
t_diff_med = np.median(t_diffs)
spacing_errors = (t_diffs - t_diff_med) / t_diff_med
groups = list()
group = [shaft[0]]
for j in range(len(shaft) - 1):
if spacing_errors[j] > tol_spacing:
groups.append(group)
group = [shaft[j + 1]]
else:
group.append(shaft[j + 1])
groups.append(group)
group = [group for group in groups if i in group][0]
ts = ts[np.isin(shaft, group)]
shaft = np.array(group, dtype=int)

shafts.append(shaft)
shaft_ts.append(ts)
return shafts, shaft_ts


@verbose
def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None):
def _interpolate_bads_seeg(
inst, exclude=None, tol_shaft=0.002, tol_spacing=1, verbose=None
):
if exclude is None:
exclude = list()
picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude)
Expand All @@ -328,38 +377,43 @@ def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None):
# Make sure only sEEG are used
bads_idx_pos = bads_idx[picks]

shafts = _find_seeg_electrode_shaft(pos, tol=tol)
shafts, shaft_ts = _find_seeg_electrode_shaft(
pos, tol_shaft=tol_shaft, tol_spacing=tol_spacing
)

# interpolate the bad contacts
picks_bad = list(np.where(bads_idx_pos)[0])
for shaft in shafts:
for shaft, ts in zip(shafts, shaft_ts):
bads_shaft = np.array([idx for idx in picks_bad if idx in shaft])
if bads_shaft.size == 0:
continue
goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)]
if goods_shaft.size < 2:
if goods_shaft.size < 4: # cubic spline requires 3 channels
msg = "No shaft" if shaft.size < 4 else "Not enough good channels"
no_shaft_chs = " and ".join(np.array(inst.ch_names)[bads_shaft])
raise RuntimeError(
f"{goods_shaft.size} good contact(s) found in a line "
f" with {np.array(inst.ch_names)[bads_shaft]}, "
"at least 2 are required for interpolation. "
"Dropping this channel/these channels is recommended."
f"{msg} found in a line with {no_shaft_chs} "
"at least 3 good channels on the same line "
f"are required for interpolation, {goods_shaft.size} found. "
f"Dropping {no_shaft_chs} is recommended."
)
logger.debug(
f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using "
f"data from {np.array(inst.ch_names)[goods_shaft]}"
)
bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0]
goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0]
n1, n2 = pos[shaft][:2]
ts = np.array(
[
-np.dot(n1 - n0, n2 - n1) / np.linalg.norm(n2 - n1) ** 2
for n0 in pos[shaft]
]

z = inst._data[..., goods_shaft, :]
is_epochs = z.ndim == 3
if is_epochs:
z = z.swapaxes(0, 1)
z = z.reshape(z.shape[0], -1)
y = np.arange(z.shape[-1])
out = RectBivariateSpline(x=ts[goods_shaft_idx], y=y, z=z)(
x=ts[bads_shaft_idx], y=y
)
if np.any(np.diff(ts) < 0):
ts *= -1
y = np.arange(inst._data.shape[-1])
inst._data[bads_shaft] = RectBivariateSpline(
x=ts[goods_shaft_idx], y=y, z=inst._data[goods_shaft]
)(x=ts[bads_shaft_idx], y=y) # 3
if is_epochs:
out = out.reshape(bads_shaft.size, inst._data.shape[0], -1)
out = out.swapaxes(0, 1)
inst._data[..., bads_shaft, :] = out
25 changes: 23 additions & 2 deletions mne/channels/tests/test_interpolation.py
Expand Up @@ -364,8 +364,6 @@ def test_interpolation_seeg():
# check that interpolation changes the data in raw
raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info)
raw_before = raw_seeg.copy()
with pytest.raises(RuntimeError, match="1 good contact"):
raw_seeg.interpolate_bads(method=dict(seeg="spline"))
montage = raw_seeg.get_montage()
pos = montage.get_positions()
ch_pos = pos.pop("ch_pos")
Expand All @@ -378,6 +376,29 @@ def test_interpolation_seeg():
assert not np.all(raw_before._data[bads_mask] == raw_after._data[bads_mask])
assert_array_equal(raw_before._data[~bads_mask], raw_after._data[~bads_mask])

# check interpolation on epochs
epochs_seeg.set_montage(make_dig_montage(ch_pos, **pos))
epochs_before = epochs_seeg.copy()
epochs_after = epochs_seeg.interpolate_bads(method=dict(seeg="spline"))
assert not np.all(
epochs_before._data[:, bads_mask] == epochs_after._data[:, bads_mask]
)
assert_array_equal(
epochs_before._data[:, ~bads_mask], epochs_after._data[:, ~bads_mask]
)

# test shaft all bad
epochs_seeg.info["bads"] = epochs_seeg.ch_names
with pytest.raises(RuntimeError, match="Not enough good channels"):
epochs_seeg.interpolate_bads(method=dict(seeg="spline"))

# test bad not on shaft
ch_pos[bads[0]] = np.array([10, 10, 10])
epochs_seeg.info["bads"] = bads
epochs_seeg.set_montage(make_dig_montage(ch_pos, **pos))
with pytest.raises(RuntimeError, match="No shaft found"):
epochs_seeg.interpolate_bads(method=dict(seeg="spline"))


def test_nan_interpolation(raw):
"""Test 'nan' method for interpolating bads."""
Expand Down

0 comments on commit 356e854

Please sign in to comment.