burn icon indicating copy to clipboard operation
burn copied to clipboard

implement ctc loss function

Open wcshds opened this issue 1 year ago • 10 comments

I need ctc loss function in CRNN model. I tried to implement it based on PyTorch implementation, but the results obtained after calling forward() are somewhat different from PyTorch's.

I don't know what went wrong, I'd appreciate it if someone could tell me.

Reference

Checklist

  • [x] Confirmed that run-checks all script has been executed.

wcshds avatar Dec 04 '23 16:12 wcshds

Hi, I can take a look at it later today

louisfd avatar Dec 04 '23 17:12 louisfd

I believe the result is now the same as PyTorch. But the performance of this implementation seems to be less than ideal.

wcshds avatar Dec 05 '23 05:12 wcshds

The implementation doesn't work on the NdArray backend because of #1053. It also doesn't work on the LibTorch backend because of #1055.

I believe the current performance bottleneck lies in creating the one-hot. This is because the repeat() method is very slow on the Wgpu backend. https://github.com/tracel-ai/burn/blob/712185292c8fb79a6fbc73d55ed719bbf7859a92/burn-core/src/nn/loss/ctc.rs#L330-L341

wcshds avatar Dec 08 '23 10:12 wcshds

Hi @wcshds I haven't had time like I thought last week and then I was abroad for several days. I'm sorry I said I was gonna look at it last week, but I certainly haven't forgotten you! Glad to see you continued working on it since then. I will definitely take a look real soon

louisfd avatar Dec 12 '23 18:12 louisfd

@louisfd Thank you! The current implementation still significantly consumes graphics memory. I believe that separately calculating the alpha values for blanks and letters can significantly reduce the graphics memory usage, but I don't know how to implement it.

wcshds avatar Dec 12 '23 19:12 wcshds

@wcshds I took your word that repeat was the bottleneck in wgpu. This made a lot of sense because we relied on the default implementation which launches as many slice_assign kernels as there are repetitions. For large times argument this is awful. I wrote a repeat kernel so that only one kernel is launched instead of times: #1068

Please tell me if this is better now

louisfd avatar Dec 14 '23 15:12 louisfd

@louisfd Thank you! Now repeat() is much faster.

wcshds avatar Dec 14 '23 18:12 wcshds

I tried to use this implementation of ctc loss in the CRNN model, but after the first iteration loss became NaN. I don't know what went wrong. wcshds/crnn-cjk

wcshds avatar Dec 14 '23 18:12 wcshds

Just noticed 1-e15 magic number. Please refactor to a constant and explain how this number is derived. It would also be preferable if float number precision independent (we use half and full precisions)

antimora avatar Dec 27 '23 13:12 antimora

@antimora I just need a small value to prevent log(0), so now I think it may not be necessary to use 1e-15; 1e-5 should be small enough. However, I think CTC Loss may not be suitable for half precision because I previously attempted to use mixed precision training in PyTorch, but PyTorch's CTC Loss does not support fp16. [CTC Loss] CTC Loss not support float16? Perhaps I need to explore the use of half precision training in future practices to see if CTC Loss can work with it.

wcshds avatar Dec 27 '23 14:12 wcshds

Closing this ticket and linking to an issue ticket so someone else can pick up: https://github.com/tracel-ai/burn/issues/1536

antimora avatar Mar 26 '24 22:03 antimora