EEG-Conformer icon indicating copy to clipboard operation
EEG-Conformer copied to clipboard

复现的模型精度远低于论文的精度

Open XCZchaos opened this issue 9 months ago • 0 comments

作者您好,我最近在复现您的模型中出了精度偏低的问题,A01acc为0.78,A02acc为0.50,A03也为0.86左右但是S5的acc确意外地达到了75,我是用mne进行预处理的,提取2-6s和ID为768的运动想象数据,并在预处理阶段进行了数据标准化(原论文中的get_data模块已经做了调整改成了自己预处理之后的数据),代码其他部分我并没有修改。avg_acc在73左右 另一个发生是:不做带通滤波似乎对精度有所提升(多次实验结果) 以下是我的python代码(未作滤波操作)

import random
from collections import Counter

import mne
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler,OneHotEncoder
import scipy.io
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from scipy import signal


# 设置随机种子
seed_value = 42  # 可以根据需要选择任意整数作为种子值

# Python内置的随机模块
random.seed(seed_value)

# NumPy随机数生成器
np.random.seed(seed_value)

# PyTorch随机数生成器
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""进行预处理前请检查数据是否存在过多的噪声,或者数据是否稳定"""
"""If you want to use this file, please make sure not have noise in your original data """

# 进行数据预处理 只适合T结尾的data数据,其他的需要修改函数 filename需以.npy结尾
def transform_save_data(filename,save_filename=None):
    raw = mne.io.read_raw_gdf(filename)
    print(raw.info['ch_names'])
    events,event_id = mne.events_from_annotations(raw)
    raw.info['bads'] += ['EOG-left','EOG-central','EOG-right']
    # 运动想象时间2-6秒
    tmin,tmax = 2, 6
    event_id = {'768': 6}
    # 需要重新加载raw对象进行滤波处理
    raw.load_data()
    # iir_params = dict(order=6, ftype='cheby2', rs=60)  # Chebyshev Type II滤波器
    # raw.filter(l_freq=8, h_freq=35, method='iir', iir_params=iir_params)
    # raw.filter(8.0, 35.0, fir_design='IIR')
    picks = mne.pick_types(raw.info,meg=False,eeg=True,stim=False,exclude='bads')
    epochs = mne.Epochs(raw=raw,events=events,event_id=6,tmin=tmin,tmax=tmax,preload=True,baseline=None,picks=picks)
    epoch_data = epochs.get_data()
    # 将最后一位数据进行去除
    epoch_data = epoch_data[:, :, :-1]


    epoch_data = epoch_data.reshape(epoch_data.shape[0], 1, 22, 1000)
    if save_filename is not None:
        np.save(save_filename,epoch_data)


# 进行归一化处理
def data_processing(BCI_IV_2a_data,label_filename):
    Scaler = StandardScaler()
    X_train = BCI_IV_2a_data.reshape(BCI_IV_2a_data.shape[0], 22000)
    X_train_Scaler = Scaler.fit_transform(X_train)
    # 进行reshape第二个维度为channels W H
    acc_train = X_train_Scaler.reshape(BCI_IV_2a_data.shape[0], 1, 22, 1000)
    data_label = scipy.io.loadmat(label_filename)
    print(data_label['classlabel'].reshape(288))
    Label = data_label['classlabel'].reshape(288)


    return acc_train,Label


# 进行转换成Tensor格式的数据  保存的文件的格式应该以pt为后缀
def data_transform_tensor(acc_train,y_oh,save_datafilename=None,save_labelfilename=None):
    # transf = transforms.ToTensor()
    # d = transf(y_oh)
    # # 去除另外四个维度的标签,标签就是最大值
    # label = torch.argmax(d,dim=2).long()

    


    data = torch.tensor(acc_train,dtype=torch.float32)
    labels = torch.tensor(y_oh,dtype=torch.long)
    if save_datafilename is not None:
        torch.save(data,save_datafilename)
    if save_labelfilename is not None:
        torch.save(labels,save_labelfilename)

    return data, labels



# 将数据进行联合
def combine_data(data_list,label_list,data_filename,label_filename):
    """_summary_
    将增强的EEG_data数据进行拼接 并保存为pt后缀文件
    Parameters
    ----------
    data_list : _type_  tensor
        EEGdata list
    label_list : _type_ tensor
        label list
    data_filename : _type_, optional
        _description_, by default None 
    label_filename : _type_, optional
        _description_, by default None 
    """
    data_combine = torch.cat(data_list, axis=0)
    label_combine = torch.cat(label_list, axis=0)
    torch.save(data_combine, data_filename)
    torch.save(label_combine, label_filename)

    return data_combine, label_combine


# 数据滤波 利用巴特沃斯滤波器
def buttferfiter(data):
    Fs = 250
    b, a = signal.butter(6, [8, 30], 'bandpass', fs=Fs)
    data = signal.filtfilt(b, a, data, axis=1)
    return data


# 进行时域上EEG数据增强  通过分割,重构 打乱数据
def interaug(timg, label, batch_size):
    """timg是data label是标签"""
    """
    tmp_aug_data 用于保存生成的增强样本数据,其形状为 (batch_size / 4, 1, 22, 1000),
    即每个增强样本包含8个时间片段,每个时间片段的形状为 (1, 22, 125)
    
    
    rand_idx 是随机选择的8个时间片段的索引,用于从原始数据中获取时间片段。
    aug_data 和 aug_label 分别保存所有类别的增强样本和对应的标签。
    aug_shuffle 对增强样本和标签进行随机打乱。

    """
    aug_data = []
    aug_label = []
    for cls4aug in range(4):
        # 条件判断 找出对应的label和data
        cls_idx = np.where(label == cls4aug+1)  # label == cls4aug + 1
        tmp_data = timg[cls_idx]
        tmp_label = label[cls_idx]
        # 分epoch
        tmp_aug_data = np.zeros((int(batch_size / 4), 1, 22, 1000))
        for ri in range(int(batch_size / 4)):
            # 随机取8个时间片段
            for rj in range(8):
                rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
                # 进行数据的打乱重构
                tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125]

        aug_data.append(tmp_aug_data)
        aug_label.append(tmp_label[:int(batch_size / 4)])
    aug_data = np.concatenate(aug_data)
    aug_label = np.concatenate(aug_label)
    aug_shuffle = np.random.permutation(len(aug_data))
    aug_data = aug_data[aug_shuffle, :, :]
    aug_label = aug_label[aug_shuffle]

    aug_data = torch.from_numpy(aug_data).cuda()
    aug_data = aug_data.float()
    aug_label = torch.from_numpy(aug_label).cuda()  # aug_label - 1
    aug_label = aug_label.long()
    return aug_data, aug_label



# 切分部分数据进行test
def split_EEGdata(data, label):
    data = data.view(data.shape[0], 1, 22, 1000)
    data = data[:100]
    label = label[:100]
    torch.save(data, 'EEG_data_split.pt')
    torch.save(label, 'EEG_label_split.pt')


#  没有进行滤波处理原数据获取
def transform_save_data_version2(filename,save_filename=None):
    """
    The processing gain the raw data from the EEG
    :param filename: data filename
    :param save_filename: save data filename
    :return: save a np style file
    """
    raw = mne.io.read_raw_gdf(filename)
    print(raw.info['ch_names'])
    events,event_id = mne.events_from_annotations(raw)
    raw.info['bads'] += ['EOG-left','EOG-central','EOG-right']
    # 运动想象时间2-6秒
    tmin,tmax = 2,6
    event_id = {'768': 6}
    # 需要重新加载raw对象进行滤波处理
    raw.load_data()
    picks = mne.pick_types(raw.info,meg=False,eeg=True,stim=False,exclude='bads')
    epochs = mne.Epochs(raw=raw,events=events,event_id=event_id,tmin=tmin,tmax=tmax,preload=True,baseline=None,picks=picks)
    epoch_data = epochs.get_data(copy=True)
    # 将最后一位数据进行去除
    epoch_data = epoch_data[:,:,:-1]


    epoch_data = epoch_data.reshape(epoch_data.shape[0], 1, 22, 1000)
    if save_filename is not None:
        np.save(save_filename,epoch_data)

    return epoch_data












if __name__ == '__main__':
    count = input('please input your subject ID:')
    filename = 'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A0' + count + 'T.gdf'
    BCI_data = transform_save_data_version2(filename)
    label_filename = 'C:\\Users\\24242\\Desktop\\AI_Reference\\data_bag\\BCICIV_2a_gdf\\A0' + count + 'T.mat'
    acc_train, y_oh = data_processing(BCI_data, label_filename)
    data, label = data_transform_tensor(acc_train, y_oh)

    # data = data.reshape(data.shape[0], 22, 1000)
    data = np.array(data)
    label = np.array(label)
    # print(label)
    data1, label1 = interaug(data, label, batch_size=288)

    data = torch.from_numpy(data)
    label = torch.from_numpy(label)
    print(data.type())
    print(label.type())
    # data1 = torch.from_numpy(data1)
    # label1 = torch.from_numpy(label1)
    data_list = [data.to('cuda'), data1.to('cuda')]
    label_list = [label.to('cuda'), label1.to('cuda')]

    data_filename = '../EEG-dataprocessing/2a/paper_data_label/A0'+ count + '_combine/A0'+ count + '_combine_data.pt'
    label_filename = '../EEG-dataprocessing/2a/paper_data_label/A0'+ count + '_combine/A0'+ count +'_combine_label.pt'

    data_combine, label_combine = combine_data(data_list, label_list, data_filename, label_filename)


    data_combine = data_combine.detach().cpu().numpy()
    label_combine = label_combine.detach().cpu().numpy()
    print(data_combine.shape)
    print(label_combine.shape)
    train_data, test_data, train_label, test_label = train_test_split(data_combine, label_combine, test_size=0.2, train_size=0.8, shuffle=True)
    train_data = torch.from_numpy(train_data).float()
    test_data = torch.from_numpy(test_data).float()
    train_label = torch.from_numpy(train_label).long()
    test_label = torch.from_numpy(test_label).long()
    torch.save(train_data, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/train_data_A0' + count + '.pt')
    torch.save(test_data, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/test_data_A0' + count + '.pt')
    torch.save(train_label, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/train_label_A0' + count + '.pt')
    torch.save(test_label, '../EEG-dataprocessing/2a/paper_data_label/A0' + count + '_combine/test_label_A0' + count + '.pt')
    print(train_data.shape)
    print(test_data.shape)
    print(train_label.shape)
    print(test_label.shape)

XCZchaos avatar May 18 '24 11:05 XCZchaos