mistral.rs
mistral.rs copied to clipboard
Batched & chunked prefill
Similar to what was described here https://github.com/huggingface/candle/issues/2108
"When prompts get longer than trivial sizes, the memory usage spikes as the prompt is thrown into one Tensor and sent off to a forward pass in the model at whatever length it comes in as. These spikes can be reduced by processing the batch in chunks."
There's a candle implementation here https://github.com/huggingface/candle/pull/2111
Let's say we configure a setting batch_size = 512.
The scheduler would need to be aware of it and only schedule 2 prompts if they're less than 512 tokens combined.
And the engine should be aware of it and if a sequence is larger than 512 tokens, split it.
To reproduce it locally, use the benchmark with a high enough -p
and you get an OOM
./mistralrs-bench -p 2048 -g 0 -r 1 -c 1 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/
Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf
2024-04-26T20:17:25.483829Z ERROR mistralrs_core::engine: prompt - Model failed with error: Cuda(Cuda(DriverError(CUDA_ERROR_OUT_OF_MEMORY, "out of memory")))
But generating this same amount of tokens work
./mistralrs-bench -p 0 -g 2048 -r 1 -c 1 gguf -t mistralai/Mistral-7B-Instruct-v0.1 -m TheBloke/
Mistral-7B-Instruct-v0.1-GGUF -f mistral-7b-instruct-v0.1.Q4_K_M.gguf
+------------------------------------+---------+---------+--------------+--------------+-------------+--------------+
| model | backend | test | t/s | ms/t | concurrency | throughput/s |
+------------------------------------+---------+---------+--------------+--------------+-------------+--------------+
| mistralai/Mistral-7B-Instruct-v0.1 | CUDA | tg 2048 | 26.297±0.000 | 38.027±0.000 | 1 | 26.296867 |
+------------------------------------+---------+---------+--------------+--------------+-------------+--------------+
@lucasavila00, this looks great. It'll require modifying the attention mask calculation of every model, so it may be helpful to factor those out into a layers.rs
in mistralrs-core
.
@lucasavila00, I am actually going to end up adding this in #242.