torcheeg
torcheeg copied to clipboard
Data split method in SEED dataset
Hello!
I am working with SEED dataset, and I would like to split the data as the authors did in their paper: first 9 clips for training and remaining 6 for testing.
Is there any data splitting method that allow this?
Thanks in advance.
Best regards, Eduardo
Is there any support for SEED dataset?
I am using this simple code and I have an error:
dataset = SEEDDataset(io_path=f'./SEED_preprocc/DE_feat',
root_path=path_to_db,
offline_transform=transforms.BandDifferentialEntropy(
band_dict={
"delta": [1, 4],
"theta": [4, 8],
"alpha": [8, 14],
"beta": [14, 31],
"gamma": [31, 49]
}),
online_transform=transforms.Compose([
transforms.ToTensor()
]),
label_transform=transforms.Compose([
transforms.Select('emotion'),
transforms.Lambda(lambda x: x + 1)
]),
num_worker=12)
train, test = train_test_split_per_subject_groupby_trial(dataset=dataset, subject=3)
The error obtained:
ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.
Any help is appreciated!
Happy New Year~ You can try the following code:
from torcheeg.datasets import SEEDDataset
from torcheeg.model_selection import KFoldPerSubjectGroupbyTrial
from torcheeg import transforms
dataset = SEEDDataset(
io_path=f'./SEED_preprocc/DE_feat',
root_path=path_to_db,
offline_transform=transforms.BandDifferentialEntropy(
band_dict={
"delta": [1, 4],
"theta": [4, 8],
"alpha": [8, 14],
"beta": [14, 31],
"gamma": [31, 49]
}),
online_transform=transforms.Compose([transforms.ToTensor()]),
label_transform=transforms.Compose(
[transforms.Select('emotion'),
transforms.Lambda(lambda x: x + 1)]),
num_worker=12)
k_fold = KFoldPerSubjectGroupbyTrial()
for subject_id, (train_dataset,
test_dataset) in enumerate(k_fold.split(dataset)):
if subject_id == 3:
print(len(train_dataset))
print(len(test_dataset))
break
I've updated an internal version to meet your requirement. You can install the latest version of TorchEEG by running:
pip install git+https://github.com/torcheeg/torcheeg.git
Use train_test_split_per_subject_cross_trial
to divide the dataset into training and testing sets based on video clips (trials), where some video clips are used for the training set and others for the testing set. On the other hand, train_test_split_per_subject_groupby_trial
divides the dataset such that some periods within a video clip are used for the training set and other periods for the testing set.
from torcheeg.datasets import SEEDDataset
from torcheeg.model_selection import train_test_split_per_subject_cross_trial
from torcheeg import transforms
dataset = SEEDDataset(
io_path=f'./SEED_preprocc/DE_feat',
root_path=path_to_db,
offline_transform=transforms.BandDifferentialEntropy(
band_dict={
"delta": [1, 4],
"theta": [4, 8],
"alpha": [8, 14],
"beta": [14, 31],
"gamma": [31, 49]
}),
online_transform=transforms.Compose([transforms.ToTensor()]),
label_transform=transforms.Compose(
[transforms.Select('emotion'),
transforms.Lambda(lambda x: x + 1)]),
num_worker=12)
train, test = train_test_split_per_subject_cross_trial(dataset=dataset,
subject=3)
print(len(train))
print(len(test))
sorry, i used your code, but there is a problem:
train_trial_ids = np.array(trial_ids)[train_index_trial_ids].tolist()
IndexError: only integers, slices (:
), ellipsis (...
), numpy.newaxis (None
) and integer or boolean arrays are valid indices
i don't know how to fix it.
I checked the code and found that it runs well. I couldn't reproduce your problem. I removed the irrelevant functionality. Can you check what's different between your code and the code below?
# pip install git+https://github.com/torcheeg/torcheeg.git
from torcheeg.datasets import SEEDDataset
from torcheeg.model_selection import train_test_split_per_subject_cross_trial
from torcheeg import transforms
dataset = SEEDDataset(
root_path='./tmp_in/Preprocessed_EEG',
# unzip the Preprocessed_EEG folder from the downloaded dataset
online_transform=transforms.Compose([transforms.ToTensor()]),
label_transform=transforms.Compose(
[transforms.Select('emotion'),
transforms.Lambda(lambda x: x + 1)]),
num_worker=12)
train, test = train_test_split_per_subject_cross_trial(dataset=dataset,
subject=3)
print(len(train))
print(len(test))
# 8187
# 1995