llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llama : simplify Mamba with advanced batch splits

Open compilade opened this issue 1 year ago • 7 comments

As promised in https://github.com/ggerganov/llama.cpp/pull/7531#issuecomment-2209585346, I've been extracting the advanced batch splits out of the Jamba PR (#7531).

I've also backported the contiguous allocation of recurrent state slots, which makes it possible to also include the changes from #7531 which simplify the ggml operators used specifically for Mamba. Hopefully this isn't too much at once.

See https://github.com/ggerganov/llama.cpp/pull/7531#discussion_r1620997020 for an explanation of the batch splits.

Summary

  • ggml.c
    • Simplify ggml_ssm_conv and ggml_ssm_scan by assuming batched sequences have the same number of new tokens, and that the states are contiguous and ordered correctly.
    • Allow ggml_concat to work with a non-contiguous second argument.
      • The CPU implementation already supported this, but it was guarded with an assertion. Meanwhile, I think the CUDA implementation already supports this too, and does not prevent its usage (not totally sure), so I did not change it.
  • llama.cpp
    • Advanced batch splits handled with lctx.sbatch for persistent buffers
      • Refactor "helpers for smoother batch API transition", by handling them in llama_sbatch, which allows avoiding repeated allocations by re-using the same buffers.
      • Simple batch splits should be equivalent to the previous behavior and are made with lctx.sbatch.split_simple(n_tokens) to build a llama_ubatch with a max size of n_tokens.
      • Equal-sequence-lengths splits are made with lctx.sbatch.split_equal(n_tokens), and are used to simplify the operators of recurrent models.
      • Add llama_ubatch. Similar to llama_batch, but aware of equal-length sequences.
        • Make llama_set_inputs (and others) use llama_ubatch instead of llama_batch.
    • Make recurrent state slot allocation contiguous in llama_kv_cache_find_slot
    • Add llm_build_mamba to build a Mamba block, used for Mamba, and will be used for Jamba
    • Add llm_build_copy_mask_state (maybe not a good name) to abstract away the shuffling and masking of recurrent states. Used for Mamba, and it should be usable for other recurrent architectures too.
    • Simplify the sanity checks for qs.n_attention_wv in llama_model_quantize_internal to make it future proof for hybrid models.
    • Reorder the outputs when using advanced batch splits like split_equal in conjunction with llama_get_logits, because the API makes it so that the outputs should have the same order they had as the user-provided batch, not something based on batch split rules.

For simplicity, this does not include the separation of the KV cache and the recurrent state cache. Both still use the same buffers (lctx.kv_self.k_l, and lctx.kv_self.v_l, as on master). The separation (necessary for hybrid models) will be introduced at the same time as Jamba.

TODO

  • [x] Test the slot allocation of llama_kv_cache_find_slot with the --hellaswag benchmark in llama-perplexity with a Mamba model
    • This uses lots of parallel sequences in an unusual way, and so I think it's a great stress test.
  • [x] Session file saving and reloading
    • [x] Reloading needs to rebuild the tail metadata for recurrent states. (i.e. which cell is the end of which sequence)
    • [x] The server tests need to pass
  • [x] Make sure T5 still works
  • [x] Make sure the pooled embeddings still work
    • tested bge-small with llama-embeddings with parallel prompts with --pooling cls, --pooling last and --pooling mean; results exactly match master.
  • [x] Make sure Gemma's sliding window mask still works
  • [x] Decide whether to rename llama_reorder_outputs to llama_output_reorder and move it close to llama_output_reserve.
    • renamed and moved

Future ideas

  • whole-sequence splits for embeddings
  • handle pooling types like cls and last within the ubatch.outputs when splitting a batch; inp_cls is redundant with inp_out_ids.

compilade avatar Jul 17 '24 01:07 compilade

Make sure Gemma's sliding window mask still works

The following command produces identical perplexity on master and this branch:

./llama-perplexity \
    -m models/gemma-2-9b/ggml-model-f16.gguf \
    -f build/wikitext-2-raw/wiki.test.raw \
    -ngl 99 -c 8192

Is this enough to confirm the SWA functionality?

ggerganov avatar Jul 18 '24 15:07 ggerganov

Is this enough to confirm the SWA functionality?

I think so. Might also be relevant to test SWA with parallel sequences too (I think this is what using a bigger -b (and -ub?) than -c does with llama-perplexity).

compilade avatar Jul 19 '24 06:07 compilade

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

hackey avatar Jul 23 '24 14:07 hackey

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.

First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

compilade avatar Jul 24 '24 04:07 compilade

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.

First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

I also encountered difficulties running mamba-codestral. I tried to run this model on https://github.com/state-spaces/mamba. But there is no config.json in the model repository. mamba-codestral includes a new tokenizer v3. Although Mistral writes that the model can be run on state-spaces/mamba, nothing worked for me.

Please see the discussion here: https://github.com/NVIDIA/TensorRT-LLM/issues/1968 and a few hours ago an example for running mamba appeared: https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/mamba

Maybe this will help development.

hackey avatar Jul 24 '24 06:07 hackey

For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

Yes, we can hardcode initially

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

Sounds good - a separate PR would be easier to review

Regarding Codestral - want to highlight again the comment by Mistral team about ngroups = 8: https://github.com/ggerganov/llama.cpp/issues/8519#issuecomment-2233110921. Seems important

ggerganov avatar Jul 24 '24 10:07 ggerganov

Guys, is there any progress in supporting Mamba2 (I'm interested in the new mamba-codestral)?

Still waiting on some upstream changes (see https://huggingface.co/mistralai/mamba-codestral-7B-v0.1/discussions/1), but otherwise I'm beginning to investigate the conversion for Mamba2 models, at least to have some GGUFs (even with no inference support) to experiment with implementing it.

First thing I'm noticing is the lack of metadata in the config.json of Mamba2 models. No state size, no convolution kernel size, no time step rank, and in the case of mamba-codestral-7B-v0.1, no indication that it's a Mamba2 model, except from the tensor names and sizes. For the state sizes, I guess these are hardcoded in the state-spaces/mamba implementation, in which case I'll hardcode them too and/or find what is used to calculate them.

I've also recently started to simplify the session file save & restore code in llama.cpp (but I'll likely open a separate PR, since I think that refactor is best tested on its own), because I'm noticing that it's often causing me problems to adapt it to changes to the KV cache structure, due to there being at least 4 places needing to be updated and/or considered for each change (read/write + seq read/write). So I'll be unifying these code paths to make them easier to maintain.

https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2_simple.py

This includes some details that may be interesting for you.

awgr avatar Jul 29 '24 18:07 awgr

I'll be re-running a few tests before merging this in hopefully less than 2 days. There is now both Mamba-2 and RWKV v6 which kind of need this to simplify the implementation.

Still, I don't want to accidentally have broken something with the batch splits, so I'll try to convince myself that there is no problem by running more tests.

compilade avatar Aug 19 '24 04:08 compilade

I've ran some tests, and there's a problem: pooled embeddings with Mamba can't work with multiple sequences anymore.

This is because lctx.embd_seq is overwritten at each ubatch which makes it only work if everything fits in a single ubatch, which is not the case when sequences don't have the same length and are split to makes them all equal in Mamba's ubatch.

This could be fixed by letting causal embeddings be split over multiple ubatch. I'll try to find a way to do this cleanly.

Where the checkbox is checked, it means the behavior is the same as on master or better.

  • [x] perplexity
    • [x] v0-mamba-100k
      • [x] 1 chunk per batch
      • [x] 4 chunks per batch
      • [x] 4 batches per chunk
      • [x] 4 ubatches per batch
    • [x] v0-llama2-100k
      • [x] 1 chunk per batch
      • [x] 4 chunks per batch
      • [x] 4 batches per chunk
      • [x] 4 ubatches per batch
  • [ ] llama-embedding
    • [x] t5-small
      • fails with GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor") as on master.
    • [ ] mamba-130M
      • ❌ Does NOT work with more than one sequence anymore
    • [x] bge-small
  • [x] llama-parallel -c 1024 -np 5 -ns 7 --seed 42 --temp 1
  • [x] perplexity --hellaswag (parallel sequences of uneven length, also the only test with batches having more than one seq_id per token)
    • [x] v0-mamba-100k
    • [x] v0-llama2-100k
  • [x] save load state
    • [x] Mamba-130M
      • works (unlike on master), but needs -np 2 for the sequence load test.
    • [x] v0-llama2-100k
    • [x] t5-small
      • fails, as on master
  • [x] quantization (because an assertion was changed over there)
    • [x] OpenELM-270M
    • [x] t5-small
    • [x] Mamba-370M
  • [x] Gemma2 sliding window
    • [x] Gemma2-2B-it perplexity with -c 5120 (its sliding window is 4096), first chunk generates the same perplexity.
    • [ ] parallel sliding windows (on more than one sequence)
      • not sure how to test that

compilade avatar Aug 21 '24 02:08 compilade

I've fixed the pooled embeddings problem with Mamba in https://github.com/ggerganov/llama.cpp/pull/8526/commits/b264eddbb26c695d50d04c37a5b9bb91181bc551 by making it only process a single sequence per ubatch. When the sequences are short, this is slightly slower than processing them all at once, unfortunately.

In the future, the pooled embeddings will be refactored to allow causal embeddings to be split across ubatches. It should also be possible to remove inp_cls, because it's redundant with inp_out_ids. LLAMA_POOLING_TYPE_CLS and LLAMA_POOLING_TYPE_LAST could be handled directly when splitting batches, since they only affect which tokens get their output selected. LLAMA_POOLING_TYPE_MEAN will be a bit harder to allow splitting, but since the total number of tokens per sequence per batch is known in advance, there might still be a way.

I'm postponing that pooled embeddings refactor to another PR. I consider this ready.

compilade avatar Aug 21 '24 04:08 compilade

@compilade btw, I have the SSD implementation on CPU, more or less, if it's interesting for you.

On Wed, Aug 21, 2024 at 2:58 PM compilade @.***> wrote:

Merged #8526 https://github.com/ggerganov/llama.cpp/pull/8526 into master.

— Reply to this email directly, view it on GitHub https://github.com/ggerganov/llama.cpp/pull/8526#event-13967053247, or unsubscribe https://github.com/notifications/unsubscribe-auth/BGA7QVODOGJ22XVGJENS4D3ZSUEQ5AVCNFSM6AAAAABK7U7OGKVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJTHE3DOMBVGMZDINY . You are receiving this because you commented.Message ID: <ggerganov/llama. @.***>

awgr avatar Aug 22 '24 04:08 awgr

@compilade

I get this error quantizing deepseek2 since the merge of this PR: https://github.com/ggerganov/llama.cpp/issues/9155

mann1x avatar Aug 24 '24 13:08 mann1x