Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Generic preprocessing function for arbitrary preprocessing steps #2813

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

h-mayorquin
Copy link
Collaborator

This PR implements a generic preprocessing step that is implemented as a function. The user defines a pre-processing step with a function that is applied to a data chunk and then that function is called every time the user calls get_traces. The implementation uses partial from the functools in the standard library to separate (possible) heavy computation at the __init__ from the one at get_traces but that's it, otherwise it is very simple.

Here a quick example of how it would look like. I am working for a project where I need to use a bandpass filter that is a bit different from the one in the library:

from scipy.signal import ellip, filtfilt

def bandpass_filter(signal, f_sampling, f_low, f_high):
    wl = f_low / (f_sampling / 2.)
    wh = f_high / (f_sampling / 2.)
    wn = [wl, wh]

    # Designs a 2nd-order Elliptic band-pass filter which passes
    # frequencies between normalized f_low and f_high, and with 0.1 dB of ripple
    # in the passband, and 40 dB of attenuation in the stopband.
    b, a = ellip(2, 0.1, 40, wn, 'bandpass', analog=False)
    # To match Matlab output, we change default padlen from
    # 3*(max(len(a), len(b))) to 3*(max(len(a), len(b)) - 1)
    padlen = 3 * (max(len(a), len(b)) - 1)
    return filtfilt(b, a, signal, axis=0 , padlen=padlen)

I am aware that we could use filter design and the API of spikeinterface to implement something similar but let me use this example as an illustration for how to integrate any generic pre-processing step. My desire would be to to test this specific way of filtering with a couple of functions of spikeinterface, let's say peak detection.

With the PR here this looks like the following:

f_sampling = recording.get_sampling_frequency()
function_kwargs = dict(f_sampling=f_sampling, f_low=300., f_high=6000.)
preprocessor = GenericPreprocessor(recording=recording, function=bandpass_filter, **function_kwargs)

# Then peak detection

Whereas otherwise I would need to do something like this:

class MyPreprocessor(BasePreprocessor):
    
    def __init__(self, recording, f_low, f_high):
        BasePreprocessor.__init__(self, recording)
        self.f_low = f_low
        self.f_high = f_high
        self.f_sampling = recording.get_sampling_frequency()
        
        for parent_segment in self.recording._recording_segments():
            
            segment = MyPreprocessorSegment(parent_segment, self.f_sampling, self.f_low, self.f_high)
            self.add_recording_segment(segment)
        
        for segment in self.recording.get_segments():
            segment.preprocessor = MyPreprocessorSegment(segment, f_sampling, f_low, f_high)

        
class MyPreprocessorSegment(BasePreprocessorSegment):
    
    def __init__(self, segment, f_sampling, f_low, f_high):
        BasePreprocessorSegment.__init__(self, segment)
        self.f_sampling = f_sampling
        self.f_low = f_low
        self.f_high = f_high
        
    def get_traces(self, start_frame, end_frame, channel_indices):
        
        traces = self.parent_segment.get_traces(start_frame, end_frame, channel_indices)
        
        return bandpass_filter(traces, self.f_sampling, self.f_low, self.f_high)
        
    
preprocessor = MyPreprocessor(recording, f_low, f_high)

# Then peak detection

I claim that the way of this PR is way simpler, allows users to tests their ideas quicker and requires them to know less about Spikeinterface internals to do what they want to do.

Some drawbacks:

  • In its current form is not json serializable (but it is pickable).
  • Hard to test because is too general.

Thoughts?

@h-mayorquin h-mayorquin added the preprocessing Related to preprocessing module label May 6, 2024
@h-mayorquin h-mayorquin self-assigned this May 6, 2024
@h-mayorquin h-mayorquin changed the title Generic/function preprocessing for arbitrary preprocessing steps [Proposal] Generic preprocessing function for arbitrary preprocessing steps May 7, 2024
@JoeZiminski
Copy link
Contributor

This is a really nice idea, I don't think testing will be too much of a problem because as long as the machinery is working well, it is up to the user to ensure their passed functions work correctly. So in that case tests could just be checking the general machinery with some basic passed function (e.g. mean) as you have done. I can't off the top of my head think of any cases where this would cause unintended behaviours.

In some ways the generality of this method is a blessing as it does not promise anything - it is on the user to check the behaviour of their function is as intended. That being said it might be worth adding some checks on the output of the passed function in the get_traces method just to ensure some basic things (e.g. it is the right shape, dtype) before it is passed up the preprocessing chain).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
preprocessing Related to preprocessing module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants