icefall icon indicating copy to clipboard operation
icefall copied to clipboard

`fast_beam_search()` returns a non-differentiable lattice + MWER training

Open desh2608 opened this issue 2 years ago • 10 comments

In the fast_beam_search() method, the lattice is eventually generated at: https://github.com/k2-fsa/icefall/blob/ffe816e2a8314318a4ef6d5eaba34b62b842ba3f/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py#L546. We do not explicitly pass logprobs here since these are stored on the arcs of the fsa during the decoding process.

If we look at the definition of this method in k2 here, we can see that the arc scores are made differentiable only if logprobs is explicitly passed to this function.

As a result, the generated lattice has non-differentible arc scores.

desh2608 avatar Jul 06 '23 20:07 desh2608

The differentiable lattice is for MWER loss (https://github.com/k2-fsa/k2/blob/42e92fdd4097adcfe9937b4d2df7736d227b8e85/k2/python/k2/mwer_loss.py). For decoding, I think it is no need to return a differentiable lattice.

See also this PR https://github.com/k2-fsa/k2/pull/1094

pkufool avatar Jul 07 '23 06:07 pkufool

Thanks, I'm indeed trying to use it for MWER loss.

desh2608 avatar Jul 07 '23 06:07 desh2608

@pkufool I am not 100% following here.. but is it possible to add an option to fast_beam_search() to make it differentiable?

danpovey avatar Jul 07 '23 06:07 danpovey

@danpovey It makes the lattice scores differentiable if you explicitly pass in the logprobs, which can be done similar to this code from @glynpu.

desh2608 avatar Jul 07 '23 07:07 desh2608

OK, but I am wondering what is the downside of making fast_beam_search() just do this automatically if the input lattice has scores.requires_grad = True.

danpovey avatar Jul 07 '23 07:07 danpovey

OK, but I am wondering what is the downside of making fast_beam_search() just do this automatically if the input lattice has scores.requires_grad = True.

Will to see if we can simplify the usage.

pkufool avatar Jul 08 '23 01:07 pkufool

I am trying to fine-tune an RNN-T model with MWER loss. I created an LG decoding graph, and obtained lattices using fast_beam_search():

lattice = fast_beam_search(
    model=self,
    decoding_graph=decoding_graph,
    encoder_out=encoder_out,
    encoder_out_lens=x_lens,
    beam=4,
    max_states=64,
    max_contexts=8,
    temperature=1.0,
    ilme_scale=0.0,
    allow_partial=True,
    blank_penalty=0.1,
    requires_grad=True,
)

(Note that I added a requires_grad argument to fast_beam_search which just ensures that all arc scores are tracked.)

I obtained ref_texts from the lexicon's symbol table as:

oov_id = word_table["<unk>"]
y = []
for text in texts:
    word_ids = []
    for word in text.split():
        if word in word_table:
            word_ids.append(word_table[word])
        else:
            word_ids.append(oov_id)
    y.append(word_ids)

I then call the k2.mwer_loss with the lattice and ref_texts:

with torch.cuda.amp.autocast(enabled=False):
    mbr_loss = k2.mwer_loss(
        lattice=lattice,
        ref_texts=y,
        nbest_scale=0.5,
        num_paths=200,
        temperature=1.0,
        reduction="sum",
        use_double_scores=True,
    )

However, the MWER loss seems to be increasing during training: image

Does the above strategy look okay or am I missing something?

desh2608 avatar Jul 10 '23 08:07 desh2608

A loss that increases and then stays roughly constant is what you might expect if parameter noise was a problem (due to too-high learning rate). Perhaps the learning rate is higher than the final learning rate of the base system? Also the MWER loss is much noisier than the regular loss, so might require a lower learning rate.

danpovey avatar Jul 10 '23 09:07 danpovey

I realized that I was "continuing training" from the last checkpoint, instead of initializing a new model's parameters from the pre-trained checkpoint, due to which the optimizer/scheduler states were carrying over. I fixed this and reduced LR to 0.0004. I also set nbest_scale=0.1, temperature=0.5 in the mwer_loss to have more unique paths in the lattice (although I'm not sure if that would be useful).

Now I see that the training loss is going down, at least until now (I'm only at a few hundred training steps). Training is only about 3x slower than regular RNN-T, which is not bad I suppose.

image

desh2608 avatar Jul 10 '23 12:07 desh2608

I have been working on MWER training for transducers. The training loss improves, but the validation loss gets worse, and I find that the resulting WER is also much worse than the model I initialized with. Here are the train/val curves: https://tensorboard.dev/experiment/L8QIjliVSQm6kRCv1Cghmw/#scalars

Here are the decoding results on TED-LIUM dev using greedy search. The first model was trained with pruned_rnnt loss, and then I used this model to initialize the MWER training (second row).

training ins del sub WER
rnnt 1.17 2.11 4.43 7.71
mwer 0.69 10.42 7.21 18.33

I used a base LR of 0.0004 for MWER training. For lattice generation, I use beam=4, max_states=64, max_contexts=8, blank_penalty=0.1 and for the MWER loss computation, I used num_paths=200, nbest_scale=0.1, temperature=0.5. I am wondering which of these values should I change to avoid over-fitting?

(BTW, it seems most of the sentences have some del/sub errors at the beginning of the sentence.)

desh2608 avatar Jul 19 '23 08:07 desh2608