DiffuSeq
DiffuSeq copied to clipboard
Issues with decoding and evaluation
trafficstars
Hi!
I am trying to replicate the DiffuSeq model for the Paraphrase task with the QQP dataset. I kept everything to the default training config, and for MBR I ran with 20 different random seeds during evaluation, but I still can't match the performance reported in the paper in Table 1.
For reference, this is the content of training_args.json file:
"lr": 0.0001,
"batch_size": 2048,
"microbatch": 64,
"learning_steps": 50000,
"log_interval": 20,
"save_interval": 10000,
"eval_interval": 1000,
"ema_rate": "0.9999",
"resume_checkpoint": "none",
"schedule_sampler": "lossaware",
"diffusion_steps": 2000,
"noise_schedule": "sqrt",
"timestep_respacing": "",
"vocab": "bert",
"use_plm_init": "no",
"vocab_size": 30522,
"config_name": "bert-base-uncased",
"notes": "test-qqp20231015-19:22:30",
"data_dir": "/scratch/ad6489/thesis/DiffuSeq/datasets/QQP",
"dataset": "qqp",
"checkpoint_path": "diffusion_models/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30",
"seq_len": 128,
"hidden_t_dim": 128,
"hidden_dim": 128,
"dropout": 0.1,
"use_fp16": false,
"fp16_scale_growth": 0.001,
"seed": 102,
"gradient_clipping": -1.0,
"weight_decay": 0.0,
"learn_sigma": false,
"use_kl": false,
"predict_xstart": true,
"rescale_timesteps": true,
"rescale_learned_sigmas": false,
"sigma_small": false,
"emb_scale_factor": 1.0
and this is the output of running evaluation python eval_seq2seq.py --folder ../{your-path-to-outputs} --mbr:
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed0_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 17.01 seconds, 146.95 sentences/sec
******************************
avg BLEU score 0.0004804176648035196
avg ROUGE-L score 0.0027964608892798422
avg berscore tensor(0.3128)
avg dist1 score 0.4627495097623505
avg len 14.924
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed100_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.93 seconds, 179.49 sentences/sec
******************************
avg BLEU score 0.000386986505770346
avg ROUGE-L score 0.002365110184252262
avg berscore tensor(0.3128)
avg dist1 score 0.46048387836817817
avg len 14.9936
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed102_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.20 seconds, 176.05 sentences/sec
******************************
avg BLEU score 0.00048498323549467804
avg ROUGE-L score 0.0026150280356407167
avg berscore tensor(0.3118)
avg dist1 score 0.46387627477460036
avg len 15.0372
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed103_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.95 seconds, 179.21 sentences/sec
******************************
avg BLEU score 0.000511365221514534
avg ROUGE-L score 0.0029585537880659103
avg berscore tensor(0.3123)
avg dist1 score 0.46248667034925856
avg len 14.8916
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed105_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.98 seconds, 178.78 sentences/sec
******************************
avg BLEU score 0.00045866246887226933
avg ROUGE-L score 0.002453571179509163
avg berscore tensor(0.3118)
avg dist1 score 0.46257998152512503
avg len 14.934
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed107_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.79 seconds, 181.33 sentences/sec
******************************
avg BLEU score 0.0005373413998520224
avg ROUGE-L score 0.00305726850181818
avg berscore tensor(0.3118)
avg dist1 score 0.4660346386579922
avg len 14.722
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed110_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.91 seconds, 179.77 sentences/sec
******************************
avg BLEU score 0.00045626781225977187
avg ROUGE-L score 0.0026337386369705202
avg berscore tensor(0.3110)
avg dist1 score 0.4587090466366367
avg len 14.8424
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed112_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.95 seconds, 179.24 sentences/sec
******************************
avg BLEU score 0.0004987962284970048
avg ROUGE-L score 0.002871700122952461
avg berscore tensor(0.3119)
avg dist1 score 0.46045974132362916
avg len 14.9172
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed115_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.00 seconds, 178.58 sentences/sec
******************************
avg BLEU score 0.0004678691109675817
avg ROUGE-L score 0.002757294401526451
avg berscore tensor(0.3127)
avg dist1 score 0.4632908610617951
avg len 14.8888
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed118_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.01 seconds, 178.41 sentences/sec
******************************
avg BLEU score 0.000402613081471179
avg ROUGE-L score 0.002492514155805111
avg berscore tensor(0.3128)
avg dist1 score 0.4586922658442217
avg len 14.9568
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed119_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.97 seconds, 179.01 sentences/sec
******************************
avg BLEU score 0.0003726691417880429
avg ROUGE-L score 0.0021837190836668015
avg berscore tensor(0.3127)
avg dist1 score 0.45847530479588566
avg len 14.8948
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed120_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.05 seconds, 177.90 sentences/sec
******************************
avg BLEU score 0.0004212587497725286
avg ROUGE-L score 0.002304764446616173
avg berscore tensor(0.3133)
avg dist1 score 0.46575454186534915
avg len 14.936
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed121_step0.json
calculating scores...
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.01 seconds, 178.47 sentences/sec
******************************
avg BLEU score 0.0004297394887193383
avg ROUGE-L score 0.0025662126049399376
avg berscore tensor(0.3116)
avg dist1 score 0.4609064204517209
avg len 14.9272
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed122_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.03 seconds, 178.15 sentences/sec
******************************
avg BLEU score 0.00045685374061348396
avg ROUGE-L score 0.002615842577815056
avg berscore tensor(0.3120)
avg dist1 score 0.46178081753269634
avg len 15.0528
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed123_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.97 seconds, 178.91 sentences/sec
******************************
avg BLEU score 0.0004793333330035246
avg ROUGE-L score 0.0029235000282526015
avg berscore tensor(0.3115)
avg dist1 score 0.46251716295186845
avg len 14.8484
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed128_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.01 seconds, 178.39 sentences/sec
******************************
avg BLEU score 0.0004840668444269398
avg ROUGE-L score 0.0028238809868693353
avg berscore tensor(0.3123)
avg dist1 score 0.461814826836216
avg len 15.056
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed132_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.04 seconds, 178.04 sentences/sec
******************************
avg BLEU score 0.0004966622597426265
avg ROUGE-L score 0.0028643640622496606
avg berscore tensor(0.3124)
avg dist1 score 0.46310948517487616
avg len 15.0036
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed156_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.03 seconds, 178.14 sentences/sec
******************************
avg BLEU score 0.0004998088099505953
avg ROUGE-L score 0.0030375902444124223
avg berscore tensor(0.3117)
avg dist1 score 0.46094208225337335
avg len 15.0072
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed46_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.05 seconds, 177.94 sentences/sec
******************************
avg BLEU score 0.0004565915330725618
avg ROUGE-L score 0.0026090403661131857
avg berscore tensor(0.3122)
avg dist1 score 0.4599867355476624
avg len 14.954
generation_outputs/diffuseq_qqp_h128_lr0.0001_t2000_sqrt_lossaware_seed102_test-qqp20231015-19:22:30/ema_0.9999_010000.pt.samples/seed90_step0.json
calculating scores...
computing bert embedding.
computing greedy matching.
done in 14.00 seconds, 178.52 sentences/sec
******************************
avg BLEU score 0.00047420615544432475
avg ROUGE-L score 0.0025192311719059945
avg berscore tensor(0.3112)
avg dist1 score 0.4604016924936326
avg len 14.9956
******************************
MBR...
******************************
calculating scores...
computing bert embedding.
computing greedy matching.
done in 13.77 seconds, 181.60 sentences/sec
******************************
avg BLEU score 0.00020953579590037182
avg ROUGE-l score 0.0013720075532794
avg berscore tensor(0.2995)
avg dist1 score 0.33858694581014626
Because the model ema_0.9999_010000.pt.samples hasn't converged yet. Try ema_0.9999_050000.pt.samples.
my bad! was able to achieve 22 BLEU in 50000 steps, thank you.