FEATURE: Saving variables to disk between major steps
Feature you'd like to see:
Hi, I am running kilosort on long recordings (>1d), that causes occasional "cuda out of memory" crashes. In order to debug this issue it will be helpful if data will be written to disk between major processing steps (after spike detection, after first clustering, etc). Doing so will allow starting a new debugging session without redoing earlier steps that may take hours and even days to complete. Special care is needed to serialize coda data, both writing and reading later on. Thanks Anan
Additional Context
No response
Some variables will likely be written this way in the future, but we do not plan to write every variable in this way. If you would like help debugging the issue / recommendations on how to handle resource usage, please upload kilosort4.log from the results directory. If you're intent on saving the variables and debugging on your own, you can always do that yourself in a script/notebook by using the individual functions in kilosort.run_kilosort.
Example script below, using our test dataset. To save the relevant variables, add calls to np.save (after converting from tensor with .cpu().numpy() if needed). For saving and loading ops, you can use io.save_ops and io.load_ops.
import time
from pathlib import Path
import numpy as np
import torch
from kilosort.run_kilosort import (
initialize_ops, compute_preprocessing, compute_drift_correction,
detect_spikes, cluster_spikes, save_sorting, setup_logger, DEFAULT_SETTINGS
)
from kilosort import io
do_CAR = True
invert_sign = False
tic0 = time.time()
device = torch.device('cuda')
progress_bar = None
data_dtype = 'int16'
filename = Path('~/.kilosort/.test_data/ZFM-02370_mini.imec0.ap.short.bin')
results_dir = filename.parent / 'kilosort4'
probe_path = filename.parent.parent / 'probes' / 'NeuroPix1_default.mat'
probe = io.load_probe(probe_path)
settings = DEFAULT_SETTINGS
settings['n_chan_bin'] = 385
settings['filename'] = filename
settings['data_dir'] = filename.parent
results_dir.mkdir(exist_ok=True, parents=True)
setup_logger(results_dir, verbose_console=True)
ops, settings = initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign,
device, save_preprocessed_copy=False)
ops = compute_preprocessing(ops, device, tic0=tic0)
# drift correction
np.random.seed(1)
torch.cuda.manual_seed_all(1)
torch.random.manual_seed(1)
ops, bfile, st0 = compute_drift_correction(
ops, device, tic0=tic0, progress_bar=progress_bar, verbose=True
)
st, tF, Wall, clu = detect_spikes(ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
clu, Wall, st, tF = cluster_spikes(st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar)
ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \
save_sorting(ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0,
save_extra_vars=False)