whisper.cpp icon indicating copy to clipboard operation
whisper.cpp copied to clipboard

Reduce memory usage during Whisper inference

Open ggerganov opened this issue 1 year ago • 0 comments

The idea is to avoid keeping all intermediate tensors of the computation graph by introducing "scratch" buffers to ggml https://github.com/ggerganov/whisper.cpp/issues/272#issuecomment-1353739709

I initially thought it would be enough to just keep the last 2 intermediate tensors at each point. However, it's not the case since we have operations like this:

cur = ggml_add(ctx0,
        ggml_repeat(ctx0,
            model.e_conv_2_b,
            cur),
        cur);

The tensor cur is used to create 2 new intermediate tensors. So we need to keep more than 2 tensors in the "scratch" buffer.

Initial results

Using scratch buffers during inference we reduce the total memory usage for the base model from 500 MB to just 213 MB. As an extra bonus, the decoder seems to be about %30 faster on M1 Pro without any loss of precision compared to master.

The main drawback is that the scratch buffer selection is currently done manually in whisper.cpp. It makes the code quite unreadable and very error-prone. I think it can be automated by analysing the nodes in the created compute graphs and assigning them to the correct scratch buffers, but the assignment algorithm is not trivial to implement and it would need some major refactoring in ggml. For now I think it would be better to just clean-up the code a little bit and wait to see if some better idea pops up.

Memory usage change:

Model Disk Mem (Old) Mem (New)
tiny 75 MB ~390 MB ~125 MB
base 142 MB ~500 MB ~210 MB
small 466 MB ~1.0 GB ~600 MB
medium 1.5 GB ~2.6 GB ~1.7 GB
large 2.9 GB ~4.7 GB ~3.3 GB

Development notes:

  • ~Cannot use ggml_cpy with scratch tensors~
  • ~Special-cased constant ggml tensors - need a better fix~
  • We now only compute the logits for the last token in whisper_decode()
  • ~Use different scratch buffers for every other layer?~

ggerganov avatar Jan 19 '23 19:01 ggerganov