PyTorch-BYOL
PyTorch-BYOL copied to clipboard
Training on CIFAR10
Hello,
Thank you for this excellent repository!
Do you have any suggestions of changes to make to train BYOL on the CIFAR10 dataset?
The way I am doing this (in main.py) (I am also training my own custom models, but I do not think that is too relevant)
DATASET='CIFAR10' # Can change to STL10
if DATASET=='STL10':
train_dataset = datasets.STL10('/workspace/STLDataset', split='train+unlabeled', download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
elif DATASET=='CIFAR10':
train_dataset = datasets.CIFAR10('/workspace/CIFAR10Dataset', train=True, download=True,
transform=MultiViewDataInjector([data_transform, data_transform]))
else:
print("Error, dataset not supported, choose CIFAR10 or STL10")
exit(0)
I also change the config to have: input_shape: (32,32,3).
Further, I may not have taken a very deep look into this code-base, but how do we produce the 'STL10 Top 1' accuracies(75.2%) after training the model on the self-supervised task? Do we take the trained model and fine-tune on the STL10 supervised dataset? I assume that code is not included in this library?
Thank you!
Hi Akhauriyash, you can just modify the input shape and name of the dataset. I am testing with the model but it doesn't work well with CIFAR10, ~ 54% top1 accuracy and I wonder the config is the same or different on learning rate? Thank you!