spikeinterface icon indicating copy to clipboard operation
spikeinterface copied to clipboard

Alter `nn_isolation` to use existing principal components

Open chrishalcrow opened this issue 6 months ago • 3 comments

Updated the nn_isolation metric to use the already-calculated PCs.

The previous implementation of the nn_isolation metric does the following:

  1. Take two units, A and B, and their spikes and waveforms
  2. Using the waveforms of just the spikes from units A and B, run a principal component analysis
  3. Compute the isolation score based on the PCs of this analysis

This is prohibitively slow; meaning that this quality metric is currently removed from the default quality_metric list due to its speed.

Instead, this PR implements the following:

  1. Take the PCs calculated by the compute_principal_components, which includes all spikes
  2. Compute the isolation score based on the PCs of this analysis

I think the new implementation is consistent with the references describing the quality metrics (https://pubmed.ncbi.nlm.nih.gov/28910621 and https://pubmed.ncbi.nlm.nih.gov/33473216/ ). Please correct me if you disagree!

It’s also:

  • Much faster (x150 on single core), since we don’t need to redo the PCA
  • Fits into the parallelisation scheme used by the other pca metrics => more than 150 faster!
  • Uses the sparsity scheme already used by the other PC metrics, rather than a custom one => easier to maintain.

The isolation score are generally worse in the new implementation, because the PCA is applied to all spikes, not just those in the two clusters being compared.

Also updated docstrings and docs to (hopefully) clarify what the metric is calculating.

Benchmark code. Note: num_units is the most important parameter since there is a max_spikes limit (so a long duration doesn’t affect things) and the method uses sparsity (so num_channels doesn’t affect things)

import spikeinterface.full as si
import numpy as np
from time import perf_counter

all_times = {}

for num_units in [10,20,30,40]:
    
    recording, sorting = si.generate_ground_truth_recording(durations=[10], num_channels=10, num_units=num_units)
    sorting_analyzer = si.create_sorting_analyzer(recording=recording, sorting=sorting)
    
    sorting_analyzer.compute(["random_spikes", "noise_levels", "waveforms", "templates", "principal_components", "spike_locations", "spike_amplitudes"])
    
    times = []
    for _ in range(3):
        t_start = perf_counter() 
        sorting_analyzer.compute({"quality_metrics": {"metric_names": ["nn_isolation"]}})
        t_end = perf_counter() 
        times.append(t_end - t_start)
    
    time = np.median(times)
    
    all_times[num_units] =  time

Old times:

all_times
>>> {10: 33.851581208989955, 20: 141.90579045796767, 30: 380.2453615410486, 40: 620.6475296249846}

New times:

all_times
>>> {10: 0.2444504169980064, 20: 0.968019749969244, 30: 2.3160873339511454, 40: 4.065173499984667}

chrishalcrow avatar Aug 27 '24 15:08 chrishalcrow