Kilosort icon indicating copy to clipboard operation
Kilosort copied to clipboard

Question: How to get template waveform for each spike?

Open chris-angeloni opened this issue 1 year ago • 1 comments

I'd like to scale the templates in templates.npy to the fit to each spike waveform (ie. scaled by the template amplitude in amplitudes.npy).

Based on this older issue: https://github.com/cortex-lab/KiloSort/issues/35#issuecomment-262824645, is it still correct to multiply the template by the amplitude, then by the inverse whitening matrix?

For example:

amplitudes = np.load(results_dir / 'amplitudes.npy')
templates = np.load(results_dir / 'templates.npy')
spike_templates = np.load(results_dir / 'spike_templates.npy')
whitening_mat_inv = np.load(results_dir / 'whitening_mat_inv.npy')

s = 0
fig, ax = plt.subplots(1,2)
ax[0].plot(amplitudes[s] * templates[spike_templates[s],:,:] @ whitening_mat_inv)
ax[0].set_title('scaled and whitened')
ax[1].plot(amplitudes[s] * templates[spike_templates[s],:,:])
ax[1].set_title('scaled')

image

The scale of the unwhitened templates seems more within microvolt range? (recordings are from neuropixels 2 probes, I'm not sure if there is a data scaling parameter I'm missing)

chris-angeloni avatar Oct 21 '24 15:10 chris-angeloni

Yes, the picture on the right is what you want. Here's some code you can use to verify by comparing to the spike waveforms:

from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from kilosort.io import BinaryFiltered, load_ops

# Path to binary file
filename = Path('C:/users/jacob/.kilosort/.test_data/ZFM-02370_mini.imec0.ap.bin')
results_dir = filename.parent / 'kilosort4'
device = torch.device('cpu')

amplitudes = np.load(results_dir / 'amplitudes.npy')
templates = np.load(results_dir / 'templates.npy')
spike_templates = np.load(results_dir / 'spike_templates.npy')
spike_times = np.load(results_dir / 'spike_times.npy')
clu = np.load(results_dir / 'spike_clusters.npy')
whitening_mat_inv = np.load(results_dir / 'whitening_mat_inv.npy')
ops = load_ops(results_dir / 'ops.npy', device=device)

# Get largest channel for each cluster id, for plotting waveform
chan_map = ops['chanMap']
chan_best = (templates**2).sum(axis=1).argmax(axis=-1)

bfile = BinaryFiltered(
    filename, n_chan_bin=ops['n_chan_bin'], chan_map=chan_map, device=device,
    hp_filter=ops['fwav'], whiten_mat=ops['Wrot'], dshift=ops['dshift'],
    )
# May need to change other arguments for bfile, like NT or do_CAR, if you changed 
# those when sorting.
# Exclude hp_filter, whitening_mat, dshift for raw data (no preprocessing).

s = 17                            # pick a spike
cluster_id = clu[s]               # get cluster id it was assigned to
c = chan_best[cluster_id]         # largest channel for that cluster

# Get mean waveform for 100 random spikes from this cluster, aligned to nt0min.
spikes = spike_times[clu == cluster_id]
subset = np.random.choice(spikes, min(spikes.size, 100))
waves = []
for t in subset:
    tmin = t - bfile.nt0min
    tmax = t + (bfile.nt - bfile.nt0min) + 1
    w = whitening_mat_inv @ bfile[tmin:tmax].cpu().numpy()
    waves.append(w)
wv = np.stack(waves, axis=-1).mean(axis=-1)

fig, ax = plt.subplots(1,2)
ax[0].plot(amplitudes[s] * templates[spike_templates[s],:,:] @ whitening_mat_inv)
ax[0].plot(wv[c], c='black', linestyle='dashed')
ax[0].set_title('scaled and un-whitened')
ax[1].plot(amplitudes[s] * templates[spike_templates[s],:,:])
ax[1].plot(wv[c], c='black', linestyle='dashed')
ax[1].set_title('scaled')

image

jacobpennington avatar Oct 22 '24 01:10 jacobpennington

This is super useful, thank you!

chris-angeloni avatar Oct 22 '24 13:10 chris-angeloni