Prompt-BERT icon indicating copy to clipboard operation
Prompt-BERT copied to clipboard

Questions about the paper.

Open chaochen99 opened this issue 3 years ago • 8 comments

Hello,

I was very fortunate to read your paper, and the experimental results are exciting. The paper mentions two observations: Observation 1: Original BERT layers fail to improve the performance. Observation 2: Embedding biases harm the sentence embeddings performance. Based on your experimental results, these two phenomena do exist and can be improved. But I don't see a connection between these two observations and prompts.

How does prompt solve the bias problem?

Looking forward to your reply, thanks!

chaochen99 avatar Feb 24 '22 09:02 chaochen99

Thank you,

For observation 1, we find representing sentences by [CLS] token or averaging is not efficient. By reformulating the sentence embedding task as the mask language task, we can efficiently use the original BERT layers by leveraging the large-scale knowledge.

For observation 2, we learn to directly representing sentences from [MASK] tokens, rather than weighted average of these token embeddings according to probability distribution (Eq 4 in paper).

kongds avatar Feb 24 '22 10:02 kongds

Thanks for your reply,

At the same time, I only found the base version of the model. I would like to know how the prompt bert performs in bert-large and roberta-lage because we want to flow your work. could you share the checkpoint?

Looking forward to your reply.

chaochen99 avatar Feb 24 '22 13:02 chaochen99

For large model, i only trained unsupervised bert-large-uncased. But i can't find the checkpoint of it.
The result is as follows:

STS12 STS13 STS14 STS15 STS16 STSb SICK-R Avg.
75.52 87.67 79.24 85.33 80.57 83.00 73.01 80.62

The training command is as follows:

BC=(python -m torch.distributed.launch --nproc_per_node 4 train.py)
GPU=0,1,2,3
BATCH=64
MODEL=bert-large-uncased
LR=3e-5
EXP=unsup-bert-large
EPOCH=1
TEMPLATE="*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*"
ES=125 # --eval_steps
BMETRIC=stsb_spearman # --metric_for_best_model
TRAIN_FILE=data/wiki1m_for_simcse.txt
args=(--mlp_only_train --mask_embedding_sentence\
      --mask_embedding_sentence_delta\
      --mask_embedding_sentence_template "*cls*_This_sentence_of_\"*sent_0*\"_means*mask*.*sep+*")
CHECKPOINT=result/$EXP
CUDA_VISIBLE_DEVICES=$GPU ${BC[@]}\
              --model_name_or_path $MODEL\
              --train_file $TRAIN_FILE\
              --output_dir $CHECKPOINT\
              --num_train_epochs $EPOCH\
              --per_device_train_batch_size $BATCH \
              --learning_rate $LR \
              --max_seq_length 32\
              --evaluation_strategy steps\
              --metric_for_best_model $BMETRIC\
              --load_best_model_at_end\
              --eval_steps $ES\
              --overwrite_output_dir\
              --temp 0.05\
              --do_train\
              --fp16\
              --preprocessing_num_workers 10\
              ${args[@]}

kongds avatar Feb 24 '22 13:02 kongds

Here is my result using the parameters:

STS12 STS13 STS14 STS15 STS16 STSb SICK-R Avg.
72.98 86.66 79.24 85.33 80.48 82.41 72.09 79.88

torch==1.7.1+cu110 GPU is 3090 seed=42

Could you tell me what torch version and gpu you are using?

Looking forward to your reply!

chaochen99 avatar Feb 25 '22 06:02 chaochen99

I use 4 * 16gb V100 and torch==1.6.1+cu101 with apex. By the way, the result of bert-large is from the old codebase, which may be slightly different.

But I don't have V100 cards right now to verify that the gap comes from torch, apex or codebase.

kongds avatar Feb 25 '22 14:02 kongds

I find the trainer_state.json of bert-large, this might help.

{
  "best_metric": 0.8650828143261685,
  "best_model_checkpoint": "result/unsup-bert-large",
  "epoch": 1.0,
  "global_step": 3907,
  "is_hyper_param_search": false,
  "is_local_process_zero": true,
  "is_world_process_zero": true,
  "log_history": [
    {
      "epoch": 0.03,
      "eval_avg_sts": 0.8029055192313586,
      "eval_sickr_spearman": 0.7471528897173139,
      "eval_stsb_spearman": 0.8586581487454031,
      "step": 125
    },
    {
      "epoch": 0.06,
      "eval_avg_sts": 0.8109103797908428,
      "eval_sickr_spearman": 0.7589070128301474,
      "eval_stsb_spearman": 0.8629137467515381,
      "step": 250
    },
    {
      "epoch": 0.1,
      "eval_avg_sts": 0.8093174992808367,
      "eval_sickr_spearman": 0.7579198776333913,
      "eval_stsb_spearman": 0.8607151209282823,
      "step": 375
    },
    {
      "epoch": 0.13,
      "learning_rate": 2.6160737138469415e-05,
      "loss": 0.004,
      "step": 500
    },
    {
      "epoch": 0.13,
      "eval_avg_sts": 0.8045124402414239,
      "eval_sickr_spearman": 0.7583830895756604,
      "eval_stsb_spearman": 0.8506417909071873,
      "step": 500
    },
    {
      "epoch": 0.16,
      "eval_avg_sts": 0.7965214393405081,
      "eval_sickr_spearman": 0.7470930429649221,
      "eval_stsb_spearman": 0.8459498357160942,
      "step": 625
    },
    {
      "epoch": 0.19,
      "eval_avg_sts": 0.8055590122682623,
      "eval_sickr_spearman": 0.7558082382897661,
      "eval_stsb_spearman": 0.8553097862467586,
      "step": 750
    },
    {
      "epoch": 0.22,
      "eval_avg_sts": 0.8155483810926036,
      "eval_sickr_spearman": 0.7704474455726826,
      "eval_stsb_spearman": 0.8606493166125245,
      "step": 875
    },
    {
      "epoch": 0.26,
      "learning_rate": 2.232147427693883e-05,
      "loss": 0.0005,
      "step": 1000
    },
    {
      "epoch": 0.26,
      "eval_avg_sts": 0.8133897854649895,
      "eval_sickr_spearman": 0.7663550996679526,
      "eval_stsb_spearman": 0.8604244712620265,
      "step": 1000
    },
    {
      "epoch": 0.29,
      "eval_avg_sts": 0.8196710191935307,
      "eval_sickr_spearman": 0.7747364788378018,
      "eval_stsb_spearman": 0.8646055595492596,
      "step": 1125
    },
    {
      "epoch": 0.32,
      "eval_avg_sts": 0.8218381673513775,
      "eval_sickr_spearman": 0.7785935203765866,
      "eval_stsb_spearman": 0.8650828143261685,
      "step": 1250
    },
    {
      "epoch": 0.35,
      "eval_avg_sts": 0.8165785131017016,
      "eval_sickr_spearman": 0.7717390018903535,
      "eval_stsb_spearman": 0.8614180243130498,
      "step": 1375
    },
    {
      "epoch": 0.38,
      "learning_rate": 1.8482211415408245e-05,
      "loss": 0.0005,
      "step": 1500
    },
    {
      "epoch": 0.38,
      "eval_avg_sts": 0.8047818318689836,
      "eval_sickr_spearman": 0.7539911256601595,
      "eval_stsb_spearman": 0.8555725380778076,
      "step": 1500
    },
    {
      "epoch": 0.42,
      "eval_avg_sts": 0.7982487502134017,
      "eval_sickr_spearman": 0.7442794290737788,
      "eval_stsb_spearman": 0.8522180713530247,
      "step": 1625
    },
    {
      "epoch": 0.45,
      "eval_avg_sts": 0.799967337700157,
      "eval_sickr_spearman": 0.7443161248352774,
      "eval_stsb_spearman": 0.8556185505650366,
      "step": 1750
    },
    {
      "epoch": 0.48,
      "eval_avg_sts": 0.801522708223845,
      "eval_sickr_spearman": 0.7457060008175922,
      "eval_stsb_spearman": 0.8573394156300977,
      "step": 1875
    },
    {
      "epoch": 0.51,
      "learning_rate": 1.4642948553877656e-05,
      "loss": 0.0004,
      "step": 2000
    },
    {
      "epoch": 0.51,
      "eval_avg_sts": 0.7970257249600163,
      "eval_sickr_spearman": 0.7415913684817947,
      "eval_stsb_spearman": 0.8524600814382378,
      "step": 2000
    },
    {
      "epoch": 0.54,
      "eval_avg_sts": 0.8056394392946331,
      "eval_sickr_spearman": 0.7538111050919695,
      "eval_stsb_spearman": 0.8574677734972966,
      "step": 2125
    },
    {
      "epoch": 0.58,
      "eval_avg_sts": 0.8038362240185967,
      "eval_sickr_spearman": 0.7490150555200695,
      "eval_stsb_spearman": 0.8586573925171238,
      "step": 2250
    },
    {
      "epoch": 0.61,
      "eval_avg_sts": 0.7985717984461609,
      "eval_sickr_spearman": 0.7436228919482213,
      "eval_stsb_spearman": 0.8535207049441006,
      "step": 2375
    },
    {
      "epoch": 0.64,
      "learning_rate": 1.080368569234707e-05,
      "loss": 0.0004,
      "step": 2500
    },
    {
      "epoch": 0.64,
      "eval_avg_sts": 0.8017446914529242,
      "eval_sickr_spearman": 0.74755755174693,
      "eval_stsb_spearman": 0.8559318311589184,
      "step": 2500
    },
    {
      "epoch": 0.67,
      "eval_avg_sts": 0.806035852944099,
      "eval_sickr_spearman": 0.7568251047383123,
      "eval_stsb_spearman": 0.8552466011498855,
      "step": 2625
    },
    {
      "epoch": 0.7,
      "eval_avg_sts": 0.7931850193859609,
      "eval_sickr_spearman": 0.7340857884155682,
      "eval_stsb_spearman": 0.8522842503563536,
      "step": 2750
    },
    {
      "epoch": 0.74,
      "eval_avg_sts": 0.7977208601703414,
      "eval_sickr_spearman": 0.7462593671372607,
      "eval_stsb_spearman": 0.8491823532034221,
      "step": 2875
    },
    {
      "epoch": 0.77,
      "learning_rate": 6.964422830816484e-06,
      "loss": 0.0005,
      "step": 3000
    },
    {
      "epoch": 0.77,
      "eval_avg_sts": 0.8073046774483807,
      "eval_sickr_spearman": 0.7641048905966784,
      "eval_stsb_spearman": 0.8505044643000832,
      "step": 3000
    },
    {
      "epoch": 0.8,
      "eval_avg_sts": 0.8022302732886464,
      "eval_sickr_spearman": 0.7558746172719535,
      "eval_stsb_spearman": 0.8485859293053393,
      "step": 3125
    },
    {
      "epoch": 0.83,
      "eval_avg_sts": 0.8059561544995988,
      "eval_sickr_spearman": 0.7599832457200711,
      "eval_stsb_spearman": 0.8519290632791264,
      "step": 3250
    },
    {
      "epoch": 0.86,
      "eval_avg_sts": 0.7939477153857237,
      "eval_sickr_spearman": 0.7442896596983852,
      "eval_stsb_spearman": 0.8436057710730622,
      "step": 3375
    },
    {
      "epoch": 0.9,
      "learning_rate": 3.125159969285897e-06,
      "loss": 0.0004,
      "step": 3500
    },
    {
      "epoch": 0.9,
      "eval_avg_sts": 0.8054149775684041,
      "eval_sickr_spearman": 0.75679162706061,
      "eval_stsb_spearman": 0.8540383280761983,
      "step": 3500
    },
    {
      "epoch": 0.93,
      "eval_avg_sts": 0.8034251823040465,
      "eval_sickr_spearman": 0.7553765827811417,
      "eval_stsb_spearman": 0.8514737818269512,
      "step": 3625
    },
    {
      "epoch": 0.96,
      "eval_avg_sts": 0.8037347568705746,
      "eval_sickr_spearman": 0.7554017510782953,
      "eval_stsb_spearman": 0.8520677626628539,
      "step": 3750
    },
    {
      "epoch": 0.99,
      "eval_avg_sts": 0.8021041735230073,
      "eval_sickr_spearman": 0.75362008540155,
      "eval_stsb_spearman": 0.8505882616444647,
      "step": 3875
    },
    {
      "epoch": 1.0,
      "step": 3907,
      "train_runtime": 6852.5932,
      "train_samples_per_second": 0.57
    }
  ],
  "max_steps": 3907,
  "num_train_epochs": 1,
  "total_flos": 169441187714039808,
  "trial_name": null,
  "trial_params": null
}

kongds avatar Feb 25 '22 14:02 kongds

你好,我也在读您的论文后关于observation 1有一些疑问。 1.文中用公式(1)衡量维基百科中的100,000条句子的句向量编码的各向异性程度(我理解就是不均匀程度)时,这10万条句子的主题分布是怎么样的呢?因为如果存在某一个主题的句子占比过大,那么一个好的模型(同时满足alignment & uniformity)就会倾向于把它们映射到语义空间的邻域,公式(1)的值就会偏大。这样的话表1的结论可能就会不太严谨🧐 2.使用prompt 模板中[mask]位的对应向量作为sentence embedding和之前研究常用的[cls]位作为sentence embedding我理解似乎原理差不多,为什么prompt方法能帮助缓解bias呢?

期待您的回复,感谢!

Yubo8Zhang avatar Feb 25 '23 16:02 Yubo8Zhang

感谢关注我们的论文

  1. 我们用的100,000的句子是随机抽取的,且这个值计算是根据100,000*100,000的句子对平均而来的。即使存在100个相同主题的句子,这些句子对对于结果的影响也只有千分之一。
  2. 虽然使用[cls]也不会存在token bias的问题,但是[cls]主要的问题是没法直接利用原始预训练模型来表示句子表征。prompt和cls最主要的区别是直接利用template来帮助预训练模型表示句子。

kongds avatar Feb 26 '23 01:02 kongds