espnet
espnet copied to clipboard
RNN-T Decoding - Large Number of Deletions Compared to Transformer/Conformer
I am getting a lot of deletions in my RNN-T training/decoding setup relative to Transformer/Conformer. The data is the "Malach" corpus; about 200 hours of English but accented speech from Holocaust survivors. Appreciate any insights/suggestions anyone may have!
This is the sclite outputs from all three systems:
Condition | # Snt | # Wrd | Corr | Sub | Del | Ins | Err | S.Err |
---|---|---|---|---|---|---|---|---|
Transformer | 1155 | 12256 | 82.0 | 13.6 | 4.3 | 4.8 | 22.8 | 68.1 |
RNN-T | 1155 | 12256 | 70.4 | 13.5 | 16.1 | 2.7 | 32.3 | 73.9 |
Conformer | 1155 | 12256 | 81.8 | 12.9 | 5.3 | 4.1 | 22.3 | 68.4 |
I also attached training and decoding configs for the RNN-T.
This is the training config for the RNNT:
# The conformer transducer training configuration from @jeon30c
# WERs for test-clean/test-other are 2.9 and 7.2, respectively.
# Trained with Tesla V100-SXM2(32GB) x 8 GPUs. It takes about 1.5 days.
batch_type: numel
batch_bins: 20000000
accum_grad: 2
max_epoch: 100
patience: none
init: none
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 10
model_conf:
ctc_weight: 0.0
report_cer: False
report_wer: False
encoder: conformer
encoder_conf:
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
normalize_before: true
macaron_style: true
pos_enc_layer_type: "rel_pos"
selfattention_layer_type: "rel_selfattn"
activation_type: "swish"
use_cnn_module: true
cnn_module_kernel: 31
decoder: transducer
decoder_conf:
rnn_type: lstm
num_layers: 1
hidden_size: 512
dropout: 0.1
dropout_embed: 0.1
joint_net_conf:
joint_space_size: 640
optim: adam
optim_conf:
lr: 0.0015
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
frontend_conf:
n_fft: 512
hop_length: 160
specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 30
num_freq_mask: 2
apply_time_mask: true
time_mask_width_range:
- 0
- 40
num_time_mask: 2
This is the decoding config:
# The conformer transducer decoding configuration from @jeon30c
beam_size: 10
transducer_conf:
search_type: default
score_norm: True
Interesting one! In general, the number of insertions and deletions is somewhat balanced but several things can cause an high number of deletions.
Some initial questions:
- Did you use the pre-trained Librispeech model to finetune Malach model (or directly perform decoding to get these results)?
- What are
Transformer
andConformer
? Are they equivalent to Transformer and Conformer CTC-Att for Librispeech? - Are you using an external language model? If yes, which one and what about decoding without it?
- Did you try other decoding methods? I don't expect them to "fix" the issue but it can give some hints?
Also, if possible, could you please share the result.txt
for the Transducer model?
- No, I trained all of the models from scratch. I did not use any pre-trained models.
- Transformer was from the AMI egs2 recipe; conformer I copied from librispeech.
- yes, in all cases, I believe. They were all from train_lm_transformer2.yaml which was in the AMI recipe. I could redecode without it but you really think that would help?
- what other decoding meth result.txt ods do you mean? Like a kaldi model?
Transformer was from the AMI egs2 recipe; conformer I copied from librispeech. result.txt
Thanks, I'll take a closer look when I have time. From what I've seen, about 20% of the utterances account for most of the deletions. I'll also check the Malach dataset as I'm not familiar with it.
yes, in all cases, I believe. They were all from train_lm_transformer2.yaml which was in the AMI recipe. I could redecode without it but you really think that would help?
I think it could cause an higher number of deletions in some cases, yes. Although, it's unlikely to have such impact.
what other decoding methods do you mean? Like a kaldi model?
We have various decoding strategies for Transducer model outside default beam search, see https://espnet.github.io/espnet/tutorial.html#inference
Thanks. Here is a good reference on MALACH https://www.isca-speech.org/archive_v0/Interspeech_2019/pdfs/1907.pdf
I tried decoding with a beam size of 60. It increased the deletion rate to 30%. Happy to try other methods but I need some parameter recommendations. Also, is there some way on the default beam search to change the language model weight? Maybe I should lower that.
Thanks Michael
Happy to try other methods but I need some parameter recommendations.
Sure, you can start with these ones:
search_type: alsd # (or maes)
u_max: 250
nstep: 3
prefix_alpha: 2
expansion_gamma: 2
expansion_beta: 2.3
And try increasing/decreasing u_max
(for ALSD) and nstep
(for mAES). It controls expansion along either the label axis or time axis.
Also, is there some way on the default beam search to change the language model weight?
You can set lm_weight: x.x
in your decode config. Not sure what's the default value in this version.
Sorry for the long delay in responding. It looks like transformer LMs are not supported for RNNTs:
Traceback (most recent call last):
File "/ext3/miniconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/ext3/miniconda3/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/vast/map22/espnet/espnet2/bin/asr_inference.py", line 707, in <module>
main()
File "/vast/map22/espnet/espnet2/bin/asr_inference.py", line 703, in main
inference(**kwargs)
File "/vast/map22/espnet/espnet2/bin/asr_inference.py", line 456, in inference
speech2text = Speech2Text.from_pretrained(
File "/vast/map22/espnet/espnet2/bin/asr_inference.py", line 368, in from_pretrained
return Speech2Text(**kwargs)
File "/vast/map22/espnet/espnet2/bin/asr_inference.py", line 163, in __init__
beam_search_transducer = BeamSearchTransducer(
File "/vast/map22/espnet/espnet2/asr/transducer/beam_search_transducer.py", line 109, in __init__
raise NotImplementedError
NotImplementedError
Code:
elif search_type == "alsd":
if isinstance(lm, TransformerLM):
raise NotImplementedError
Hi,
Yes, we removed support for Transformer LM because in current form (or at least when we made this choice), it was not worth supporting it in terms of ER-RTF/latency versus RNN-LM. I don't recall the technical reasons though, let me check.
I don't mind re-enabling Transformer LM but IMO we should rethink ASR-LM fusion for Transducer or/and how we include external linguistical information. I've some work in progress for that subject but it's without the use of external LM.
thanks. can you remind me which config is used for just the rnn lm? is this train_lm.yaml?
thanks. can you remind me which config is used for just the rnn lm? is this train_lm.yaml?
I'm not sure which recipe config file you're referring to but this one is the LM config for Librispeech. You can use it as a reference I guess.
LM training runs for a few iterations and then bombs out with this message (at bottom). Any idea what this means?
[ga014] 2023-01-03 23:17:19,407 (trainer:704) INFO: 1epoch:train:155-176batch: iter_time=0.027, forward_time=0.039, loss=3.487, backward_time=0.041, optim_step_time=0.003, optim0_lr0=0.001, train_time=0.119
Traceback (most recent call last):
File "/ext3/miniconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/ext3/miniconda3/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/vast/map22/espnet/espnet2/bin/lm_train.py", line 22, in <module>
main()
File "/vast/map22/espnet/espnet2/bin/lm_train.py", line 18, in main
LMTask.main(cmd=cmd)
File "/vast/map22/espnet/espnet2/tasks/abs_task.py", line 1013, in main
cls.main_worker(args)
File "/vast/map22/espnet/espnet2/tasks/abs_task.py", line 1309, in main_worker
cls.trainer.run(
File "/vast/map22/espnet/espnet2/train/trainer.py", line 281, in run
all_steps_are_invalid = cls.train_one_epoch(
File "/vast/map22/espnet/espnet2/train/trainer.py", line 615, in train_one_epoch
loss.backward()
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/ext3/miniconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: unique_by_key: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered
# Accounting: time=58 threads=1
# Ended (code 1) at Tue Jan 3 23:17:24 EST 2023, elapsed time 58 seconds
Hum, I'm not sure about this one. Could you try with another (downgraded) CUDA version please?
ccing @kamo-naoyuki as he may knows what's the issue here.
I had a 4 gpu machine but only specified ngpu=1 (I thought it detected this automatically). When I changed this to ngpu=4 it seemed to work and is now training the transducer.
I notice that CER and WER is not reported during training. Not sure why. When I use the same config to train a transformer or conformer it all works find (i.e., it produces these two images).
I had a 4 gpu machine but only specified ngpu=1 (I thought it detected this automatically). When I changed this to ngpu=4 it seemed to work and is now training the transducer.
Good to know, thanks for feedback!
I notice that CER and WER is not reported during training. Not sure why. When I use the same config to train a transformer or conformer it all works find (i.e., it produces these two images).
report_cer
and report_wer
are set to False
by default. You need to enable it in your training config:
model_conf:
...
report_cer: True
report_wer: True
Btw, just to be sure I don't give wrong information: you're using the Transducer version under asr
, right?
Yes, thanks!
I trained the rnn lm and used the "alsd" method for decoding. Results worse with a lot of repeated words (no no no, yes yes yes, etc). Training results look "reasonable". Any ideas or suggestions? Happy to send you any logs, etc. if that would be helpful. Just trying to write a paper and include some RNNT results to compare to conformer and transformer.
When I switched to an lm weight of .1 and switched back to the original search, it brought the WER down to 24.5 with the RNN LM which is much better.
When I switched to an lm weight of .1 and switched back to the original search, it brought the WER down to 24.5 with the RNN LM which is much better.
Is it a token-level LM or a word-level LM?
Also, what's the WER with ALSD
? It should yield similar performance to the default beam search, even with an LM. I didn't use this Transducer version in a while, it may be due to a bug.
word level. The ALSD results with the original LM weight were so bad (45% WER) I did not try them with the new LM weight.