From 356e8546890b8f798a14777991ba7de6cd9ab9bb Mon Sep 17 00:00:00 2001 From: Alex Rockhill Date: Wed, 1 May 2024 17:01:13 -0700 Subject: [PATCH] BUG: Fix epochs interpolation for sEEG and don't interpolate over spans in electrodes and don't include stray contacts from other electrodes circumstantially in a line with an electrode (#12593) --- doc/changes/devel/12593.bugfix.rst | 1 + mne/channels/interpolation.py | 100 +++++++++++++++++------ mne/channels/tests/test_interpolation.py | 25 +++++- 3 files changed, 101 insertions(+), 25 deletions(-) create mode 100644 doc/changes/devel/12593.bugfix.rst diff --git a/doc/changes/devel/12593.bugfix.rst b/doc/changes/devel/12593.bugfix.rst new file mode 100644 index 00000000000..e43d6110716 --- /dev/null +++ b/doc/changes/devel/12593.bugfix.rst @@ -0,0 +1 @@ +Fix error causing :meth:`mne.Epochs.interpolate_bads` not to work for ``seeg`` channels and fix a single contact on neighboring shafts sometimes being included in interpolation, by `Alex Rockhill`_ \ No newline at end of file diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index 6c5042d1d04..28a5058b3ac 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -292,14 +292,16 @@ def _interpolate_bads_nirs(inst, exclude=(), verbose=None): return inst -def _find_seeg_electrode_shaft(pos, tol=2e-3): +def _find_seeg_electrode_shaft(pos, tol_shaft=0.002, tol_spacing=1): # 1) find nearest neighbor to define the electrode shaft line # 2) find all contacts on the same line + # 3) remove contacts with large distances dist = squareform(pdist(pos)) np.fill_diagonal(dist, np.inf) shafts = list() + shaft_ts = list() for i, n1 in enumerate(pos): if any([i in shaft for shaft in shafts]): continue @@ -308,12 +310,59 @@ def _find_seeg_electrode_shaft(pos, tol=2e-3): shaft_dists = np.linalg.norm( np.cross((pos - n1), (pos - n2)), axis=1 ) / np.linalg.norm(n2 - n1) - shafts.append(np.where(shaft_dists < tol)[0]) # 2 - return shafts + shaft = np.where(shaft_dists < tol_shaft)[0] # 2 + shaft_prev = None + for _ in range(10): # avoid potential cycles + if np.array_equal(shaft, shaft_prev): + break + shaft_prev = shaft + # compute median shaft line + v = np.median( + [ + pos[i] - pos[j] + for idx, i in enumerate(shaft) + for j in shaft[idx + 1 :] + ], + axis=0, + ) + c = np.median(pos[shaft], axis=0) + # recompute distances + shaft_dists = np.linalg.norm( + np.cross((pos - c), (pos - c + v)), axis=1 + ) / np.linalg.norm(v) + shaft = np.where(shaft_dists < tol_shaft)[0] + ts = np.array([np.dot(c - n0, v) / np.linalg.norm(v) ** 2 for n0 in pos[shaft]]) + shaft_order = np.argsort(ts) + shaft = shaft[shaft_order] + ts = ts[shaft_order] + + # only include the largest group with spacing with the error tolerance + # avoid interpolating across spans between contacts + t_diffs = np.diff(ts) + t_diff_med = np.median(t_diffs) + spacing_errors = (t_diffs - t_diff_med) / t_diff_med + groups = list() + group = [shaft[0]] + for j in range(len(shaft) - 1): + if spacing_errors[j] > tol_spacing: + groups.append(group) + group = [shaft[j + 1]] + else: + group.append(shaft[j + 1]) + groups.append(group) + group = [group for group in groups if i in group][0] + ts = ts[np.isin(shaft, group)] + shaft = np.array(group, dtype=int) + + shafts.append(shaft) + shaft_ts.append(ts) + return shafts, shaft_ts @verbose -def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None): +def _interpolate_bads_seeg( + inst, exclude=None, tol_shaft=0.002, tol_spacing=1, verbose=None +): if exclude is None: exclude = list() picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude) @@ -328,21 +377,25 @@ def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None): # Make sure only sEEG are used bads_idx_pos = bads_idx[picks] - shafts = _find_seeg_electrode_shaft(pos, tol=tol) + shafts, shaft_ts = _find_seeg_electrode_shaft( + pos, tol_shaft=tol_shaft, tol_spacing=tol_spacing + ) # interpolate the bad contacts picks_bad = list(np.where(bads_idx_pos)[0]) - for shaft in shafts: + for shaft, ts in zip(shafts, shaft_ts): bads_shaft = np.array([idx for idx in picks_bad if idx in shaft]) if bads_shaft.size == 0: continue goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)] - if goods_shaft.size < 2: + if goods_shaft.size < 4: # cubic spline requires 3 channels + msg = "No shaft" if shaft.size < 4 else "Not enough good channels" + no_shaft_chs = " and ".join(np.array(inst.ch_names)[bads_shaft]) raise RuntimeError( - f"{goods_shaft.size} good contact(s) found in a line " - f" with {np.array(inst.ch_names)[bads_shaft]}, " - "at least 2 are required for interpolation. " - "Dropping this channel/these channels is recommended." + f"{msg} found in a line with {no_shaft_chs} " + "at least 3 good channels on the same line " + f"are required for interpolation, {goods_shaft.size} found. " + f"Dropping {no_shaft_chs} is recommended." ) logger.debug( f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using " @@ -350,16 +403,17 @@ def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None): ) bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0] goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0] - n1, n2 = pos[shaft][:2] - ts = np.array( - [ - -np.dot(n1 - n0, n2 - n1) / np.linalg.norm(n2 - n1) ** 2 - for n0 in pos[shaft] - ] + + z = inst._data[..., goods_shaft, :] + is_epochs = z.ndim == 3 + if is_epochs: + z = z.swapaxes(0, 1) + z = z.reshape(z.shape[0], -1) + y = np.arange(z.shape[-1]) + out = RectBivariateSpline(x=ts[goods_shaft_idx], y=y, z=z)( + x=ts[bads_shaft_idx], y=y ) - if np.any(np.diff(ts) < 0): - ts *= -1 - y = np.arange(inst._data.shape[-1]) - inst._data[bads_shaft] = RectBivariateSpline( - x=ts[goods_shaft_idx], y=y, z=inst._data[goods_shaft] - )(x=ts[bads_shaft_idx], y=y) # 3 + if is_epochs: + out = out.reshape(bads_shaft.size, inst._data.shape[0], -1) + out = out.swapaxes(0, 1) + inst._data[..., bads_shaft, :] = out diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index 7e282562955..31315343ddc 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -364,8 +364,6 @@ def test_interpolation_seeg(): # check that interpolation changes the data in raw raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info) raw_before = raw_seeg.copy() - with pytest.raises(RuntimeError, match="1 good contact"): - raw_seeg.interpolate_bads(method=dict(seeg="spline")) montage = raw_seeg.get_montage() pos = montage.get_positions() ch_pos = pos.pop("ch_pos") @@ -378,6 +376,29 @@ def test_interpolation_seeg(): assert not np.all(raw_before._data[bads_mask] == raw_after._data[bads_mask]) assert_array_equal(raw_before._data[~bads_mask], raw_after._data[~bads_mask]) + # check interpolation on epochs + epochs_seeg.set_montage(make_dig_montage(ch_pos, **pos)) + epochs_before = epochs_seeg.copy() + epochs_after = epochs_seeg.interpolate_bads(method=dict(seeg="spline")) + assert not np.all( + epochs_before._data[:, bads_mask] == epochs_after._data[:, bads_mask] + ) + assert_array_equal( + epochs_before._data[:, ~bads_mask], epochs_after._data[:, ~bads_mask] + ) + + # test shaft all bad + epochs_seeg.info["bads"] = epochs_seeg.ch_names + with pytest.raises(RuntimeError, match="Not enough good channels"): + epochs_seeg.interpolate_bads(method=dict(seeg="spline")) + + # test bad not on shaft + ch_pos[bads[0]] = np.array([10, 10, 10]) + epochs_seeg.info["bads"] = bads + epochs_seeg.set_montage(make_dig_montage(ch_pos, **pos)) + with pytest.raises(RuntimeError, match="No shaft found"): + epochs_seeg.interpolate_bads(method=dict(seeg="spline")) + def test_nan_interpolation(raw): """Test 'nan' method for interpolating bads."""