spikeinterface icon indicating copy to clipboard operation
spikeinterface copied to clipboard

[Proposal] Generic preprocessing function for arbitrary preprocessing steps

Open h-mayorquin opened this issue 9 months ago • 2 comments

This PR implements a generic preprocessing step that is implemented as a function. The user defines a pre-processing step with a function that is applied to a data chunk and then that function is called every time the user calls get_traces. The implementation uses partial from the functools in the standard library to separate (possible) heavy computation at the __init__ from the one at get_traces but that's it, otherwise it is very simple.

Here a quick example of how it would look like. I am working for a project where I need to use a bandpass filter that is a bit different from the one in the library:

from scipy.signal import ellip, filtfilt

def bandpass_filter(signal, f_sampling, f_low, f_high):
    wl = f_low / (f_sampling / 2.)
    wh = f_high / (f_sampling / 2.)
    wn = [wl, wh]

    # Designs a 2nd-order Elliptic band-pass filter which passes
    # frequencies between normalized f_low and f_high, and with 0.1 dB of ripple
    # in the passband, and 40 dB of attenuation in the stopband.
    b, a = ellip(2, 0.1, 40, wn, 'bandpass', analog=False)
    # To match Matlab output, we change default padlen from
    # 3*(max(len(a), len(b))) to 3*(max(len(a), len(b)) - 1)
    padlen = 3 * (max(len(a), len(b)) - 1)
    return filtfilt(b, a, signal, axis=0 , padlen=padlen)

I am aware that we could use filter design and the API of spikeinterface to implement something similar but let me use this example as an illustration for how to integrate any generic pre-processing step. My desire would be to to test this specific way of filtering with a couple of functions of spikeinterface, let's say peak detection.

With the PR here this looks like the following:


f_sampling = recording.get_sampling_frequency()
function_kwargs = dict(f_sampling=f_sampling, f_low=300., f_high=6000.)
preprocessor = GenericPreprocessor(recording=recording, function=bandpass_filter, **function_kwargs)

# Then peak detection

Whereas otherwise I would need to do something like this:


class MyPreprocessor(BasePreprocessor):
    
    def __init__(self, recording, f_low, f_high):
        BasePreprocessor.__init__(self, recording)
        self.f_low = f_low
        self.f_high = f_high
        self.f_sampling = recording.get_sampling_frequency()
        
        for parent_segment in self.recording._recording_segments():
            
            segment = MyPreprocessorSegment(parent_segment, self.f_sampling, self.f_low, self.f_high)
            self.add_recording_segment(segment)
        
        for segment in self.recording.get_segments():
            segment.preprocessor = MyPreprocessorSegment(segment, f_sampling, f_low, f_high)

        
class MyPreprocessorSegment(BasePreprocessorSegment):
    
    def __init__(self, segment, f_sampling, f_low, f_high):
        BasePreprocessorSegment.__init__(self, segment)
        self.f_sampling = f_sampling
        self.f_low = f_low
        self.f_high = f_high
        
    def get_traces(self, start_frame, end_frame, channel_indices):
        
        traces = self.parent_segment.get_traces(start_frame, end_frame, channel_indices)
        
        return bandpass_filter(traces, self.f_sampling, self.f_low, self.f_high)
        
    
preprocessor = MyPreprocessor(recording, f_low, f_high)

# Then peak detection

I claim that the way of this PR is way simpler, allows users to tests their ideas quicker and requires them to know less about Spikeinterface internals to do what they want to do.

Some drawbacks:

  • In its current form is not json serializable (but it is pickable).
  • Hard to test because is too general.

Thoughts?

h-mayorquin avatar May 06 '24 23:05 h-mayorquin