EEG-Conformer
EEG-Conformer copied to clipboard
复现的模型精度远低于论文的精度
作者您好,我最近在复现您的模型中出了精度偏低的问题,A01acc为0.78,A02acc为0.50,A03也为0.86左右但是S5的acc确意外地达到了75,我是用mne进行预处理的,提取2-6s和ID为768的运动想象数据,并在预处理阶段进行了数据标准化(原论文中的get_data模块已经做了调整改成了自己预处理之后的数据),代码其他部分我并没有修改。avg_acc在73左右 另一个发生是:不做带通滤波似乎对精度有所提升(多次实验结果) 以下是我的python代码(未作滤波操作)
import random
from collections import Counter
import mne
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler,OneHotEncoder
import scipy.io
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from scipy import signal
# 设置随机种子
seed_value = 42 # 可以根据需要选择任意整数作为种子值
# Python内置的随机模块
random.seed(seed_value)
# NumPy随机数生成器
np.random.seed(seed_value)
# PyTorch随机数生成器
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"""进行预处理前请检查数据是否存在过多的噪声,或者数据是否稳定"""
"""If you want to use this file, please make sure not have noise in your original data """
# 进行数据预处理 只适合T结尾的data数据,其他的需要修改函数 filename需以.npy结尾
def transform_save_data(filename,save_filename=None):
raw = mne.io.read_raw_gdf(filename)
print(raw.info['ch_names'])
events,event_id = mne.events_from_annotations(raw)
raw.info['bads'] += ['EOG-left','EOG-central','EOG-right']
# 运动想象时间2-6秒
tmin,tmax = 2, 6
event_id = {'768': 6}
# 需要重新加载raw对象进行滤波处理
raw.load_data()
# iir_params = dict(order=6, ftype='cheby2', rs=60) # Chebyshev Type II滤波器
# raw.filter(l_freq=8, h_freq=35, method='iir', iir_params=iir_params)
# raw.filter(8.0, 35.0, fir_design='IIR')
picks = mne.pick_types(raw.info,meg=False,eeg=True,stim=False,exclude='bads')
epochs = mne.Epochs(raw=raw,events=events,event_id=6,tmin=tmin,tmax=tmax,preload=True,baseline=None,picks=picks)
epoch_data = epochs.get_data()
# 将最后一位数据进行去除
epoch_data = epoch_data[:, :, :-1]
epoch_data = epoch_data.reshape(epoch_data.shape[0], 1, 22, 1000)
if save_filename is not None:
np.save(save_filename,epoch_data)
# 进行归一化处理
def data_processing(BCI_IV_2a_data,label_filename):
Scaler = StandardScaler()
X_train = BCI_IV_2a_data.reshape(BCI_IV_2a_data.shape[0], 22000)
X_train_Scaler = Scaler.fit_transform(X_train)
# 进行reshape第二个维度为channels W H
acc_train = X_train_Scaler.reshape(BCI_IV_2a_data.shape[0], 1, 22, 1000)
data_label = scipy.io.loadmat(label_filename)
print(data_label['classlabel'].reshape(288))
Label = data_label['classlabel'].reshape(288)
return acc_train,Label
# 进行转换成Tensor格式的数据 保存的文件的格式应该以pt为后缀
def data_transform_tensor(acc_train,y_oh,save_datafilename=None,save_labelfilename=None):
# transf = transforms.ToTensor()
# d = transf(y_oh)
# # 去除另外四个维度的标签,标签就是最大值
# label = torch.argmax(d,dim=2).long()
data = torch.tensor(acc_train,dtype=torch.float32)
labels = torch.tensor(y_oh,dtype=torch.long)
if save_datafilename is not None:
torch.save(data,save_datafilename)
if save_labelfilename is not None:
torch.save(labels,save_labelfilename)
return data, labels
# 将数据进行联合
def combine_data(data_list,label_list,data_filename,label_filename):
"""_summary_
将增强的EEG_data数据进行拼接 并保存为pt后缀文件
Parameters
----------
data_list : _type_ tensor
EEGdata list
label_list : _type_ tensor
label list
data_filename : _type_, optional
_description_, by default None
label_filename : _type_, optional
_description_, by default None
"""
data_combine = torch.cat(data_list, axis=0)
label_combine = torch.cat(label_list, axis=0)
torch.save(data_combine, data_filename)
torch.save(label_combine, label_filename)
return data_combine, label_combine
# 数据滤波 利用巴特沃斯滤波器
def buttferfiter(data):
Fs = 250
b, a = signal.butter(6, [8, 30], 'bandpass', fs=Fs)
data = signal.filtfilt(b, a, data, axis=1)
return data
# 进行时域上EEG数据增强 通过分割,重构 打乱数据
def interaug(timg, label, batch_size):
"""timg是data label是标签"""
"""
tmp_aug_data 用于保存生成的增强样本数据,其形状为 (batch_size / 4, 1, 22, 1000),
即每个增强样本包含8个时间片段,每个时间片段的形状为 (1, 22, 125)
rand_idx 是随机选择的8个时间片段的索引,用于从原始数据中获取时间片段。
aug_data 和 aug_label 分别保存所有类别的增强样本和对应的标签。
aug_shuffle 对增强样本和标签进行随机打乱。
"""
aug_data = []
aug_label = []
for cls4aug in range(4):
# 条件判断 找出对应的label和data
cls_idx = np.where(label == cls4aug+1) # label == cls4aug + 1
tmp_data = timg[cls_idx]
tmp_label = label[cls_idx]
# 分epoch
tmp_aug_data = np.zeros((int(batch_size / 4), 1, 22, 1000))
for ri in range(int(batch_size / 4)):
# 随机取8个时间片段
for rj in range(8):
rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
# 进行数据的打乱重构
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125]
aug_data.append(tmp_aug_data)
aug_label.append(tmp_label[:int(batch_size / 4)])
aug_data = np.concatenate(aug_data)
aug_label = np.concatenate(aug_label)
aug_shuffle = np.random.permutation(len(aug_data))
aug_data = aug_data[aug_shuffle, :, :]
aug_label = aug_label[aug_shuffle]
aug_data = torch.from_numpy(aug_data).cuda()
aug_data = aug_data.float()
aug_label = torch.from_numpy(aug_label).cuda() # aug_label - 1
aug_label = aug_label.long()
return aug_data, aug_label
# 切分部分数据进行test
def split_EEGdata(data, label):
data = data.view(data.shape[0], 1, 22, 1000)
data = data[:100]
label = label[:100]
torch.save(data, 'EEG_data_split.pt')
torch.save(label, 'EEG_label_split.pt')
# 没有进行滤波处理原数据获取
def transform_save_data_version2(filename,save_filename=None):
"""
The processing gain the raw data from the EEG
:param filename: data filename
:param save_filename: save data filename
:return: save a np style file
"""
raw = mne.io.read_raw_gdf(filename)
print(raw.info['ch_names'])
events,event_id = mne.events_from_annotations(raw)
raw.info['bads'] += ['EOG-left','EOG-central','EOG-right']
# 运动想象时间2-6秒
tmin,tmax = 2,6
event_id = {'768': 6}
# 需要重新加载raw对象进行滤波处理
raw.load_data()
picks = mne.pick_types(raw.info,meg=False,eeg=True,stim=False,exclude='bads')
epochs = mne.Epochs(raw=raw,events=events,event_id=event_id,tmin=tmin,tmax=tmax,preload=True,baseline=None,picks=picks)
epoch_data = epochs.get_data(copy=True)
# 将最后一位数据进行去除
epoch_data = epoch_data[:,:,:-1]
epoch_data = epoch_data.reshape(epoch_data.shape[0], 1, 22, 1000)
if save_filename is not None:
np.save(save_filename,epoch_data)
return epoch_data
if __name__ == '__main__':
count = input('please input your subject ID:')
filename = 'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A0' + count + 'T.gdf'
BCI_data = transform_save_data_version2(filename)
label_filename = 'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A0' + count + 'T.mat'
acc_train, y_oh = data_processing(BCI_data, label_filename)
data, label = data_transform_tensor(acc_train, y_oh)
# data = data.reshape(data.shape[0], 22, 1000)
data = np.array(data)
label = np.array(label)
# print(label)
data1, label1 = interaug(data, label, batch_size=288)
data = torch.from_numpy(data)
label = torch.from_numpy(label)
print(data.type())
print(label.type())
# data1 = torch.from_numpy(data1)
# label1 = torch.from_numpy(label1)
data_list = [data.to('cuda'), data1.to('cuda')]
label_list = [label.to('cuda'), label1.to('cuda')]
data_filename = '../EEG-dataprocessing/2a/paper_data_label/A0'+ count + '_combine/A0'+ count + '_combine_data.pt'
label_filename = '../EEG-dataprocessing/2a/paper_data_label/A0'+ count + '_combine/A0'+ count +'_combine_label.pt'
data_combine, label_combine = combine_data(data_list, label_list, data_filename, label_filename)
data_combine = data_combine.detach().cpu().numpy()
label_combine = label_combine.detach().cpu().numpy()
print(data_combine.shape)
print(label_combine.shape)
train_data, test_data, train_label, test_label = train_test_split(data_combine, label_combine, test_size=0.2, train_size=0.8, shuffle=True)
train_data = torch.from_numpy(train_data).float()
test_data = torch.from_numpy(test_data).float()
train_label = torch.from_numpy(train_label).long()
test_label = torch.from_numpy(test_label).long()
torch.save(train_data, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/train_data_A0' + count + '.pt')
torch.save(test_data, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/test_data_A0' + count + '.pt')
torch.save(train_label, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/train_label_A0' + count + '.pt')
torch.save(test_label, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/test_label_A0' + count + '.pt')
print(train_data.shape)
print(test_data.shape)
print(train_label.shape)
print(test_label.shape)