Question: How to get template waveform for each spike?
I'd like to scale the templates in templates.npy to the fit to each spike waveform (ie. scaled by the template amplitude in amplitudes.npy).
Based on this older issue: https://github.com/cortex-lab/KiloSort/issues/35#issuecomment-262824645, is it still correct to multiply the template by the amplitude, then by the inverse whitening matrix?
For example:
amplitudes = np.load(results_dir / 'amplitudes.npy')
templates = np.load(results_dir / 'templates.npy')
spike_templates = np.load(results_dir / 'spike_templates.npy')
whitening_mat_inv = np.load(results_dir / 'whitening_mat_inv.npy')
s = 0
fig, ax = plt.subplots(1,2)
ax[0].plot(amplitudes[s] * templates[spike_templates[s],:,:] @ whitening_mat_inv)
ax[0].set_title('scaled and whitened')
ax[1].plot(amplitudes[s] * templates[spike_templates[s],:,:])
ax[1].set_title('scaled')
The scale of the unwhitened templates seems more within microvolt range? (recordings are from neuropixels 2 probes, I'm not sure if there is a data scaling parameter I'm missing)
Yes, the picture on the right is what you want. Here's some code you can use to verify by comparing to the spike waveforms:
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from kilosort.io import BinaryFiltered, load_ops
# Path to binary file
filename = Path('C:/users/jacob/.kilosort/.test_data/ZFM-02370_mini.imec0.ap.bin')
results_dir = filename.parent / 'kilosort4'
device = torch.device('cpu')
amplitudes = np.load(results_dir / 'amplitudes.npy')
templates = np.load(results_dir / 'templates.npy')
spike_templates = np.load(results_dir / 'spike_templates.npy')
spike_times = np.load(results_dir / 'spike_times.npy')
clu = np.load(results_dir / 'spike_clusters.npy')
whitening_mat_inv = np.load(results_dir / 'whitening_mat_inv.npy')
ops = load_ops(results_dir / 'ops.npy', device=device)
# Get largest channel for each cluster id, for plotting waveform
chan_map = ops['chanMap']
chan_best = (templates**2).sum(axis=1).argmax(axis=-1)
bfile = BinaryFiltered(
filename, n_chan_bin=ops['n_chan_bin'], chan_map=chan_map, device=device,
hp_filter=ops['fwav'], whiten_mat=ops['Wrot'], dshift=ops['dshift'],
)
# May need to change other arguments for bfile, like NT or do_CAR, if you changed
# those when sorting.
# Exclude hp_filter, whitening_mat, dshift for raw data (no preprocessing).
s = 17 # pick a spike
cluster_id = clu[s] # get cluster id it was assigned to
c = chan_best[cluster_id] # largest channel for that cluster
# Get mean waveform for 100 random spikes from this cluster, aligned to nt0min.
spikes = spike_times[clu == cluster_id]
subset = np.random.choice(spikes, min(spikes.size, 100))
waves = []
for t in subset:
tmin = t - bfile.nt0min
tmax = t + (bfile.nt - bfile.nt0min) + 1
w = whitening_mat_inv @ bfile[tmin:tmax].cpu().numpy()
waves.append(w)
wv = np.stack(waves, axis=-1).mean(axis=-1)
fig, ax = plt.subplots(1,2)
ax[0].plot(amplitudes[s] * templates[spike_templates[s],:,:] @ whitening_mat_inv)
ax[0].plot(wv[c], c='black', linestyle='dashed')
ax[0].set_title('scaled and un-whitened')
ax[1].plot(amplitudes[s] * templates[spike_templates[s],:,:])
ax[1].plot(wv[c], c='black', linestyle='dashed')
ax[1].set_title('scaled')
This is super useful, thank you!