Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP [ENH] metadata attribute for Annotations #12213

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
56 changes: 43 additions & 13 deletions mne/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
_datetime = datetime


def _check_o_d_s_c(onset, duration, description, ch_names):
def _check_o_d_s_c(onset, duration, description, ch_names, metadata):
onset = np.atleast_1d(np.array(onset, dtype=float))
if onset.ndim != 1:
raise ValueError(
Expand Down Expand Up @@ -95,13 +95,28 @@ def _check_o_d_s_c(onset, duration, description, ch_names):
_validate_type(name, str, f"ch_names[{ai}][{ci}]")
ch_names = _ndarray_ch_names(ch_names)

if not (len(onset) == len(duration) == len(description) == len(ch_names)):
# metadata:
_validate_type(metadata, (None, tuple, list, np.ndarray), "metadata")
if metadata is None:
metadata = [{}] * len(onset)
metadata = list(metadata)
for ai, md in enumerate(metadata):
_validate_type(md, dict, f"metadata[{ai}]")
metadata = np.array(metadata, dtype=object)

if not (
len(onset)
== len(duration)
== len(description)
== len(ch_names)
== len(metadata)
):
raise ValueError(
"Onset, duration, description, and ch_names must be "
"Onset, duration, description, ch_names, and metadata must be "
f"equal in sizes, got {len(onset)}, {len(duration)}, "
f"{len(description)}, and {len(ch_names)}."
f"{len(description)}, {len(ch_names)}, and {len(metadata)}."
)
return onset, duration, description, ch_names
return onset, duration, description, ch_names, metadata


def _ndarray_ch_names(ch_names):
Expand Down Expand Up @@ -276,11 +291,12 @@ class Annotations:
""" # noqa: E501

def __init__(
self, onset, duration, description, orig_time=None, ch_names=None
self, onset, duration, description, orig_time=None, ch_names=None, metadata=None
): # noqa: D102
self._orig_time = _handle_meas_date(orig_time)
self.onset, self.duration, self.description, self.ch_names = _check_o_d_s_c(
onset, duration, description, ch_names
self.onset, self.duration, self.description, self.ch_names,
self._metadata = _check_o_d_s_c(
onset, duration, description, ch_names, metadata
)
self._sort() # ensure we're sorted

Expand All @@ -298,6 +314,7 @@ def __eq__(self, other):
and np.array_equal(self.duration, other.duration)
and np.array_equal(self.description, other.description)
and np.array_equal(self.ch_names, other.ch_names)
and np.array_equal(self._metadata, other._metadata)
and self.orig_time == other.orig_time
)

Expand Down Expand Up @@ -345,7 +362,11 @@ def __iadd__(self, other):
"(got %s != %s)" % (self.orig_time, other.orig_time)
)
return self.append(
other.onset, other.duration, other.description, other.ch_names
other.onset,
other.duration,
other.description,
other.ch_names,
other._metadata,
)

def __iter__(self):
Expand All @@ -359,12 +380,13 @@ def __iter__(self):
def __getitem__(self, key, *, with_ch_names=None):
"""Propagate indexing and slicing to the underlying numpy structure."""
if isinstance(key, int_like):
out_keys = ("onset", "duration", "description", "orig_time")
out_keys = ("onset", "duration", "description", "orig_time", "metadata")
out_vals = (
self.onset[key],
self.duration[key],
self.description[key],
self.orig_time,
self._metadata[key],
)
if with_ch_names or (with_ch_names is None and self._any_ch_names()):
out_keys += ("ch_names",)
Expand All @@ -378,10 +400,11 @@ def __getitem__(self, key, *, with_ch_names=None):
description=self.description[key],
orig_time=self.orig_time,
ch_names=self.ch_names[key],
metadata=self._metadata[key],
)

@fill_doc
def append(self, onset, duration, description, ch_names=None):
def append(self, onset, duration, description, ch_names=None, metadata=None):
"""Add an annotated segment. Operates inplace.

Parameters
Expand All @@ -395,6 +418,10 @@ def append(self, onset, duration, description, ch_names=None):
Description for the annotation. To reject epochs, use description
starting with keyword 'bad'.
%(ch_names_annot)s
metadata : dict | array-like
Metadata for the annotation. Can be a dict or an array-like
object of dicts. If an array-like object, must be the same length
as ``onset``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with this we'd be exposing metadata to the user --> weren't we discussing keeping it private, and having dedicated functions populate / add / use / change this data?

If we make it public here (exposing it in the docstring), then I think we might as well make it completely public 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with this we'd be exposing metadata to the user --> weren't we discussing keeping it private, and having dedicated functions populate / add / use / change this data?

If we make it public here (exposing it in the docstring), then I think we might as well make it completely public 🤔

good point. That's inconsistent. We can also not allow to append Annotations with metadata manually. And behind the scenes then just do:

self._metadata = np.append(self._metadata, {})

I haven't checked, yet, whether .append() is used somewhere in code, though.


.. versionadded:: 0.23

Expand All @@ -409,13 +436,14 @@ def append(self, onset, duration, description, ch_names=None):
to not only ``list.append``, but also
`list.extend <https://docs.python.org/3/library/stdtypes.html#mutable-sequence-types>`__.
""" # noqa: E501
onset, duration, description, ch_names = _check_o_d_s_c(
onset, duration, description, ch_names
onset, duration, description, ch_names, metadata = _check_o_d_s_c(
onset, duration, description, ch_names, metadata
)
self.onset = np.append(self.onset, onset)
self.duration = np.append(self.duration, duration)
self.description = np.append(self.description, description)
self.ch_names = np.append(self.ch_names, ch_names)
self._metadata = np.append(self._metadata, metadata)
self._sort()
return self

Expand All @@ -442,6 +470,7 @@ def delete(self, idx):
self.duration = np.delete(self.duration, idx)
self.description = np.delete(self.description, idx)
self.ch_names = np.delete(self.ch_names, idx)
self._metadata = np.delete(self._metadata, idx)

def to_data_frame(self):
"""Export annotations in tabular structure as a pandas DataFrame.
Expand Down Expand Up @@ -564,6 +593,7 @@ def _sort(self):
self.duration = self.duration[order]
self.description = self.description[order]
self.ch_names = self.ch_names[order]
self._metadata = self._metadata[order]

@verbose
def crop(
Expand Down