llama : simplify Mamba with advanced batch splits
As promised in https://github.com/ggerganov/llama.cpp/pull/7531#issuecomment-2209585346, I've been extracting the advanced batch splits out of the Jamba PR (#7531).
I've also backported the contiguous allocation of recurrent state slots, which makes it possible to also include the changes from #7531 which simplify the ggml operators used specifically for Mamba. Hopefully this isn't too much at once.
See https://github.com/ggerganov/llama.cpp/pull/7531#discussion_r1620997020 for an explanation of the batch splits.
Summary
ggml.c- Simplify
ggml_ssm_convandggml_ssm_scanby assuming batched sequences have the same number of new tokens, and that the states are contiguous and ordered correctly. - Allow
ggml_concatto work with a non-contiguous second argument.- The CPU implementation already supported this, but it was guarded with an assertion. Meanwhile, I think the CUDA implementation already supports this too, and does not prevent its usage (not totally sure), so I did not change it.
- Simplify
llama.cpp- Advanced batch splits handled with
lctx.sbatchfor persistent buffers- Refactor "helpers for smoother batch API transition", by handling them in
llama_sbatch, which allows avoiding repeated allocations by re-using the same buffers. - Simple batch splits should be equivalent to the previous behavior and are made with
lctx.sbatch.split_simple(n_tokens)to build allama_ubatchwith a max size ofn_tokens. - Equal-sequence-lengths splits are made with
lctx.sbatch.split_equal(n_tokens), and are used to simplify the operators of recurrent models. - Add
llama_ubatch. Similar tollama_batch, but aware of equal-length sequences.- Make
llama_set_inputs(and others) usellama_ubatchinstead ofllama_batch.
- Make
- Refactor "helpers for smoother batch API transition", by handling them in
- Make recurrent state slot allocation contiguous in
llama_kv_cache_find_slot - Add
llm_build_mambato build a Mamba block, used for Mamba, and will be used for Jamba - Add
llm_build_copy_mask_state(maybe not a good name) to abstract away the shuffling and masking of recurrent states. Used for Mamba, and it should be usable for other recurrent architectures too. - Simplify the sanity checks for
qs.n_attention_wvinllama_model_quantize_internalto make it future proof for hybrid models. - Reorder the outputs when using advanced batch splits like
split_equalin conjunction withllama_get_logits, because the API makes it so that the outputs should have the same order they had as the user-provided batch, not something based on batch split rules.
- Advanced batch splits handled with
For simplicity, this does not include the separation of the KV cache and the recurrent state cache. Both still use the same buffers (lctx.kv_self.k_l, and lctx.kv_self.v_l, as on master). The separation (necessary for hybrid models) will be introduced at the same time as Jamba.
TODO
- [x] Test the slot allocation of
llama_kv_cache_find_slotwith the--hellaswagbenchmark inllama-perplexitywith a Mamba model- This uses lots of parallel sequences in an unusual way, and so I think it's a great stress test.
- [x] Session file saving and reloading
- [x] Reloading needs to rebuild the
tailmetadata for recurrent states. (i.e. which cell is the end of which sequence) - [x] The server tests need to pass
- [x] Reloading needs to rebuild the
- [x] Make sure T5 still works
- [x] Make sure the pooled embeddings still work
- tested
bge-smallwithllama-embeddingswith parallel prompts with--pooling cls,--pooling lastand--pooling mean; results exactly matchmaster.
- tested
- [x] Make sure Gemma's sliding window mask still works
- [x] Decide whether to rename
llama_reorder_outputstollama_output_reorderand move it close tollama_output_reserve.- renamed and moved
Future ideas
- whole-sequence splits for embeddings
- handle pooling types like
clsandlastwithin theubatch.outputswhen splitting a batch;inp_clsis redundant withinp_out_ids.
- [x] I have read the contributing guidelines
- Self-reported review complexity:
- [x] Medium
Make sure Gemma's sliding window mask still works
The following command produces identical perplexity on master and this branch:
./llama-perplexity \
-m models/gemma-2-9b/ggml-model-f16.gguf \
-f build/wikitext-2-raw/wiki.test.raw \
-ngl 99 -c 8192
Is this enough to confirm the SWA functionality?
Is this enough to confirm the SWA functionality?
I think so. Might also be relevant to test SWA with parallel sequences too (I think this is what using a bigger -b (and -ub?) than -c does with llama-perplexity).
Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?
Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?
Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.
First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes.
For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.
I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.
Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?
Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.
First thing I'm noticing is the lack of metadata in the
config.jsonof Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case ofmamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in thestate-spaces/mambaimplementation, in which case I'll hardcode them too and/or find what is used to calculate them.I've also recently started to simplify the session file save & restore code in
llama.cpp(but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.
I also encountered difficulties running mamba-codestral. I tried to run this model on https://github.com/state-spaces/mamba. But there is no config.json in the model repository. mamba-codestral includes a new tokenizer v3. Although Mistral writes that the model can be run on state-spaces/mamba, nothing worked for me.
Please see the discussion here: https://github.com/NVIDIA/TensorRT-LLM/issues/1968 and a few hours ago an example for running mamba appeared: https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mamba
Maybe this will help development.
For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.
Yes, we can hardcode initially
I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.
Sounds good - a separate PR would be easier to review
Regarding Codestral - want to highlight again the comment by Mistral team about ngroups = 8: https://github.com/ggerganov/llama.cpp/issues/8519#issuecomment-2233110921. Seems important
Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?
Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.
First thing I'm noticing is the lack of metadata in the
config.jsonof Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case ofmamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in thestate-spaces/mambaimplementation, in which case I'll hardcode them too and/or find what is used to calculate them.I've also recently started to simplify the session file save & restore code in
llama.cpp(but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.
https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py
This includes some details that may be interesting for you.
I'll be re-running a few tests before merging this in hopefully less than 2 days. There is now both Mamba-2 and RWKV v6 which kind of need this to simplify the implementation.
Still, I don't want to accidentally have broken something with the batch splits, so I'll try to convince myself that there is no problem by running more tests.
I've ran some tests, and there's a problem: pooled embeddings with Mamba can't work with multiple sequences anymore.
This is because lctx.embd_seq is overwritten at each ubatch which makes it only work if everything fits in a single ubatch, which is not the case when sequences don't have the same length and are split to makes them all equal in Mamba's ubatch.
This could be fixed by letting causal embeddings be split over multiple ubatch. I'll try to find a way to do this cleanly.
Where the checkbox is checked, it means the behavior is the same as on master or better.
- [x] perplexity
- [x] v0-mamba-100k
- [x] 1 chunk per batch
- [x] 4 chunks per batch
- [x] 4 batches per chunk
- [x] 4 ubatches per batch
- [x] v0-llama2-100k
- [x] 1 chunk per batch
- [x] 4 chunks per batch
- [x] 4 batches per chunk
- [x] 4 ubatches per batch
- [x] v0-mamba-100k
- [ ]
llama-embedding- [x] t5-small
- fails with
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor")as onmaster.
- fails with
- [ ] mamba-130M
- ❌ Does NOT work with more than one sequence anymore
- [x] bge-small
- [x] t5-small
- [x]
llama-parallel -c 1024 -np 5 -ns 7 --seed 42 --temp 1- [x] t5-small
- fails to decode, as on
master(fixed segfault in https://github.com/ggerganov/llama.cpp/pull/8526/commits/652e9b0d6145a3e97d14e2a10efa531f1d76dc31)
- fails to decode, as on
- [x] stories-MoE
- works with
-c 1024, but segfaults otherwise (as onmaster)
- works with
- [x] mamba-130M
- works without problem
- [x] t5-small
- [x]
perplexity --hellaswag(parallel sequences of uneven length, also the only test with batches having more than one seq_id per token)- [x] v0-mamba-100k
- [x] v0-llama2-100k
- [x] save load state
- [x] Mamba-130M
- works (unlike on
master), but needs-np 2for the sequence load test.
- works (unlike on
- [x] v0-llama2-100k
- [x] t5-small
- fails, as on
master
- fails, as on
- [x] Mamba-130M
- [x] quantization (because an assertion was changed over there)
- [x] OpenELM-270M
- [x] t5-small
- [x] Mamba-370M
- [x] Gemma2 sliding window
- [x] Gemma2-2B-it perplexity with
-c 5120(its sliding window is 4096), first chunk generates the same perplexity. - [ ] parallel sliding windows (on more than one sequence)
- not sure how to test that
- [x] Gemma2-2B-it perplexity with
I've fixed the pooled embeddings problem with Mamba in https://github.com/ggerganov/llama.cpp/pull/8526/commits/b264eddbb26c695d50d04c37a5b9bb91181bc551 by making it only process a single sequence per ubatch. When the sequences are short, this is slightly slower than processing them all at once, unfortunately.
In the future, the pooled embeddings will be refactored to allow causal embeddings to be split across ubatches. It should also be possible to remove inp_cls, because it's redundant with inp_out_ids. LLAMA_POOLING_TYPE_CLS and LLAMA_POOLING_TYPE_LAST could be handled directly when splitting batches, since they only affect which tokens get their output selected. LLAMA_POOLING_TYPE_MEAN will be a bit harder to allow splitting, but since the total number of tokens per sequence per batch is known in advance, there might still be a way.
I'm postponing that pooled embeddings refactor to another PR. I consider this ready.
@compilade btw, I have the SSD implementation on CPU, more or less, if it's interesting for you.
On Wed, Aug 21, 2024 at 2:58 PM compilade @.***> wrote:
Merged #8526 https://github.com/ggerganov/llama.cpp/pull/8526 into master.
— Reply to this email directly, view it on GitHub https://github.com/ggerganov/llama.cpp/pull/8526#event-13967053247, or unsubscribe https://github.com/notifications/unsubscribe-auth/BGA7QVODOGJ22XVGJENS4D3ZSUEQ5AVCNFSM6AAAAABK7U7OGKVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJTHE3DOMBVGMZDINY . You are receiving this because you commented.Message ID: <ggerganov/llama. @.***>
@compilade
I get this error quantizing deepseek2 since the merge of this PR: https://github.com/ggerganov/llama.cpp/issues/9155