blog icon indicating copy to clipboard operation
blog copied to clipboard

Finetuning Whisper: validation loss keeps decreasing while validation WER increases

Open anderleich opened this issue 2 years ago • 8 comments

Hi, I've followed the this blog post https://huggingface.co/blog/fine-tune-whisper to finetune Whisper with my own dataset.

Everything seems to be working as expected. However, I've noticed a strange behaviour. While validation loss keeps decreasing at every checkpoint, validation WER increases. As a consequence, the best model in terms of WER is obtained at early stages of the finetuning process. At inference time, the last saved checkpoint outperforms the best WER checkpoint though.

Do you have any clues of what is happening?

anderleich avatar Mar 14 '23 17:03 anderleich

image image image

anderleich avatar Mar 20 '23 16:03 anderleich

cc @Vaibhavs10 @sanchit-gandhi

osanseviero avatar Mar 23 '23 14:03 osanseviero

Hey @anderleich! Sorry for the late reply here. Could you share your training arguments so we can get a feel for the kind of set-up you're employing? And also your compute_metrics function? The fact that the eval loss is going down suggests that we're predicting the correct tokens, but not decoding them to words properly with the tokenizer.

Even better would be a Colab link that we could look through and reproduce ourselves :)

sanchit-gandhi avatar Apr 04 '23 15:04 sanchit-gandhi

I had the same issue and my problem was enabling bf16=True, instead of fp16=True, in Seq2SeqTrainingArguments

ammaraldirawi avatar Apr 07 '23 10:04 ammaraldirawi

Hi @sanchit-gandhi ,

Thanks for your response!

Indeed, just looking at the loss it seems the model is learning. In fact, it improves the 'openai/whisper-small` baseline when tested on several custom tests. So, the training process is working properly.

This is my compute_metrics function:

metric = evaluate.load("wer")

def compute_metrics(pred):
	pred_ids = pred.predictions
	label_ids = pred.label_ids

	# replace -100 with the pad_token_id
	label_ids[label_ids == -100] = processors.tokenizer.pad_token_id

	# we do not want to group tokens when computing the metrics
	pred_str = processors.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
	label_str = processors.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

	wer = 100 * metric.compute(predictions=pred_str, references=label_str)

	return {"wer": wer}

And these are my training arguments:

training_args = Seq2SeqTrainingArguments(
	output_dir=OUTPUT_MODEL_DIR,
	per_device_train_batch_size=batch_size,
	gradient_accumulation_steps=gradient_accumulation,
	learning_rate=1e-5,
	warmup_steps=500,
	max_steps=20000,
	gradient_checkpointing=False,
	fp16=True,
	evaluation_strategy="steps",
	per_device_eval_batch_size=8,
	predict_with_generate=True,
	#generation_max_length=225,
	save_steps=1000,
	eval_steps=1000,
	#eval_delay=1000,
	logging_steps=25,
	report_to=["tensorboard"],
	load_best_model_at_end=True,
	metric_for_best_model="wer",
	greater_is_better=False,
	push_to_hub=False,
	dataloader_num_workers=20
)

anderleich avatar Apr 14 '23 11:04 anderleich

Thanks for sharing those additional details @anderleich! I can't see anything that looks wrong based on these, and given what you've said about your tests it certainly sounds like the model is learning correctly.

What I would suggest doing is removing these two args from the Seq2SeqTrainingArguments:

-	metric_for_best_model="wer",
-	greater_is_better=False,

What will then happen is that Trainer will select the model with the lowest eval loss as your 'best' model at the end of training (rather than eval WER).

Based on your logs, this should select for you the actual best model.

sanchit-gandhi avatar Apr 21 '23 15:04 sanchit-gandhi

@anderleich I tried fine-tuning on LibriSpeech recently and realized that for WER computation (in the compute_metric function) you should normalize text (as is done in the model-card as an example)

pred_str  = model.tokenizer._normalize(pred_str)
label_str = model.tokenizer._normalize(label_str)

Hope it helps!

iamgroot42 avatar Jun 14 '23 18:06 iamgroot42

There's another example here for normalising that you can use: https://github.com/huggingface/community-events/blob/a2d9115007c7e44b4389e005ea5c6163ae5b0470/whisper-fine-tuning-event/run_speech_recognition_seq2seq_streaming.py#L514-L532

As well as some other valuable resources for fine-tuning: https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event#tips-and-tricks

sanchit-gandhi avatar Apr 02 '24 14:04 sanchit-gandhi