COSINE
COSINE copied to clipboard
[NAACL 2021] This is the code for our paper `Fine-Tuning Pre-trained Language Model with Weak Supervision: A Contrastive-Regularized Self-Training Approach'.
COSINE
This repo contains our code for paper Fine-Tuning Pre-trained Language Model with Weak Supervision: A Contrastive-Regularized Self-Training Approach (In Proc. of NAACL-HLT 2021).
COSINE is also included in the WRENCH benchmark 🔧. Feel free to checkout the repo and the preprint for details!
Model Framework
Benchmark
The results on different datasets are summarized as follows:
Method | AGNews | IMDB | Yelp | MIT-R | TREC | Chemprot | WiC (dev) |
---|---|---|---|---|---|---|---|
Full Supervision (Roberta-base) | 91.41 | 94.26 | 97.27 | 88.51 | 96.68 | 79.65 | 70.53 |
Direct Fine-tune with Weak Supervision (Roberta-base) | 82.25 | 72.60 | 74.89 | 70.95 | 62.25 | 44.80 | 59.36 |
Previous SOTA | 86.28 | 86.98 | 92.05 | 74.41 | 80.20 | 53.48 | 64.88 |
COSINE | 87.52 | 90.54 | 95.97 | 76.61 | 82.59 | 54.36 | 67.71 |
- Previous SOTA: Self-ensemble/FreeLB/Mixup/SMART (Fine-tuning Approach); Snorkel/WeSTClass/ImplyLoss/Denoise/UST (Weakly-supervised Approach).
Data
The weakly labeled datasets we used in our experiments are in here: dataset. The statistics of dataset is summarized as follows:
Dataset | AGNews | IMDB | Yelp | TREC | MIT-R | Chemprot | WiC (dev) |
---|---|---|---|---|---|---|---|
Type | Topic | Sentiment | Sentiment | Slot Filling | Question | Relation | Word Sense Disambiguation |
# of Training Samples | 96k | 20k | 30.4k | 6.6k | 4.8k | 12.6k | 5.4k |
# of Validation Samples | 12k | 2.5k | 3.8k | 1.0k | 0.6k | 1.6k | 0.6k |
# of Test Samples | 12k | 2.5k | 3.8k | 1.5k | 0.6k | 1.6k | 1.4k |
Coverage | 56.4% | 87.5% | 82.8% | 13.5% | 95.0% | 85.9% | 63.4% |
Accuracy | 83.1% | 74.5% | 71.5% | 80.7% | 63.8% | 46.5% | 58.8% |
Package
- PyTorch 1.2
- python 3.6
- Transformers v2.8.0
- tqdm
Code
-
main.py
: the main code to run the self-training code. -
dataloader.py
: the code to preprocess text data and tokenize it. -
utils.py
: some code including calculating accuracy, saving data etc. -
modeling_roberta.py
: the code to modify the basic Roberta model for our task (we need to directly output the feature vector for RoBERTa) -
model.py
: the RoBERTa model for classfication tasks. SeeBERT_model
for details. -
trainer.py
: the code to training the RoBERTa under different settings.-
train(self)
: training for stage 1 -
selftrain(self, soft = True)
: the code for self-training based on pseudo-labeling with period update. -
soft_frequency
: the function to reweight the value of pseudo-labels based on WESTClass. -
calc_loss
: Calculate the prediction loss for self-training. -
contrastive_loss
: Contrastive loss on sample pairs.
-
Run the Code
Please use run_agnews.sh
to run the code for AGnews dataset as an example.
Key Parameters
For each model, we summarize the key parameters as follows (note that some parameters defined in the args are obsolete, and we will clean them up later):
General
- use
--method
to determine the training method you use-
clean
: train on clean data -
noisy
: train directly on weakly labeled data -
selftrain
: self-training
-
- use
--task
to determine the dataset. Choice includes 'agnews', 'imdb', 'yelp', 'mit-r', 'trec', 'chemprot', 'wic'. - use
--task_type
to determine the training task.tc
stands for text classification, 're' means relation classification. - use
--gpu
to allocate the GPU resource to speed up training. - use
--max_seq_len
to determine the maximum number of tokens per sentences. - use
--auto_load
to automatically load the cached training data. Otherwise, we will regenerate the training/dev/test set - Change code in
utils.py
to add special tokens (in line 26).
For self-training-based model
- Use
--self_training_eps
to determine the threshold for confidence. Usually set around 0.6-0.7. - Use
--self_training_power
to control the power for calculating pseudo labels. - Use
--self_training_contrastive_weight
to control the power for contrastive loss. - Use
--self_training_confreg
to control the power for confidence regularization.
TODO:
- Add token classification version of our framework.
Citation
Please cite the following paper if you find our datasets/tool are useful. Thanks!
@inproceedings{yu2021fine,
title={Fine-Tuning Pre-trained Language Model with Weak Supervision: A Contrastive-Regularized Self-Training Approach},
author={Yu, Yue and Zuo, Simiao and Jiang, Haoming and Ren, Wendi and Zhao, Tuo and Zhang, Chao},
booktitle={Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies},
pages={1063--1077},
year={2021}
}