Skip to content

Commit

Permalink
wip epochs spectrum
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill committed Dec 19, 2023
1 parent 797d6a5 commit 9e4eaf1
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ def _get_data(
if self.preload:
self._data = data

# Now update our properties (excepd data, which is already fixed)
# Now update our properties (except data, which is already fixed)
self._getitem(
good_idx, None, copy=False, drop_event_id=False, select_data=False
)
Expand Down
2 changes: 1 addition & 1 deletion mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ def compute_psd(
self._set_legacy_nfft_default(tmin, tmax, method, method_kw)
_validate_type(ragged_epochs, (bool, str, list), 'ragged_epochs')
kwargs = dict(
self,
inst=self,
method=method,
fmin=fmin,
fmax=fmax,
Expand Down
47 changes: 37 additions & 10 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,29 @@ def _returns_complex_tapers(self, **method_kw):
def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, annotations=None, verbose=None):
# make the spectra
if annotations:
result = list()
for start, duration_s in zip(self.events[:, 0], annotations.duration):
# fencepost
start = self.events[0, 0] - self.inst.first_samp
duration = int(round(annotations.duration[0] * self.sfreq))
result_tmp = self._psd_func(
data[:, start: start + duration], self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose
)
data_out = np.zeros((self.events.shape[0],) + result_tmp[0].shape)
data_out[0] = result_tmp[0]
freqs = result_tmp[1]
if self._returns_complex_tapers(**method_kw):
weights = np.zeros((self.events.shape[0],) + result_tmp[2].shape)
weights[0] = result_tmp[2]
for i, (start, duration_s) in enumerate(zip(self.events[1:, 0], annotations.duration[1:])):
start -= self.inst.first_samp
duration = int(round(duration_s * self.sfreq))
result = self._psd_func(
data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose
result_tmp = self._psd_func(
data[:, start: start + duration], self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose
)
data_out[i + 1] = result_tmp[0]
assert np.array_equal(result_tmp[1], freqs)
if self._returns_complex_tapers(**method_kw):
weights[i + 1] = result_tmp[2]
result = (data_out, freqs, weights) if self._returns_complex_tapers(**method_kw) else (data_out, freqs)
else:
result = self._psd_func(
data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose
Expand Down Expand Up @@ -1294,6 +1311,8 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
%(proj_psd)s
%(remove_dc)s
%(ragged_epochs)s
%(reject_by_annotation_psd)s
Only used for ``ragged_epochs=True``.
%(n_jobs)s
%(verbose)s
%(method_kw_psd)s
Expand Down Expand Up @@ -1332,6 +1351,7 @@ def __init__(
proj,
remove_dc,
ragged_epochs=False,
reject_by_annotation=False,
*,
n_jobs=None,
verbose=None,
Expand All @@ -1356,19 +1376,25 @@ def __init__(
verbose=verbose,
**method_kw,
)
from ..io import BaseRaw
# get just the data we want
data = np.take(self.inst._get_data(picks=self._picks, on_empty="raise"),
self._time_mask, axis=-1)
if isinstance(self.inst, BaseRaw):
data = self.inst.get_data(self._picks)
else:
data = self.inst._get_data(picks=self._picks, on_empty="raise")[
:, :, self._time_mask]
# set metadata
if ragged_epochs:
self.events, self.event_id = events_from_annotations(self.inst.annotations)
self.events, self.event_id = events_from_annotations(self.inst)
if reject_by_annotation: # remove bad annotations
self.events = np.delete(self.events, [descr.lower().startswith("bad")
for descr in self.inst.annotations.description],
axis=0)
# select only events in the time interval
self.events = self.events[self._time_mask[self.events[:, 0] - self.inst.first_samp]]
self._shape = ((self._events.shape[0])) + self._shape
self.selection = np.arange(self.events.shape[0])
self.drop_log = [()] * self.events.shape[0]
else:
self._shape = (len(self.inst),) + self._shape
# we need these for to_data_frame()
self.event_id = self.inst.event_id.copy()
self.events = self.inst.events.copy()
Expand All @@ -1377,11 +1403,12 @@ def __init__(
self.drop_log = deepcopy(self.inst.drop_log)
# compute the spectra
self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, self.inst.annotations if ragged_epochs else None, verbose)
self._shape = ((self.events.shape[0] if ragged_epochs else len(self.inst)),) + self._shape
self._dims = ("epoch",) + self._dims
# check for correct shape and bad values
self._check_values()
del self._shape
self._metadata = self.inst.metadata
self._metadata = self.inst.metadata if hasattr(self.inst, "metadata") else None
# save memory
del self.inst

Expand Down

0 comments on commit 9e4eaf1

Please sign in to comment.