spikeinterface
spikeinterface copied to clipboard
about bug
This is my code. Imet a serious problem. I hope your help. When dealing different files, the number of units are the same. That's not right.
import os from pprint import pprint import time import spikeinterface import spikeinterface as si # import core only import spikeinterface.extractors as se import spikeinterface.preprocessing as spre import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm import spikeinterface.comparison as sc import spikeinterface.core as score import spikeinterface.exporters as sexp import spikeinterface.curation as scur import spikeinterface.widgets as sw import spikeinterface.full as si import matplotlib.pyplot as plt import numpy as np import os import pandas as pd from spikeinterface.widgets.isi_distribution import ISIDistributionWidget from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from spikeinterface.exporters import export_to_phy from datetime import datetime, timedelta global_job_kwargs = dict(n_jobs=1, chunk_duration="1s") si.set_global_job_kwargs(**global_job_kwargs)
def process_recording(recording, data_location, suffix): # Preprocessing recording_f = si.bandpass_filter(recording, freq_min=300, freq_max=6000) recording_cmr = si.common_reference(recording_f, reference='global', operator='median') recording_preprocessed = recording_cmr.save(format='binary') kilosort2_params = ss.get_default_sorter_params('kilosort4') print("Updated Kilosort4 params:", kilosort2_params) # Run Kilosort4 sorting_KS2 = ss.run_sorter(sorter_name="kilosort2", recording=recording_preprocessed, extra_requirements=["numpy==1.26"], docker_image=True, verbose=True)
# Extract waveforms
we_KS2 = si.extract_waveforms(recording_preprocessed, sorting_KS2,
os.path.join(data_location, f'waveforms_folder_{suffix}'), overwrite=None)
# Save waveforms as .npy files
npy_waveform_folder = os.path.join(data_location, f'waveform_npy_{suffix}')
os.makedirs(npy_waveform_folder, exist_ok=True)
for unit_id in we_KS2.unit_ids:
waveforms = we_KS2.get_waveforms(unit_id)
npy_filename = f"unit_{unit_id}_waveforms.npy"
npy_filepath = os.path.join(npy_waveform_folder, npy_filename)
np.save(npy_filepath, waveforms)
print(f"Unit {unit_id} waveforms saved to {npy_filepath}")
# Compute metrics
amplitudes = spost.compute_spike_amplitudes(we_KS2)
unit_locations = spost.compute_unit_locations(we_KS2)
spike_locations = spost.compute_spike_locations(we_KS2)
correlograms, bins = spost.compute_correlograms(we_KS2)
similarity = spost.compute_template_similarity(we_KS2)
ISI = spost.compute_isi_histograms(we_KS2, window_ms=100.0, bin_ms=2.0, method="auto")
metric = spost.compute_template_metrics(we_KS2, include_multi_channel_metrics=True)
metric_names = spost.get_template_metric_names()
print(we_KS2.get_available_extension_names())
waveform_folder = os.path.join(data_location, f'waveforms_folder_{suffix}')
if not os.path.isdir(waveform_folder):
print(f"Waveform folder does not exist: {waveform_folder}")
return
we_loaded = si.load_waveforms(waveform_folder)
print(we_loaded.get_available_extension_names())
print(we_loaded.get_available_extension_names())
# Compute quality metrics
qm_params = sqm.get_default_qm_params()
qm_params["presence_ratio"]["bin_duration_s"] = 1
qm_params["amplitude_cutoff"]["num_histogram_bins"] = 5
qm_params["drift"]["interval_s"] = 2
qm_params["drift"]["min_spikes_per_interval"] = 2
qm = sqm.compute_quality_metrics(we_KS2, qm_params=qm_params)
print(f"Quality metrics for {suffix}:", qm)
# Save spike trains
spike_trains = {}
for unit_id in sorting_KS2.unit_ids:
spike_train = sorting_KS2.get_unit_spike_train(unit_id, start_frame=None, end_frame=None)
spike_trains[unit_id] = spike_train
np.save(os.path.join(data_location, f'aligned_spike_trains_{suffix}.npy'), spike_trains)
# Load and check spike trains
loaded_spike_trains = np.load(os.path.join(data_location, f'aligned_spike_trains_{suffix}.npy'),
allow_pickle=True).item()
print(f"Loaded spike train data type for {suffix}:", type(loaded_spike_trains))
print(f"Loaded spike train dimensions for {suffix}:", {k: np.shape(v) for k, v in loaded_spike_trains.items()})
# Save to CSV
data = []
for unit_id, spike_train in spike_trains.items():
for spike in spike_train:
data.append([unit_id, spike])
df = pd.DataFrame(data, columns=['unit_id', 'spike_time'])
df.to_csv(os.path.join(data_location, f'aligned_spike_trains_{suffix}.csv'), index=False)
# Export to phy
sorting_analyzer = si.create_sorting_analyzer(sorting=sorting_KS2, recording=recording_preprocessed,
format="memory")
sorting_analyzer.compute(['random_spikes', 'waveforms', 'templates', 'noise_levels'])
_ = sorting_analyzer.compute('correlograms')
_ = sorting_analyzer.compute('spike_amplitudes')
_ = sorting_analyzer.compute('principal_components', n_components=5, mode="by_channel_local")
#phy_folder = os.path.join(data_location, f'phy_folder_{suffix}')
#os.makedirs(phy_folder, exist_ok=True)
si.export_to_phy(sorting_analyzer=sorting_analyzer, output_folder=os.path.join(data_location, f'phy_folder_{suffix}'))
# Plot and save images
sorting_analyzer.compute('unit_locations')
w2 = sw.plot_sorting_summary(sorting_analyzer, display=False, curation=True, backend="sortingview")
plt.savefig(os.path.join(data_location, f'sorting_summary_{suffix}.png'))
# 查看 SortingSummaryWidget 对象的所有属性和方法
print(dir(w2))
url = w2.url # 确保你能从 w2 中获取 URL
url_file_path = os.path.join(data_location, f'plot_url_{suffix}.txt')
with open(url_file_path, 'w') as f:
f.write(f"URL: {url}\n")
print(f"URL saved to {url_file_path}")
print(f"Processing complete for {suffix} in:", data_location)
w_rs = sw.plot_rasters(sorting_KS2)
plt.savefig(os.path.join(data_location, f'rasters_{suffix}.png'))
w_pr = sw.plot_unit_presence(sorting_KS2)
plt.savefig(os.path.join(data_location, f'unit_presence_{suffix}.png'))
def process_with_artifact_removal(recording, stim_times, recording_start_time, ms_after=1000):
list_of_artifacts = []
for stim_time in stim_times:
stim_start_time = stim_time
stim_end_time = stim_time + timedelta(milliseconds=ms_after)
start_time_sec = (stim_start_time - recording_start_time).total_seconds()
list_of_artifacts.append(start_time_sec)
rec_segment = recording.time_slice(start_time=start_time_sec, end_time=start_time_sec + ms_after / 1000.0)
cleaned_segment = spre.remove_artifacts(rec_segment,
list_triggers=[start_time_sec],
ms_before=0, ms_after=ms_after)
cleaned_recording = spre.remove_artifacts(recording,
list_triggers=list_of_artifacts,
ms_before=0, ms_after=ms_after)
return cleaned_recording
def split_recording_at_midpoint(recording, recording_start_time): sampling_rate = recording.get_sampling_frequency() total_duration_sec = recording.get_duration()
midpoint_sec = total_duration_sec / 2
midpoint_frame = int(midpoint_sec * sampling_rate)
total_frames = int(total_duration_sec * sampling_rate)
if not (0 <= midpoint_frame < total_frames):
raise ValueError("Midpoint frame is out of bounds.")
rec_first_half = recording.frame_slice(start_frame=0, end_frame=midpoint_frame)
rec_second_half = recording.frame_slice(start_frame=midpoint_frame, end_frame=total_frames)
return rec_first_half, rec_second_half
def process_files(data_location, csv_file, recording_start_time): # Load stim times stim_times_df = pd.read_csv(csv_file) stim_times_df['End Time'] = pd.to_datetime(stim_times_df['End Time']) stim_times = stim_times_df['End Time'].tolist() print("Stim Times:", stim_times)
data_name = "data.raw.h5"
recording = si.read_maxwell(os.path.join(data_location, data_name))
cleaned_recording = process_with_artifact_removal(recording, stim_times, recording_start_time)
cleaned_recording_first_half, cleaned_recording_second_half = split_recording_at_midpoint(cleaned_recording,
recording_start_time)
process_recording(cleaned_recording_first_half, data_location, 'KS2-first_half')
process_recording(cleaned_recording_second_half, data_location, 'KS2-second_half')
List of file sets with paths and start times
file_sets = [ { "data_location": , "csv_file": , "recording_start_time": },
]
Process each file set
for file_set in file_sets: process_files( data_location=file_set["data_location"], csv_file=file_set["csv_file"], recording_start_time=file_set["recording_start_time"] )
Hi, are you sure you're processing different files?
Can you print(file_set["data_location"])?
Hi, are you sure you're processing different files?
Can you
print(file_set["data_location"])?
cleaned_recording_first_half and cleaned_recording_second_half are the different part of my files, so it is different recordings right? I also have file_sets = [
{
"data_location": "D:\\sjwlab\\bxy\\0828shang\\24395\\1card-3interval-near\\game\\",
"csv_file": r"D:\sjwlab\bxy\0828shang\24395\1card-3interval-near\game\stim_region_times_20240828_105130.csv",
"recording_start_time": datetime.strptime("2024/8/28 10:51:32", "%Y/%m/%d %H:%M:%S")
},
{
"data_location": "D:\\sjwlab\\bxy\\game\\24395\\01\\01-1\\",
"csv_file": r"D:\sjwlab\bxy\game\24395\01\01-1\stim_region_times_20240830_112443.csv",
"recording_start_time": datetime.strptime("2024/8/30 11:24:46", "%Y/%m/%d %H:%M:%S")
},
{
"data_location": "D:\\sjwlab\\bxy\\game\\24395\\01\\4.4-1\\",
"csv_file": r"D:\sjwlab\bxy\game\24395\01\01-2\stim_region_times_20240830_113653.csv",
"recording_start_time": datetime.strptime("2024/8/30 11:36:56", "%Y/%m/%d %H:%M:%S")
},
]
They are 2 halves of the same recording! The neurons recorded will be the same so the sorter might find the same number of units.
They are 2 halves of the same recording! The neurons recorded will be the same so the sorter might find the same number of units.
however, my { "data_location": "D:\sjwlab\bxy\0828shang\24395\1card-3interval-near\game\", "csv_file": r"D:\sjwlab\bxy\0828shang\24395\1card-3interval-near\game\stim_region_times_20240828_105130.csv", "recording_start_time": datetime.strptime("2024/8/28 10:51:32", "%Y/%m/%d %H:%M:%S") }, { "data_location": "D:\sjwlab\bxy\game\24395\01\01-1\", "csv_file": r"D:\sjwlab\bxy\game\24395\01\01-1\stim_region_times_20240830_112443.csv", "recording_start_time": datetime.strptime("2024/8/30 11:24:46", "%Y/%m/%d %H:%M:%S") },
{ "data_location": "D:\sjwlab\bxy\game\24395\01\4.4-1\", "csv_file": r"D:\sjwlab\bxy\game\24395\01\01-2\stim_region_times_20240830_113653.csv", "recording_start_time": datetime.strptime("2024/8/30 11:36:56", "%Y/%m/%d %H:%M:%S") }, also had the same units.
Did you get the same spike trains, or just the same number of units? If you got the first spike trains, then the file is 99% the same.
I'll close for being stale. But feel free to reach back out if you have additional questions.