flair icon indicating copy to clipboard operation
flair copied to clipboard

embeddings: add support for T5 encoder models

Open stefan-it opened this issue 1 year ago • 0 comments

Hi,

this PR adds support for encoder-only fine-tuning T5 models. Supported models are T5, mT5 and LongT5. Unfortunately, ByT5 is currently not working, because this would require a change in our internal tokenization logic.

I tested fine-tuning the encoder of t5-v1_1-base on CoNLL-2003 using the fine-tuning script in our examples/ner section:

python3 run_ner.py --dataset_name CONLL_03\
--model_name_or_path google/t5-v1_1-base\
--batch_size 4\
--learning_rate 5e-05\
--num_epochs 10\
--output_dir conll-03-t5-base

I got promising results with these hyper-params: it achieves 91.38% on test set:

See current logs
2022-08-08 10:49:41,338 ----------------------------------------------------------------------------------------------------
2022-08-08 10:49:41,339 Corpus: "Corpus: 14987 train + 3466 dev + 3684 test sentences"
2022-08-08 10:49:41,339 ----------------------------------------------------------------------------------------------------
2022-08-08 10:49:41,339 Parameters:
2022-08-08 10:49:41,339  - learning_rate: "0.000050"
2022-08-08 10:49:41,339  - mini_batch_size: "4"
2022-08-08 10:49:41,339  - patience: "3"
2022-08-08 10:49:41,340  - anneal_factor: "0.5"
2022-08-08 10:49:41,340  - max_epochs: "10"
2022-08-08 10:49:41,340  - shuffle: "True"
2022-08-08 10:49:41,340  - train_with_dev: "False"
2022-08-08 10:49:41,340  - batch_growth_annealing: "False"
2022-08-08 10:49:41,340 ----------------------------------------------------------------------------------------------------
2022-08-08 10:49:41,340 Model training base path: "conll-03-t5-base"
2022-08-08 10:49:41,341 ----------------------------------------------------------------------------------------------------
2022-08-08 10:49:41,341 Device: cuda:0
2022-08-08 10:49:41,341 ----------------------------------------------------------------------------------------------------
2022-08-08 10:49:41,341 Embeddings storage mode: none
2022-08-08 10:49:41,341 ----------------------------------------------------------------------------------------------------
2022-08-08 10:51:22,246 epoch 1 - iter 374/3747 - loss 2.85495099 - samples/sec: 14.83 - lr: 0.000005
2022-08-08 10:53:02,046 epoch 1 - iter 748/3747 - loss 2.79143992 - samples/sec: 14.99 - lr: 0.000010
2022-08-08 10:54:43,478 epoch 1 - iter 1122/3747 - loss 2.60677163 - samples/sec: 14.75 - lr: 0.000015
2022-08-08 10:56:25,174 epoch 1 - iter 1496/3747 - loss 2.33643808 - samples/sec: 14.71 - lr: 0.000020
2022-08-08 10:58:05,318 epoch 1 - iter 1870/3747 - loss 2.03904417 - samples/sec: 14.94 - lr: 0.000025
2022-08-08 10:59:45,430 epoch 1 - iter 2244/3747 - loss 1.78307621 - samples/sec: 14.95 - lr: 0.000030
2022-08-08 11:01:25,588 epoch 1 - iter 2618/3747 - loss 1.58793961 - samples/sec: 14.94 - lr: 0.000035
2022-08-08 11:03:05,589 epoch 1 - iter 2992/3747 - loss 1.45233487 - samples/sec: 14.96 - lr: 0.000040
2022-08-08 11:04:46,092 epoch 1 - iter 3366/3747 - loss 1.32495761 - samples/sec: 14.89 - lr: 0.000045
2022-08-08 11:06:33,485 epoch 1 - iter 3740/3747 - loss 1.20748517 - samples/sec: 13.93 - lr: 0.000050
2022-08-08 11:06:35,536 ----------------------------------------------------------------------------------------------------
2022-08-08 11:06:35,536 EPOCH 1 done: loss 1.2063 - lr 0.000050
2022-08-08 11:06:59,710 Evaluating as a multi-label problem: False
2022-08-08 11:06:59,762 DEV : loss 0.18262654542922974 - f1-score (micro avg)  0.7874
2022-08-08 11:06:59,857 BAD EPOCHS (no improvement): 4
2022-08-08 11:06:59,859 ----------------------------------------------------------------------------------------------------
2022-08-08 11:08:43,559 epoch 2 - iter 374/3747 - loss 0.34818564 - samples/sec: 14.43 - lr: 0.000049
2022-08-08 11:10:23,639 epoch 2 - iter 748/3747 - loss 0.32684598 - samples/sec: 14.95 - lr: 0.000049
2022-08-08 11:12:03,897 epoch 2 - iter 1122/3747 - loss 0.31373943 - samples/sec: 14.93 - lr: 0.000048
2022-08-08 11:13:43,949 epoch 2 - iter 1496/3747 - loss 0.30196577 - samples/sec: 14.96 - lr: 0.000048
2022-08-08 11:15:24,533 epoch 2 - iter 1870/3747 - loss 0.29414107 - samples/sec: 14.88 - lr: 0.000047
2022-08-08 11:17:04,117 epoch 2 - iter 2244/3747 - loss 0.28818838 - samples/sec: 15.03 - lr: 0.000047
2022-08-08 11:18:43,763 epoch 2 - iter 2618/3747 - loss 0.28146902 - samples/sec: 15.02 - lr: 0.000046
2022-08-08 11:20:23,522 epoch 2 - iter 2992/3747 - loss 0.27523743 - samples/sec: 15.00 - lr: 0.000046
2022-08-08 11:22:04,016 epoch 2 - iter 3366/3747 - loss 0.26903608 - samples/sec: 14.89 - lr: 0.000045
2022-08-08 11:23:43,834 epoch 2 - iter 3740/3747 - loss 0.26496687 - samples/sec: 14.99 - lr: 0.000044
2022-08-08 11:23:45,633 ----------------------------------------------------------------------------------------------------
2022-08-08 11:23:45,633 EPOCH 2 done: loss 0.2648 - lr 0.000044
2022-08-08 11:24:13,577 Evaluating as a multi-label problem: False
2022-08-08 11:24:13,627 DEV : loss 0.07291336357593536 - f1-score (micro avg)  0.924
2022-08-08 11:24:13,693 BAD EPOCHS (no improvement): 4
2022-08-08 11:24:13,695 ----------------------------------------------------------------------------------------------------
2022-08-08 11:25:52,964 epoch 3 - iter 374/3747 - loss 0.20482545 - samples/sec: 15.08 - lr: 0.000044
2022-08-08 11:27:33,043 epoch 3 - iter 748/3747 - loss 0.20323350 - samples/sec: 14.95 - lr: 0.000043
2022-08-08 11:29:12,944 epoch 3 - iter 1122/3747 - loss 0.20492864 - samples/sec: 14.98 - lr: 0.000043
2022-08-08 11:30:52,841 epoch 3 - iter 1496/3747 - loss 0.20793765 - samples/sec: 14.98 - lr: 0.000042
2022-08-08 11:32:36,617 epoch 3 - iter 1870/3747 - loss 0.20995132 - samples/sec: 14.42 - lr: 0.000042
2022-08-08 11:34:25,574 epoch 3 - iter 2244/3747 - loss 0.20857272 - samples/sec: 13.73 - lr: 0.000041
2022-08-08 11:36:05,992 epoch 3 - iter 2618/3747 - loss 0.20741541 - samples/sec: 14.90 - lr: 0.000041
2022-08-08 11:37:45,762 epoch 3 - iter 2992/3747 - loss 0.20632236 - samples/sec: 15.00 - lr: 0.000040
2022-08-08 11:39:25,564 epoch 3 - iter 3366/3747 - loss 0.20443458 - samples/sec: 14.99 - lr: 0.000039
2022-08-08 11:41:05,117 epoch 3 - iter 3740/3747 - loss 0.20433500 - samples/sec: 15.03 - lr: 0.000039
2022-08-08 11:41:06,928 ----------------------------------------------------------------------------------------------------
2022-08-08 11:41:06,928 EPOCH 3 done: loss 0.2044 - lr 0.000039
2022-08-08 11:41:33,071 Evaluating as a multi-label problem: False
2022-08-08 11:41:33,121 DEV : loss 0.0718647688627243 - f1-score (micro avg)  0.9417
2022-08-08 11:41:33,187 BAD EPOCHS (no improvement): 4
2022-08-08 11:41:33,188 ----------------------------------------------------------------------------------------------------
2022-08-08 11:43:13,212 epoch 4 - iter 374/3747 - loss 0.19210315 - samples/sec: 14.96 - lr: 0.000038
2022-08-08 11:44:53,217 epoch 4 - iter 748/3747 - loss 0.19096504 - samples/sec: 14.96 - lr: 0.000038
2022-08-08 11:46:32,781 epoch 4 - iter 1122/3747 - loss 0.19001080 - samples/sec: 15.03 - lr: 0.000037
2022-08-08 11:48:12,899 epoch 4 - iter 1496/3747 - loss 0.19074383 - samples/sec: 14.95 - lr: 0.000037
2022-08-08 11:49:52,676 epoch 4 - iter 1870/3747 - loss 0.18948673 - samples/sec: 15.00 - lr: 0.000036
2022-08-08 11:51:32,621 epoch 4 - iter 2244/3747 - loss 0.18923435 - samples/sec: 14.97 - lr: 0.000036
2022-08-08 11:53:19,835 epoch 4 - iter 2618/3747 - loss 0.18684948 - samples/sec: 13.96 - lr: 0.000035
2022-08-08 11:55:01,867 epoch 4 - iter 2992/3747 - loss 0.18522100 - samples/sec: 14.67 - lr: 0.000034
2022-08-08 11:56:42,730 epoch 4 - iter 3366/3747 - loss 0.18430053 - samples/sec: 14.84 - lr: 0.000034
2022-08-08 11:58:23,167 epoch 4 - iter 3740/3747 - loss 0.18356806 - samples/sec: 14.90 - lr: 0.000033
2022-08-08 11:58:24,973 ----------------------------------------------------------------------------------------------------
2022-08-08 11:58:24,973 EPOCH 4 done: loss 0.1836 - lr 0.000033
2022-08-08 11:58:53,333 Evaluating as a multi-label problem: False
2022-08-08 11:58:53,382 DEV : loss 0.06763585656881332 - f1-score (micro avg)  0.951
2022-08-08 11:58:53,457 BAD EPOCHS (no improvement): 4
2022-08-08 11:58:53,459 ----------------------------------------------------------------------------------------------------
2022-08-08 12:00:33,381 epoch 5 - iter 374/3747 - loss 0.16898301 - samples/sec: 14.98 - lr: 0.000033
2022-08-08 12:02:13,327 epoch 5 - iter 748/3747 - loss 0.16596862 - samples/sec: 14.97 - lr: 0.000032
2022-08-08 12:03:53,915 epoch 5 - iter 1122/3747 - loss 0.16801329 - samples/sec: 14.88 - lr: 0.000032
2022-08-08 12:05:35,374 epoch 5 - iter 1496/3747 - loss 0.16710119 - samples/sec: 14.75 - lr: 0.000031
2022-08-08 12:07:15,115 epoch 5 - iter 1870/3747 - loss 0.16677485 - samples/sec: 15.00 - lr: 0.000031
2022-08-08 12:08:56,254 epoch 5 - iter 2244/3747 - loss 0.16775390 - samples/sec: 14.80 - lr: 0.000030
2022-08-08 12:10:41,461 epoch 5 - iter 2618/3747 - loss 0.16932652 - samples/sec: 14.22 - lr: 0.000029
2022-08-08 12:12:25,711 epoch 5 - iter 2992/3747 - loss 0.16749110 - samples/sec: 14.35 - lr: 0.000029
2022-08-08 12:14:15,495 epoch 5 - iter 3366/3747 - loss 0.16670496 - samples/sec: 13.63 - lr: 0.000028
2022-08-08 12:15:55,633 epoch 5 - iter 3740/3747 - loss 0.16615409 - samples/sec: 14.94 - lr: 0.000028
2022-08-08 12:15:57,411 ----------------------------------------------------------------------------------------------------
2022-08-08 12:15:57,411 EPOCH 5 done: loss 0.1661 - lr 0.000028
2022-08-08 12:16:23,482 Evaluating as a multi-label problem: False
2022-08-08 12:16:23,531 DEV : loss 0.06966154277324677 - f1-score (micro avg)  0.956
2022-08-08 12:16:23,596 BAD EPOCHS (no improvement): 4
2022-08-08 12:16:23,598 ----------------------------------------------------------------------------------------------------
2022-08-08 12:18:03,524 epoch 6 - iter 374/3747 - loss 0.15837342 - samples/sec: 14.98 - lr: 0.000027
2022-08-08 12:19:51,657 epoch 6 - iter 748/3747 - loss 0.16083516 - samples/sec: 13.84 - lr: 0.000027
2022-08-08 12:21:31,987 epoch 6 - iter 1122/3747 - loss 0.15579719 - samples/sec: 14.92 - lr: 0.000026
2022-08-08 12:23:18,939 epoch 6 - iter 1496/3747 - loss 0.15501790 - samples/sec: 13.99 - lr: 0.000026
2022-08-08 12:24:58,791 epoch 6 - iter 1870/3747 - loss 0.15459379 - samples/sec: 14.99 - lr: 0.000025
2022-08-08 12:26:39,016 epoch 6 - iter 2244/3747 - loss 0.15411446 - samples/sec: 14.93 - lr: 0.000024
2022-08-08 12:28:18,862 epoch 6 - iter 2618/3747 - loss 0.15394347 - samples/sec: 14.99 - lr: 0.000024
2022-08-08 12:29:58,172 epoch 6 - iter 2992/3747 - loss 0.15233145 - samples/sec: 15.07 - lr: 0.000023
2022-08-08 12:31:37,537 epoch 6 - iter 3366/3747 - loss 0.15244938 - samples/sec: 15.06 - lr: 0.000023
2022-08-08 12:33:17,700 epoch 6 - iter 3740/3747 - loss 0.15228318 - samples/sec: 14.94 - lr: 0.000022
2022-08-08 12:33:19,598 ----------------------------------------------------------------------------------------------------
2022-08-08 12:33:19,599 EPOCH 6 done: loss 0.1524 - lr 0.000022
2022-08-08 12:33:50,988 Evaluating as a multi-label problem: False
2022-08-08 12:33:51,038 DEV : loss 0.06917252391576767 - f1-score (micro avg)  0.9557
2022-08-08 12:33:51,112 BAD EPOCHS (no improvement): 4
2022-08-08 12:33:51,114 ----------------------------------------------------------------------------------------------------
2022-08-08 12:35:31,345 epoch 7 - iter 374/3747 - loss 0.14431815 - samples/sec: 14.93 - lr: 0.000022
2022-08-08 12:37:11,030 epoch 7 - iter 748/3747 - loss 0.14536920 - samples/sec: 15.01 - lr: 0.000021
2022-08-08 12:38:51,069 epoch 7 - iter 1122/3747 - loss 0.14373482 - samples/sec: 14.96 - lr: 0.000021
2022-08-08 12:40:31,963 epoch 7 - iter 1496/3747 - loss 0.14245254 - samples/sec: 14.83 - lr: 0.000020
2022-08-08 12:42:11,381 epoch 7 - iter 1870/3747 - loss 0.14282917 - samples/sec: 15.05 - lr: 0.000019
2022-08-08 12:43:50,963 epoch 7 - iter 2244/3747 - loss 0.14465897 - samples/sec: 15.03 - lr: 0.000019
2022-08-08 12:45:30,720 epoch 7 - iter 2618/3747 - loss 0.14360484 - samples/sec: 15.00 - lr: 0.000018
2022-08-08 12:47:11,012 epoch 7 - iter 2992/3747 - loss 0.14360124 - samples/sec: 14.92 - lr: 0.000018
2022-08-08 12:48:50,310 epoch 7 - iter 3366/3747 - loss 0.14497951 - samples/sec: 15.07 - lr: 0.000017
2022-08-08 12:50:30,027 epoch 7 - iter 3740/3747 - loss 0.14429212 - samples/sec: 15.01 - lr: 0.000017
2022-08-08 12:50:31,779 ----------------------------------------------------------------------------------------------------
2022-08-08 12:50:31,779 EPOCH 7 done: loss 0.1442 - lr 0.000017
2022-08-08 12:51:00,058 Evaluating as a multi-label problem: False
2022-08-08 12:51:00,108 DEV : loss 0.06839105486869812 - f1-score (micro avg)  0.9569
2022-08-08 12:51:00,166 BAD EPOCHS (no improvement): 4
2022-08-08 12:51:00,168 ----------------------------------------------------------------------------------------------------
2022-08-08 12:52:46,221 epoch 8 - iter 374/3747 - loss 0.13615586 - samples/sec: 14.11 - lr: 0.000016
2022-08-08 12:54:27,060 epoch 8 - iter 748/3747 - loss 0.13583238 - samples/sec: 14.84 - lr: 0.000016
2022-08-08 12:56:06,827 epoch 8 - iter 1122/3747 - loss 0.13283108 - samples/sec: 15.00 - lr: 0.000015
2022-08-08 12:57:54,121 epoch 8 - iter 1496/3747 - loss 0.13253029 - samples/sec: 13.95 - lr: 0.000014
2022-08-08 12:59:37,153 epoch 8 - iter 1870/3747 - loss 0.13498578 - samples/sec: 14.52 - lr: 0.000014
2022-08-08 13:01:16,551 epoch 8 - iter 2244/3747 - loss 0.13333883 - samples/sec: 15.06 - lr: 0.000013
2022-08-08 13:02:59,256 epoch 8 - iter 2618/3747 - loss 0.13356751 - samples/sec: 14.57 - lr: 0.000013
2022-08-08 13:04:38,873 epoch 8 - iter 2992/3747 - loss 0.13437344 - samples/sec: 15.02 - lr: 0.000012
2022-08-08 13:06:20,420 epoch 8 - iter 3366/3747 - loss 0.13522012 - samples/sec: 14.74 - lr: 0.000012
2022-08-08 13:08:00,037 epoch 8 - iter 3740/3747 - loss 0.13532826 - samples/sec: 15.02 - lr: 0.000011
2022-08-08 13:08:01,854 ----------------------------------------------------------------------------------------------------
2022-08-08 13:08:01,854 EPOCH 8 done: loss 0.1351 - lr 0.000011
2022-08-08 13:08:28,038 Evaluating as a multi-label problem: False
2022-08-08 13:08:28,087 DEV : loss 0.06876853853464127 - f1-score (micro avg)  0.9597
2022-08-08 13:08:28,170 BAD EPOCHS (no improvement): 4
2022-08-08 13:08:28,172 ----------------------------------------------------------------------------------------------------
2022-08-08 13:10:07,557 epoch 9 - iter 374/3747 - loss 0.12768286 - samples/sec: 15.06 - lr: 0.000011
2022-08-08 13:11:47,494 epoch 9 - iter 748/3747 - loss 0.12808245 - samples/sec: 14.97 - lr: 0.000010
2022-08-08 13:13:32,131 epoch 9 - iter 1122/3747 - loss 0.13153981 - samples/sec: 14.30 - lr: 0.000009
2022-08-08 13:15:14,694 epoch 9 - iter 1496/3747 - loss 0.12980986 - samples/sec: 14.59 - lr: 0.000009
2022-08-08 13:16:54,505 epoch 9 - iter 1870/3747 - loss 0.13055254 - samples/sec: 14.99 - lr: 0.000008
2022-08-08 13:18:34,603 epoch 9 - iter 2244/3747 - loss 0.13135059 - samples/sec: 14.95 - lr: 0.000008
2022-08-08 13:20:14,213 epoch 9 - iter 2618/3747 - loss 0.13220059 - samples/sec: 15.02 - lr: 0.000007
2022-08-08 13:22:07,833 epoch 9 - iter 2992/3747 - loss 0.13070022 - samples/sec: 13.17 - lr: 0.000007
2022-08-08 13:23:52,194 epoch 9 - iter 3366/3747 - loss 0.13033551 - samples/sec: 14.34 - lr: 0.000006
2022-08-08 13:25:31,989 epoch 9 - iter 3740/3747 - loss 0.13090245 - samples/sec: 15.00 - lr: 0.000006
2022-08-08 13:25:33,787 ----------------------------------------------------------------------------------------------------
2022-08-08 13:25:33,787 EPOCH 9 done: loss 0.1309 - lr 0.000006
2022-08-08 13:26:01,768 Evaluating as a multi-label problem: False
2022-08-08 13:26:01,817 DEV : loss 0.07069706916809082 - f1-score (micro avg)  0.9572
2022-08-08 13:26:01,888 BAD EPOCHS (no improvement): 4
2022-08-08 13:26:01,890 ----------------------------------------------------------------------------------------------------
2022-08-08 13:27:41,675 epoch 10 - iter 374/3747 - loss 0.13143471 - samples/sec: 15.00 - lr: 0.000005
2022-08-08 13:29:21,120 epoch 10 - iter 748/3747 - loss 0.13005545 - samples/sec: 15.05 - lr: 0.000004
2022-08-08 13:31:00,801 epoch 10 - iter 1122/3747 - loss 0.13138708 - samples/sec: 15.01 - lr: 0.000004
2022-08-08 13:32:40,365 epoch 10 - iter 1496/3747 - loss 0.13177414 - samples/sec: 15.03 - lr: 0.000003
2022-08-08 13:34:20,222 epoch 10 - iter 1870/3747 - loss 0.13180984 - samples/sec: 14.99 - lr: 0.000003
2022-08-08 13:35:59,594 epoch 10 - iter 2244/3747 - loss 0.13310145 - samples/sec: 15.06 - lr: 0.000002
2022-08-08 13:37:38,557 epoch 10 - iter 2618/3747 - loss 0.13291887 - samples/sec: 15.12 - lr: 0.000002
2022-08-08 13:39:17,998 epoch 10 - iter 2992/3747 - loss 0.13345536 - samples/sec: 15.05 - lr: 0.000001
2022-08-08 13:40:57,590 epoch 10 - iter 3366/3747 - loss 0.13231566 - samples/sec: 15.03 - lr: 0.000001
2022-08-08 13:42:37,215 epoch 10 - iter 3740/3747 - loss 0.13184462 - samples/sec: 15.02 - lr: 0.000000
2022-08-08 13:42:39,079 ----------------------------------------------------------------------------------------------------
2022-08-08 13:42:39,079 EPOCH 10 done: loss 0.1318 - lr 0.000000
2022-08-08 13:43:05,531 Evaluating as a multi-label problem: False
2022-08-08 13:43:05,581 DEV : loss 0.07155544310808182 - f1-score (micro avg)  0.9586
2022-08-08 13:43:05,639 BAD EPOCHS (no improvement): 4
2022-08-08 13:43:06,848 ----------------------------------------------------------------------------------------------------
2022-08-08 13:43:06,852 Testing using last state of model ...
2022-08-08 13:43:36,798 Evaluating as a multi-label problem: False
2022-08-08 13:43:36,849 0.9064	0.9212	0.9138	0.8759
2022-08-08 13:43:36,849 
Results:
- F-score (micro) 0.9138
- F-score (macro) 0.8981
- Accuracy 0.8759

By class:
            precision    recall  f1-score   support

       ORG     0.8873    0.9097    0.8983      1661
       LOC     0.9320    0.9281    0.9300      1668
       PER     0.9671    0.9635    0.9653      1617
      MISC     0.7660    0.8348    0.7989       702

 micro avg     0.9064    0.9212    0.9138      5648
 macro avg     0.8881    0.9090    0.8981      5648
weighted avg     0.9083    0.9212    0.9145      5648

2022-08-08 13:43:36,850 ----------------------------------------------------------------------------------------------------

stefan-it avatar Aug 08 '22 10:08 stefan-it

@stefan-it thanks for adding this!

alanakbik avatar Aug 10 '22 13:08 alanakbik

Awesome, thanks!

On Wed, Aug 10, 2022 at 7:11 PM Alan Akbik @.***> wrote:

Merged #2896 https://github.com/flairNLP/flair/pull/2896 into master.

— Reply to this email directly, view it on GitHub https://github.com/flairNLP/flair/pull/2896#event-7163974170, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKCHYN663U252M3TPJUTF63VYOWPXANCNFSM554VVK5A . You are receiving this because you commented.Message ID: @.***>

Madhu000 avatar Aug 10 '22 17:08 Madhu000