vit-pytorch icon indicating copy to clipboard operation
vit-pytorch copied to clipboard

Trained on small dataset with pre-trained weight, don't have good result.

Open JamesQFreeman opened this issue 4 years ago • 12 comments

pretrained_v = timm.create_model('vit_base_patch16_224', pretrained=True)
pretrained_v.head = nn.Linear(768,2)

I tried Kaggle Cats vs Dogs Dataset for binary classification. Didn't work, output is all cat or all dog.

Any idea how to make it work at small dataset? (less than 10000 or even less than 1000)

PS: Adam, lr = 1e-2

JamesQFreeman avatar Nov 26 '20 05:11 JamesQFreeman

update: "Didn't work, output is all cat or all dog." was trained on only 1k images. Now I train the ViT on whole dataset which have 20k images and it kind of works.

0.73 acc @ 10 epochs, 45 mins on RTX Titan(another same run used 100 mins on Titan X), Not very amazing comparing with CNN so far. image

JamesQFreeman avatar Nov 26 '20 07:11 JamesQFreeman

v = ViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 2,
    dim = 512,
    depth = 4,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)

BTW, using non pre-trained model above, I got around 0.8 acc using the same amount of time.

JamesQFreeman avatar Nov 26 '20 07:11 JamesQFreeman

The pertained model had a peak acc of 0.796 after 100 epochs of training. In this dataset, resnet50 can reach 90 without any modification. Is there any tuning trick I can use?

JamesQFreeman avatar Nov 27 '20 05:11 JamesQFreeman

Hi James, attention excels in the regime of big data, as shown in the paper. However, I am curious why fine tuning did not work. Are you using Ross' model? Perhaps submit an issue at his repository?

lucidrains avatar Nov 27 '20 21:11 lucidrains

Yes, Ross' model (which is uploaded to timm) is used. Is pretrained model always work on small dataset?

JamesQFreeman avatar Nov 29 '20 15:11 JamesQFreeman

@JamesQFreeman I think fine-tuning from a pretrained model should generally work well. maybe you should raise the issue with him

lucidrains avatar Nov 29 '20 18:11 lucidrains

@JamesQFreeman ohh... well, I think I spot the error, your learning rate is way too high 1e-2, try Karpathy's favorite LR, 3e-4

lucidrains avatar Nov 29 '20 18:11 lucidrains

@JamesQFreeman ohh... well, I think I spot the error, your learning rate is way too high 1e-2, try Karpathy's favorite LR, 3e-4

Thanks! I'll give a try.

JamesQFreeman avatar Nov 30 '20 06:11 JamesQFreeman

I also tried the experiment. lr = 3e-5 batch_size = 8

Epoch : 1 - loss : 0.0648 - acc: 0.9752 - val_loss : 0.0592 - val_acc: 0.9782 Epoch : 2 - loss : 0.0561 - acc: 0.9773 - val_loss : 0.0531 - val_acc: 0.9790 Epoch : 3 - loss : 0.0513 - acc: 0.9795 - val_loss : 0.0677 - val_acc: 0.9750 Epoch : 4 - loss : 0.0473 - acc: 0.9809 - val_loss : 0.0479 - val_acc: 0.9804 Epoch : 5 - loss : 0.0473 - acc: 0.9800 - val_loss : 0.0567 - val_acc: 0.9780 Epoch : 6 - loss : 0.0466 - acc: 0.9806 - val_loss : 0.0526 - val_acc: 0.9780 Epoch : 7 - loss : 0.0413 - acc: 0.9826 - val_loss : 0.0615 - val_acc: 0.9774 Epoch : 8 - loss : 0.0430 - acc: 0.9833 - val_loss : 0.0619 - val_acc: 0.9746 Epoch : 9 - loss : 0.0411 - acc: 0.9832 - val_loss : 0.0616 - val_acc: 0.9784 Epoch : 10 - loss : 0.0450 - acc: 0.9824 - val_loss : 0.0483 - val_acc: 0.9830 Epoch : 11 - loss : 0.0374 - acc: 0.9842 - val_loss : 0.0598 - val_acc: 0.9746 Epoch : 12 - loss : 0.0393 - acc: 0.9844 - val_loss : 0.1202 - val_acc: 0.9602 Epoch : 13 - loss : 0.0418 - acc: 0.9830 - val_loss : 0.0547 - val_acc: 0.9806 Epoch : 14 - loss : 0.0380 - acc: 0.9846 - val_loss : 0.0578 - val_acc: 0.9760 Epoch : 15 - loss : 0.0376 - acc: 0.9852 - val_loss : 0.0557 - val_acc: 0.9786 Epoch : 16 - loss : 0.0372 - acc: 0.9845 - val_loss : 0.0595 - val_acc: 0.9790 Epoch : 17 - loss : 0.0379 - acc: 0.9846 - val_loss : 0.0560 - val_acc: 0.9802 Epoch : 18 - loss : 0.0353 - acc: 0.9859 - val_loss : 0.0561 - val_acc: 0.9818 Epoch : 19 - loss : 0.0361 - acc: 0.9860 - val_loss : 0.0482 - val_acc: 0.9810 Epoch : 20 - loss : 0.0349 - acc: 0.9864 - val_loss : 0.0547 - val_acc: 0.9792

emmmmm,not bad. I think it will better if i can tunning the parameter.

Lin-Zhipeng avatar Dec 03 '20 03:12 Lin-Zhipeng

lower learning rate and SGD are better for fine-tuning, don't use Adam

XA-kirino avatar Jan 14 '21 02:01 XA-kirino

I also tried the experiment. lr = 3e-5 batch_size = 8

Epoch : 1 - loss : 0.0648 - acc: 0.9752 - val_loss : 0.0592 - val_acc: 0.9782 Epoch : 2 - loss : 0.0561 - acc: 0.9773 - val_loss : 0.0531 - val_acc: 0.9790 Epoch : 3 - loss : 0.0513 - acc: 0.9795 - val_loss : 0.0677 - val_acc: 0.9750 Epoch : 4 - loss : 0.0473 - acc: 0.9809 - val_loss : 0.0479 - val_acc: 0.9804 Epoch : 5 - loss : 0.0473 - acc: 0.9800 - val_loss : 0.0567 - val_acc: 0.9780 Epoch : 6 - loss : 0.0466 - acc: 0.9806 - val_loss : 0.0526 - val_acc: 0.9780 Epoch : 7 - loss : 0.0413 - acc: 0.9826 - val_loss : 0.0615 - val_acc: 0.9774 Epoch : 8 - loss : 0.0430 - acc: 0.9833 - val_loss : 0.0619 - val_acc: 0.9746 Epoch : 9 - loss : 0.0411 - acc: 0.9832 - val_loss : 0.0616 - val_acc: 0.9784 Epoch : 10 - loss : 0.0450 - acc: 0.9824 - val_loss : 0.0483 - val_acc: 0.9830 Epoch : 11 - loss : 0.0374 - acc: 0.9842 - val_loss : 0.0598 - val_acc: 0.9746 Epoch : 12 - loss : 0.0393 - acc: 0.9844 - val_loss : 0.1202 - val_acc: 0.9602 Epoch : 13 - loss : 0.0418 - acc: 0.9830 - val_loss : 0.0547 - val_acc: 0.9806 Epoch : 14 - loss : 0.0380 - acc: 0.9846 - val_loss : 0.0578 - val_acc: 0.9760 Epoch : 15 - loss : 0.0376 - acc: 0.9852 - val_loss : 0.0557 - val_acc: 0.9786 Epoch : 16 - loss : 0.0372 - acc: 0.9845 - val_loss : 0.0595 - val_acc: 0.9790 Epoch : 17 - loss : 0.0379 - acc: 0.9846 - val_loss : 0.0560 - val_acc: 0.9802 Epoch : 18 - loss : 0.0353 - acc: 0.9859 - val_loss : 0.0561 - val_acc: 0.9818 Epoch : 19 - loss : 0.0361 - acc: 0.9860 - val_loss : 0.0482 - val_acc: 0.9810 Epoch : 20 - loss : 0.0349 - acc: 0.9864 - val_loss : 0.0547 - val_acc: 0.9792

emmmmm,not bad. I think it will better if i can tunning the parameter.

Hi,you really got good results with acc,I wonder how did you do that?I mean the datasets and the pre-trained weight.

myt889 avatar Apr 20 '21 09:04 myt889

@JamesQFreeman @myt889 @Lin-Zhipeng can you share your train.py to load it from pretrained model , it will be very helpful did u try loading the varriant of the models ??

abhigoku10 avatar Aug 24 '21 02:08 abhigoku10