burn
burn copied to clipboard
implement ctc loss function
I need ctc loss function in CRNN model. I tried to implement it based on PyTorch implementation, but the results obtained after calling forward() are somewhat different from PyTorch's.
I don't know what went wrong, I'd appreciate it if someone could tell me.
Reference
- PyTorch implementation
- Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks
Checklist
- [x] Confirmed that
run-checks all
script has been executed.
Hi, I can take a look at it later today
I believe the result is now the same as PyTorch. But the performance of this implementation seems to be less than ideal.
The implementation doesn't work on the NdArray backend because of #1053. It also doesn't work on the LibTorch backend because of #1055.
I believe the current performance bottleneck lies in creating the one-hot. This is because the repeat()
method is very slow on the Wgpu backend.
https://github.com/tracel-ai/burn/blob/712185292c8fb79a6fbc73d55ed719bbf7859a92/burn-core/src/nn/loss/ctc.rs#L330-L341
Hi @wcshds I haven't had time like I thought last week and then I was abroad for several days. I'm sorry I said I was gonna look at it last week, but I certainly haven't forgotten you! Glad to see you continued working on it since then. I will definitely take a look real soon
@louisfd Thank you! The current implementation still significantly consumes graphics memory. I believe that separately calculating the alpha values for blanks and letters can significantly reduce the graphics memory usage, but I don't know how to implement it.
@wcshds
I took your word that repeat was the bottleneck in wgpu. This made a lot of sense because we relied on the default implementation which launches as many slice_assign
kernels as there are repetitions. For large times
argument this is awful.
I wrote a repeat kernel so that only one kernel is launched instead of times
: #1068
Please tell me if this is better now
@louisfd Thank you! Now repeat()
is much faster.
I tried to use this implementation of ctc loss in the CRNN model, but after the first iteration loss became NaN. I don't know what went wrong. wcshds/crnn-cjk
Just noticed 1-e15 magic number. Please refactor to a constant and explain how this number is derived. It would also be preferable if float number precision independent (we use half and full precisions)
@antimora I just need a small value to prevent log(0), so now I think it may not be necessary to use 1e-15; 1e-5 should be small enough. However, I think CTC Loss may not be suitable for half precision because I previously attempted to use mixed precision training in PyTorch, but PyTorch's CTC Loss does not support fp16. [CTC Loss] CTC Loss not support float16? Perhaps I need to explore the use of half precision training in future practices to see if CTC Loss can work with it.
Closing this ticket and linking to an issue ticket so someone else can pick up: https://github.com/tracel-ai/burn/issues/1536