vit-pytorch
vit-pytorch copied to clipboard
Loss doesn't drop when training on ImageNet
Hi, Great Thanks for sharing the code! I found that the loss was always stable at around 7 when I training it with ImageNet on one 3090. Have you tried it on imageNet successfully with vit-pytorch?
THIS IS THE HYPERPARAMETERS I HAVA.
batch_size = 192
image_size = 256
patch_size = 16
num_layers = 8
num_head = 8
mlp_dim = 512
dim_model = 512
num_class = 1000
channel = 3
dropout = 0.4
learning_rate = 3e-4
beta1 = 0.9
beta2 = 0.999
weight_decay = 0.01
epoches = 20
num_workers = 4
Here are some logs when i training:
2020-12-01 20:12:48,410 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6519/6672 - Iter loss : 27.8335 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:50,900 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6520/6672 - Iter loss : 33.7667 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:52,467 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6521/6672 - Iter loss : 16.5766 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:53,450 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6522/6672 - Iter loss : 9.5950 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:54,919 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6523/6672 - Iter loss : 12.1596 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:57,193 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6524/6672 - Iter loss : 8.5739 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:58,304 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6525/6672 - Iter loss : 8.2490 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:12:59,285 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6526/6672 - Iter loss : 6.9780 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:05,256 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6527/6672 - Iter loss : 6.9443 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:07,644 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6528/6672 - Iter loss : 7.0152 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:09,009 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6529/6672 - Iter loss : 7.0805 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:09,993 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6530/6672 - Iter loss : 7.1097 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:13,525 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6531/6672 - Iter loss : 7.1950 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:20,461 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6532/6672 - Iter loss : 10.4532 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:23,267 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6533/6672 - Iter loss : 11.6160 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:24,247 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6534/6672 - Iter loss : 14.5333 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:25,253 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6535/6672 - Iter loss : 49.1107 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:27,575 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6536/6672 - Iter loss : 36.7893 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:29,475 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6537/6672 - Iter loss : 22.1718 - Iter acc: 0.0104 - Num correct: 2
2020-12-01 20:13:30,457 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6538/6672 - Iter loss : 7.3235 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:31,441 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6539/6672 - Iter loss : 7.0744 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:38,798 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6540/6672 - Iter loss : 8.2893 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:45,064 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6541/6672 - Iter loss : 7.2673 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:46,039 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6542/6672 - Iter loss : 8.7027 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:49,190 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6543/6672 - Iter loss : 7.0044 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:59,167 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6544/6672 - Iter loss : 6.9146 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:05,451 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6545/6672 - Iter loss : 7.1504 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:06,431 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6546/6672 - Iter loss : 6.9319 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:07,437 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6547/6672 - Iter loss : 6.9291 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:09,486 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6548/6672 - Iter loss : 6.8974 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:13,987 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6549/6672 - Iter loss : 6.9187 - Iter acc: 0.0104 - Num correct: 2
2020-12-01 20:14:14,967 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6550/6672 - Iter loss : 7.0641 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:15,952 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6551/6672 - Iter loss : 7.0853 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:18,097 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6552/6672 - Iter loss : 7.2872 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:14:23,045 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6553/6672 - Iter loss : 7.0384 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:24,026 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6554/6672 - Iter loss : 6.9754 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:25,033 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6555/6672 - Iter loss : 7.0169 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:14:26,194 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6556/6672 - Iter loss : 6.9243 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:29,584 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6557/6672 - Iter loss : 6.9077 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:30,563 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6558/6672 - Iter loss : 6.9486 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:31,734 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6559/6672 - Iter loss : 6.9292 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:14:41,763 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6560/6672 - Iter loss : 6.9972 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:47,781 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6561/6672 - Iter loss : 7.0302 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:48,764 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6562/6672 - Iter loss : 6.9987 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:50,913 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6563/6672 - Iter loss : 6.9165 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:57,334 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6564/6672 - Iter loss : 6.9228 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:15:03,513 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6565/6672 - Iter loss : 7.0496 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:04,494 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6566/6672 - Iter loss : 6.9134 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:05,499 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6567/6672 - Iter loss : 7.1323 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:06,459 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6568/6672 - Iter loss : 6.9309 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:10,015 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6569/6672 - Iter loss : 6.9452 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:10,995 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6570/6672 - Iter loss : 6.9216 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:11,978 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6571/6672 - Iter loss : 7.0044 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:12,938 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6572/6672 - Iter loss : 6.9816 - Iter acc: 0.0000 - Num correct: 0
I tried the default vit configuration on my own dataset, the loss didn't drop, too. Then I used smaller 'depth' and 'dim', it worked. Maybe you can try to modify the model architecture.
I tried the default vit configuration on my own dataset, the loss didn't drop, too. Then I used smaller 'depth' and 'dim', it worked. Maybe you can try to modify the model architecture.
After I decreased the dim from 1024 to 256, num_layers and num_head from 8 to 2, the number of parameters has reduced to 679016 which is 100 times smaller than the orignal network setting. But it still did not work, is it stilll too many params or can you share your param setting? Thanks!
Here are my logs of the 5th epoch:
2020-12-03 06:39:21,075 - meter.py[line:34] - INFO: Epoch: [3][ 0/6673] Time 2.843 ( 2.843) Data 2.753 ( 2.753) Loss 6.9075e+00 (6.9075e+00) Acc@1 0.00 ( 0.00) Acc@5 0.00 ( 0.00)
2020-12-03 06:44:17,016 - meter.py[line:34] - INFO: Epoch: [3][ 500/6673] Time 1.920 ( 0.596) Data 1.860 ( 0.531) Loss 6.9071e+00 (6.9077e+00) Acc@1 0.00 ( 0.11) Acc@5 0.00 ( 0.54)
2020-12-03 06:49:20,588 - meter.py[line:34] - INFO: Epoch: [3][1000/6673] Time 0.082 ( 0.602) Data 0.022 ( 0.537) Loss 6.9074e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.52 ( 0.53)
2020-12-03 06:54:45,241 - meter.py[line:34] - INFO: Epoch: [3][1500/6673] Time 0.658 ( 0.618) Data 0.600 ( 0.552) Loss 6.9077e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 1.04 ( 0.52)
2020-12-03 07:00:25,042 - meter.py[line:34] - INFO: Epoch: [3][2000/6673] Time 0.107 ( 0.633) Data 0.019 ( 0.567) Loss 6.9081e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.52)
2020-12-03 07:06:20,696 - meter.py[line:34] - INFO: Epoch: [3][2500/6673] Time 0.106 ( 0.649) Data 0.020 ( 0.580) Loss 6.9073e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.51)
2020-12-03 07:12:51,716 - meter.py[line:34] - INFO: Epoch: [3][3000/6673] Time 0.110 ( 0.671) Data 0.021 ( 0.602) Loss 6.9077e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 1.56 ( 0.51)
2020-12-03 07:19:59,395 - meter.py[line:34] - INFO: Epoch: [3][3500/6673] Time 0.861 ( 0.697) Data 0.800 ( 0.628) Loss 6.9069e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 1.04 ( 0.50)
2020-12-03 07:27:54,941 - meter.py[line:34] - INFO: Epoch: [3][4000/6673] Time 0.111 ( 0.729) Data 0.019 ( 0.660) Loss 6.9075e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.52 ( 0.50)
2020-12-03 07:37:17,196 - meter.py[line:34] - INFO: Epoch: [3][4500/6673] Time 0.108 ( 0.773) Data 0.018 ( 0.704) Loss 6.9070e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 1.04 ( 0.50)
2020-12-03 07:48:40,004 - meter.py[line:34] - INFO: Epoch: [3][5000/6673] Time 0.117 ( 0.832) Data 0.018 ( 0.763) Loss 6.9078e+00 (6.9077e+00) Acc@1 0.52 ( 0.10) Acc@5 0.52 ( 0.50)
2020-12-03 08:02:10,024 - meter.py[line:34] - INFO: Epoch: [3][5500/6673] Time 0.110 ( 0.904) Data 0.019 ( 0.834) Loss 6.9072e+00 (6.9077e+00) Acc@1 0.52 ( 0.09) Acc@5 0.52 ( 0.50)
2020-12-03 08:16:39,075 - meter.py[line:34] - INFO: Epoch: [3][6000/6673] Time 5.605 ( 0.973) Data 5.545 ( 0.903) Loss 6.9067e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 1.04 ( 0.50)
2020-12-03 08:31:23,624 - meter.py[line:34] - INFO: Epoch: [3][6500/6673] Time 0.109 ( 1.035) Data 0.019 ( 0.963) Loss 6.9074e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 1.04 ( 0.50)
2020-12-03 08:36:24,116 - vit_train.py[line:193] - INFO: ++++++++++++++++++++++++Training for one epoch +++++++++++++++++++++++++++++++++++
2020-12-03 08:36:27,453 - meter.py[line:34] - INFO: Test: [ 0/261] Time 3.336 ( 3.336) Loss 6.9085e+00 (6.9085e+00) Acc@1 0.00 ( 0.00) Acc@5 0.00 ( 0.00)
2020-12-03 08:39:17,536 - vit_train.py[line:153] - INFO: * Acc@1 0.100 Acc@5 0.500
2020-12-03 08:39:17,537 - vit_train.py[line:207] - INFO: save checkpoint.................
2020-12-03 08:39:20,401 - meter.py[line:34] - INFO: Epoch: [4][ 0/6673] Time 2.782 ( 2.782) Data 2.714 ( 2.714) Loss
2020-12-03 08:44:16,053 - meter.py[line:34] - INFO: Epoch: [4][ 500/6673] Time 0.912 ( 0.596) Data 0.843 ( 0.517) Loss 6.9077e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 1.04 ( 0.50)
2020-12-03 08:49:24,166 - meter.py[line:34] - INFO: Epoch: [4][1000/6673] Time 2.126 ( 0.606) Data 2.041 ( 0.527) Loss 6.9080e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.48)
2020-12-03 08:55:15,165 - meter.py[line:34] - INFO: Epoch: [4][1500/6673] Time 4.573 ( 0.638) Data 4.507 ( 0.564) Loss 6.9084e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.48)
2020-12-03 09:01:45,592 - meter.py[line:34] - INFO: Epoch: [4][2000/6673] Time 0.659 ( 0.674) Data 0.601 ( 0.602) Loss 6.9079e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.49)
2020-12-03 09:08:26,356 - meter.py[line:34] - INFO: Epoch: [4][2500/6673] Time 0.109 ( 0.699) Data 0.019 ( 0.627) Loss 6.9072e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.00 ( 0.49)
2020-12-03 09:15:38,532 - meter.py[line:34] - INFO: Epoch: [4][3000/6673] Time 0.310 ( 0.727) Data 0.225 ( 0.654) Loss 6.9079e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.48)
2020-12-03 09:23:49,919 - meter.py[line:34] - INFO: Epoch: [4][3500/6673] Time 1.382 ( 0.763) Data 1.294 ( 0.690) Loss 6.9079e+00 (6.9077e+00) Acc@1 0.00 ( 0.09) Acc@5 0.00 ( 0.48)
2020-12-03 09:33:21,487 - meter.py[line:34] - INFO: Epoch: [4][4000/6673] Time 3.452 ( 0.811) Data 3.392 ( 0.737) Loss 6.9071e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.52 ( 0.48)
2020-12-03 09:44:09,826 - meter.py[line:34] - INFO: Epoch: [4][4500/6673] Time 0.089 ( 0.865) Data 0.022 ( 0.792) Loss 6.9078e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.52 ( 0.49)
2020-12-03 09:57:39,927 - meter.py[line:34] - INFO: Epoch: [4][5000/6673] Time 3.217 ( 0.940) Data 3.156 ( 0.868) Loss 6.9077e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.00 ( 0.49)
2020-12-03 10:13:39,403 - meter.py[line:34] - INFO: Epoch: [4][5500/6673] Time 0.102 ( 1.029) Data 0.019 ( 0.957) Loss 6.9079e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.00 ( 0.49)
2020-12-03 10:31:27,529 - meter.py[line:34] - INFO: Epoch: [4][6000/6673] Time 0.108 ( 1.121) Data 0.019 ( 1.049) Loss 6.9074e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.00 ( 0.49)
2020-12-03 10:49:44,418 - meter.py[line:34] - INFO: Epoch: [4][6500/6673] Time 0.085 ( 1.204) Data 0.019 ( 1.131) Loss 6.9085e+00 (6.9077e+00) Acc@1 0.00 ( 0.10) Acc@5 0.00 ( 0.49)
2020-12-03 10:56:05,657 - vit_train.py[line:193] - INFO: ++++++++++++++++++++++++Training for one epoch +++++++++++++++++++++++++++++++++++
2020-12-03 10:56:14,022 - meter.py[line:34] - INFO: Test: [ 0/261] Time 8.363 ( 8.363) Loss 6.9058e+00 (6.9058e+00) Acc@1 0.00 ( 0.00) Acc@5 0.00 ( 0.00)
2020-12-03 10:59:03,169 - vit_train.py[line:153] - INFO: * Acc@1 0.100 Acc@5 0.500
2020-12-03 10:59:03,170 - vit_train.py[line:207] - INFO: save checkpoint.................
2020-12-03 10:59:05,996 - meter.py[line:34] - INFO: Epoch: [5][ 0/6673] Time 2.762 ( 2.762) Data 2.681 ( 2.681) Loss 6.9081e+00 (6.9081e+00) Acc@1 0.00 ( 0.00) Acc@5 0.52 ( 0.52)
2020-12-03 11:05:38,925 - meter.py[line:34] - INFO: Epoch: [5][ 500/6673] Time 0.089 ( 0.790) Data 0.021 ( 0.713) Loss 6.9079e+00 (6.9077e+00) Acc@1 0.00 ( 0.11) Acc@5 0.52 ( 0.51)
2020-12-03 11:12:54,162 - meter.py[line:34] - INFO: Epoch: [5][1000/6673] Time 0.079 ( 0.830) Data 0.019 ( 0.754) Loss 6.9076e+00 (6.9077e+00) Acc@1 0.00 ( 0.12) Acc@5 0.52 ( 0.51)
2020-12-03 11:20:38,089 - meter.py[line:34] - INFO: Epoch: [5][1500/6673] Time 0.099 ( 0.863) Data 0.024 ( 0.787) Loss 6.9078e+00 (6.9077e+00) Acc@1 0.52 ( 0.12) Acc@5 0.52 ( 0.51)
2020-12-03 11:28:54,603 - meter.py[line:34] - INFO: Epoch: [5][2000/6673] Time 2.232 ( 0.895) Data 2.162 ( 0.819) Loss 6.9078e+00 (6.9077e+00) Acc@1 0.00 ( 0.12) Acc@5 1.04 ( 0.52)
2020-12-03 11:37:28,869 - meter.py[line:34] - INFO: Epoch: [5][2500/6673] Time 0.096 ( 0.922) Data 0.020 ( 0.846) Loss 6.9080e+00 (6.9077e+00) Acc@1 0.00 ( 0.11) Acc@5 0.00 ( 0.51)
2020-12-03 11:46:50,222 - meter.py[line:34] - INFO: Epoch: [5][3000/6673] Time 0.087 ( 0.955) Data 0.020 ( 0.879) Loss 6.9076e+00 (6.9077e+00) Acc@1 0.00 ( 0.11) Acc@5 0.00 ( 0.50)
also trouble with this issue, is there any tips about how to train vit on ImageNet