Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement channel specific epoch rejection #12219

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

CarinaFo
Copy link
Contributor

#11705

implement a epochs method that allows for channel specific epoch rejection

@dominikwelke
Copy link
Contributor

apart from the main functionality and some test (tbd), we would probably also want some kind of record that this was done.
drop_log comes to mind, but i think as is it wouldnt be suitable because no epoch is dropped.

any good idea?

@dominikwelke
Copy link
Contributor

btw, i dont know about the state of the PR but for me this code doesnt work so far..
i think it's the overwriting with nan values using the mask.. but behavior was not really consistent, i have to track this down :)

mne/epochs.py Outdated Show resolved Hide resolved
CarinaFo and others added 2 commits November 17, 2023 17:06
@CarinaFo
Copy link
Contributor Author

Update: the code works after implementing a preload_check. TBD: do we want channel specific epoch rejection, meaning we kick out epoch outliers based on the channel specific standard deviation OR do we calculate the standard deviation over all channels to define outlier epochs?

@dominikwelke
Copy link
Contributor

related to one of my previous comments:
currently data has to be preloaded, so that chosen epochs can be overwritten with nan. as mentioned, currently this change is not logged anywhere which we should probably do :).
if we were to extend drop_log in a way that it can handle individual channel/epoch combos, this would enable us to process without preloading, as i think is the case for all other rejection methods.

but this would be a deeper change, and basically a design decision.
any thoughts @drammock @larsoner ?

@dominikwelke
Copy link
Contributor

TBD: do we want channel specific epoch rejection, meaning we kick out epoch outliers based on the channel specific standard deviation OR do we calculate the standard deviation over all channels to define outlier epochs?

+1 for this question!
you are right @carina - the requested functionality in #11705 was std within channels, and not across all data..

@drammock
Copy link
Member

Sorry it took me so long to look at this PR

TBD: do we want channel specific epoch rejection, meaning we kick out epoch outliers based on the channel specific standard deviation OR do we calculate the standard deviation over all channels to define outlier epochs?

This implementation is actually a bit more specific than I was expecting ("specific" in the sense that it bakes in the idea of using standard deviation to decide what to do). I was expecting @larsoner's idea (2) from here, as I tried to indicate in my subsequent comment:

marking as NaN and using nanmean / interpolation seems like the right approach to me.

In other words, I was thinking the implementation should be something like:

  • make it easy to replace channel data with nan on a per-epoch basis. Currently if a channel is bad it is bad for all epochs in the Epochs object, so epochs.interpolate_bads can't do nan-replacement on a per-trial basis. The solution might be to modify interpolate_bads or it might be a new function, not sure.
  • modify epochs.average to use np.nanmean
  • deal with the "effective n_averaged" issue somehow

That way, the user's job is to somehow create the inputs to that new function (or the modified interpolate_bads function) --- I'm imagining a dict of {epoch_idx: list_of_bad_chs} as one possible way --- and they can decide how to populate that (std deviation or somethign else, within channel vs across-channel, etc).

@larsoner was this more or less what you were expecting, or am I way off base here?

@larsoner
Copy link
Member

larsoner commented Dec 1, 2023

Yes that sounds reasonable to me!

@CarinaFo CarinaFo closed this Feb 7, 2024
@CarinaFo CarinaFo force-pushed the channel_specific_epoch_rejection branch from 7d17f89 to 9f0dfef Compare February 7, 2024 04:27
…o/mne-python into channel_specific_epoch_rejection
@CarinaFo
Copy link
Contributor Author

@drammock @dominikwelke

I continued working on this and decided to include a method to the BaseEpochs class called set_bad_epochs_to_NaN(). I decided against adjusting interpolate_bads(), as this method is within the class InterpolationMixin, which allows for Raws and Evokeds and would create a conflict with epoch-based rejection.
If you agree with the epochs method, I would add the tests and start thinking about a attribute which states the number of epochs per channel after marking bad epochs. epochs.average() does currently not work with the epochs including NaNs, even though we use np.nanmean().
I think it is related to the complex object check in the lines before (epochs.py, _compute_aggregate(), ll. 1206-1210). I am not super sure why we need those lines and why we can't add epochs.get_data() into np.nanmean(), straight away?

@drammock
Copy link
Member

I continued working on this and decided to include a method to the BaseEpochs class called set_bad_epochs_to_NaN(). I decided against adjusting interpolate_bads(), as this method is within the class InterpolationMixin, which allows for Raws and Evokeds and would create a conflict with epoch-based rejection.

I'll be away for the next ~1.5 weeks, but I just wanted to note that another option would be to define a method of the same name for BaseEpochs (i.e., BaseEpochs.interpolate_bads()) which will then take precedence over the version of that method defined in the mixin class. Within that "overriding" method, you can do things like handle any special cases and then (optionally) pass off remaining computations to the mixin class's method by doing something like

if some_custom_case_holds:  # only relevant for Epochs
    # insert new code here
    return my_new_result
else:
    return super().interpolate_bads(copy_the_relevant_input_params_here)

I'm not sure if this actually will help / work / make it more efficient to implement than your proposed approach of writing a brand new function... but I wanted you to know it's an option, in case you want to work more on this PR before I have a chance to look in greater depth.

@dominikwelke
Copy link
Contributor

hi @CarinaFo ill try to have a deeper look into it soon :)

initially your new approach makes sense to me, as the method is really only intended for epochs, but I'm not deep enough in the codebase to judge potential caveats (other than the api change, which i guess is generally disliked).

@dominikwelke
Copy link
Contributor

and if you didnt see it, i already flagged two minor things that i believe would break the doc built

@dominikwelke
Copy link
Contributor

hi @CarinaFo
i created a pull request into your branch with my revisions (though this was easier than pasting everything here in the chat)

here my initial revisions that make it work (i think).
main changes:

  • the nanmean function had to be swapped somewhere else,
  • applying your method is now only allowed if data are preloaded (which makes things easier and i think is valid).
  • i added an initial test

it's just quick and dirty fixes.. e.g., the test could be much better.
also not sure if we should really change mean to nanmean in the global util function, or whether it makes more sense to allow nanmean as an additional separate option and call it based on a (to be implemented) flag, that epochs were rejected (_check_combine is only used in epochs.py and tfr.py, probably for a similar case, but who knows..).

@dominikwelke
Copy link
Contributor

dominikwelke commented Feb 26, 2024

another comment: currently the input defining bad epochs is quite sparse - a list of lists for all channels, and if a channel doesnt contain bad epochs it has to be an empty list..

another option would be a dictionary s.a. {"CH_NAME/CH_IDX":[LIST OF BAD EPOCHS], "CH_NAME/CH_IDX":[LIST OF BAD EPOCHS],...}
this would make the input easier to create (if user defined), more concise, and also eliminate one layer of potential errors (order of channels / removing channels during preprocessing if we use the channel name instead of index)

identification of epochs would still not happen/be based on sequence.. not optimal (bc other functions might kick epochs out and thereby change order) but its not easy to get around it.

a mask would be another option (an empty/sparse array with the same shape as the data that holds NaN in the wanted places).
this would be the safest, as any possible change in the data shape due to intermediate processing steps would cause an error.
(it would also make the application more efficient than the current loop. not sure how significant, though)
but it's not nice for users to define. we could make a helper function.
or maybe the core devs dont want bad epochs to be user defined, anyway, but only internally by MNE functions :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants