llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

llama : support Mamba Selective State Space Models

Open compilade opened this issue 1 year ago • 6 comments

[!NOTE] Some changes made between now and merging could possibly require re-converting Mamba models to GGUF. I'll announce it in this note if/when it happens.

This should fix #4353

Implementing Mamba in llama.cpp is taking more time than I thought. But it's getting close. See the TODO section below.

I started working on this as an experiment and because I wanted to try Mamba models with llama.cpp (also, there have been quite a few finetunes already). Turns out that implementing support for a novel model architecture is quite fun (well, at least when it finally works). The most powerful machine on which I try LLMs is a low-power laptop with 8GB of ram and an Intel CPU (no discrete GPU), so I can't try Mamba-3B in its full f32 glory (the full weights take 11GB), but at least now it's possible to use it quantized.

Constant memory usage is a big advantage of Mamba models, but this also means that previous states are not all kept in memory (at least in the current implementation, only the last one is kept), which means there might be more prompt re-processing than necessary in the server example (which won't work until llama_kv_cache_seq_rm can signal when it fails to partially delete a sequence). The main example has no such problem.) Currently, the initial text generation speed for Mamba is a bit slower than for Transformer-based models (with empty context), but unlike them, Mamba's speed does not degrade with the amount of tokens processed. Also note that quantization may make the state unstable (making the output gibberish), but this needs more testing to figure out how much this happens (because I only saw it happen with very small models (130M), and not yet with bigger ones (3B)).

For testing, I recommend converting from https://huggingface.co/Q-bert/Mamba-130M since it's small, the config.json doesn't require modification, and the tokenizer is already next to the model files.

(see below near the official weights reference for instructions on how to convert the official weights)

Design decisions

I'd like to discuss a few things before this can be merged:

  • Tensor names
    • I added new tensor types for Mamba, because the weights don't have exactly the same roles as in Transformers, but some of them might be redundant. For example, instead of ssm_out, I could probably have re-used attn_output, but then its relationship with ssm_in would have been less obvious. Conversely, it's not really attention, but I still re-used the attn_norm type for the layer norms.
  • Comments
    • I might have put too many comments. Some of them are explanations of what's going on, others are to help me refactor more easily (like the comments describing the changes in tensor dimensions).
  • Metadata and convert script
    • (ab)use of the KV cache metadata
      • Currently, the metadata for HEAD_COUNT, KEY_LENGHT and VALUE_LENGTH are used purely for making the KV cache the right size, depending on d_conv and d_state sizes (usually 4 and 16, respectively). This is probably wrong, since changing anything about the cache size (like I did in 7016fe53331309b525f0a0230a12574971aaef41) breaks existing converted-to-GGUF Mamba models.
    • What context length should be stored in the model? The one they trained with is not even in config.json next to their model weights, and the effective context length is bigger than that anyway. Should I put a huge number like $2^{20}$ (1048576), or should I put the sequence length with which they say they trained their models in the paper (2048)?
    • Speaking of config.json, the official Mamba models don't have an architectures field, which makes the model type hard to detect. For now, I've resorted to expecting "architectures": ["MambaForCausalLM"], in there, since the Q-bert/Mamba-* models are the only ones I've found which have an actual architecture defined in the config.json. Another architecture name which I've come across is MambaLMHeadModel, but it has not been used in config.json of any Mamba models I've looked for (I might have missed some). It seems like the class name of the official Mamba implementation, and I first saw it in the description of their 3B model trained on SlimPajama.
    • The official Mamba models use the GPT-NeoX tokenizer, but don't include the tokenizer at all along with their weights. There might be a way to detect the absence of tokenizer.json and use llama.cpp's models/ggml-vocab-gpt-neox.gguf, but I did not do this yet.
  • Quantization
    • Currently, the _S, _M and _L variants of k-quants are the same, because I don't yet know which weights are more (or less) important. This will require experimentation.
    • A thing I've noticed is that the more a Mamba model is quantized, the more likely the model will output gibberish (especially with smaller models like mamba-130m at Q4_K). It would be nice to find a quant mix which alleviates this.
  • Performance
    • I don't have a GPU, so I can't really test it there, so all of my numbers are for CPU inference.
      • I did not implement GPU kernels for the new operators I added (I could not test them anyway).
    • In the paper, they compare the performance with Transformers for 128 tokens after 2048 tokens have been put in the context. So my speed comparisons with empty-context-Transformers are not directly comparable with theirs.
    • Most of the CPU time is spent on matrix multiplications from the linear projections (there are lots of layers in Mamba models (64 in Mamba-3B))
    • I fused operations together in ggml_ssm_scan (and got a 25% perf boost on Mamba-3B compared to not fusing the operations), so I also removed my addition of the ggml_exp and ggml_soft_plus operators, since they are now unused.
    • I also fused the operations of the conv step in ggml_ssm_conv because managing the states of simultaneous sequences was easier that way.
    • Memory usage depends not on context length but on batch size and on the number of parallel sequences. If memory is precious, use a smaller batch size and/or fewer parallel sequences.

TODO

Things that should (probably) be done before merging, from more important to less important:

  • [x] Better support the KV cache API: should work correctly when the functions are used on whole sequences.
    • [ ] Replace usages of llama_kv_cache_seq_rm on parts of sequences to an equivalent way done using whole sequences (required for at least the server and parallel examples)
  • [x] Simultaneous sequence processing (required for the parallel example, the HellaSwag benchmark in the perplexity example, and probably also the server example)
  • ~~GPU kernels for ggml_ssm_scan~~ (out of scope for this PR)
  • [ ] Quantization mixes (find out which weights are more important)
  • [ ] Find a better way to handle the KV cache size than by misusing metadata
  • [ ] Detect lack of tokenizer.json and use models/ggml-vocab-gpt-neox.gguf when converting a Mamba model to GGUF
  • [ ] Remove redundant comments

References

  • The Mamba paper
    • https://arxiv.org/abs/2312.00752
  • Huge inspiration for initial implementation
    • https://github.com/kroggen/mamba.c
  • Another minimal Mamba implementation
    • https://github.com/johnma2006/mamba-minimal/blob/master/model.py
  • The official Mamba implementation
    • https://github.com/state-spaces/mamba
  • Official Mamba model weights
    • https://huggingface.co/state-spaces
    • If you want to convert them, edit their config.json to add "architectures": ["MambaForCausalLM"], then also download the GPT-NeoX tokenizer files (I think only tokenizer.json and tokenizer_config.json are needed, but I'm not sure. Maybe special_tokens_map.json and vocab.json are needed too, but maybe not.) and put them in the model directory.
  • Mamba models, but (re-exported? (the layer names are slightly different)) with a more complete config.json and the presence of tokenizer.json, so they are much easier to convert out-of-the-box.
    • https://huggingface.co/collections/Q-bert/mamba-65869481595e25821853d20d

compilade avatar Feb 05 '24 01:02 compilade

@compilade ust out of curiosity, is any convolution operation performed? I see some tensors with the name conv, but I never see ggml_conv_1d or ggml_conv_2d being used at any point.

FSSRepo avatar Feb 05 '24 03:02 FSSRepo

@compilade ust out of curiosity, is any convolution operation performed? I see some tensors with the name conv, but I never see ggml_conv_1d or ggml_conv_2d being used at any point.

@FSSRepo Well, at some point I tried to use ggml_conv_1d for this, but Mamba uses a number of groups equal to the number of in_channels, and ggml_conv_1d does not support setting the number of groups (at least, from what I understood when trying to make it work).

But it turns out that the desired operation in this case is exactly equivalent to making a self-overlapping view which shifts by one column at each stride in the 3rd dimension (which corresponds here to the number of tokens in the batch), and then doing a matrix multiplication with the conv1d weight of Mamba over the d_conv dimension (the kernel size of the convolution, which is 4). That matrix multiplication is done with ggml_mul and ggml_sum_rows because that way each row of the x tensor is still contiguous after permuting away the 1-sized dimension.

Not sure if I'm explaining this clearly, because I did not really know anything about convolutions before working on this.

(Here are the relevant lines for the "conv" step in my implementation.)

I figured this out when thinking about how to process multiple tokens at a time in the "conv" step when starting from how the next conv_state is built one token at a time in mamba.c and the corresponding lines in the official "simple" implementation. What I ended up with is much simpler than what I initially thought would have been necessary for batch processing.

compilade avatar Feb 05 '24 04:02 compilade

Regarding the KV questions: IIUC one slot is needed per sequence, so in that sense the KV cache size could be interpreted as the maximum number of distinct sequences that can be processed simultaneously.

(What follows are some thoughts about where the number of distinct sequences should be taken from. TL;DR at the end.)

For Mamba 3B, each KV slot takes 23.75 MiB. If the value is taken from n_ctx, since the default value is 512, the KV cache would take 512 times 23.75 MiB, which is 11.875 GiB, an unacceptably large amount of used memory (especially since most people won't use anywhere near 512 distinct sequences at the same time). Also, even if somehow a very small n_ctx is used with Mamba, almost everything currently seems to expect the input prompt(s) to never be bigger than the context size, which complicates this solution. But at least, since quite a lot of memory calculations are based on n_ctx, less initialization code has to be changed.

So, let's say I instead take the max number of distinct sequences from the value of n_parallel, then that value would need to be available at KV cache initialization (easy, this means adding it to llama_context_params to access it from llama_new_context_with_model and then in llama_kv_cache_init). The current default value of 1 here is reasonable, but for servers, n_parallel should be at least as big as the number of users, or it won't work properly (is it already like this for Transformer-based models?).

But then I have to change how some things are initialized, replacing cparams.n_ctx with kv_self.size in a bunch of places where the assumption was that they are the same (but they aren't with Mamba). I think that's the better way, since it would also make it easier to use the KV cache differently than even what Mamba and Transformers do. If there's ever a non-linear and non-constant way to fill the KV cache, it should be easier to implement after this change.

Another thing regarding n_parallel and the KV cache size: even for Transformers, it could be useful to make the KV cache size a multiple of n_ctx, which would make the n_ctx per client slot in the server example easier to reason about (each would simply be equal to the global n_ctx). Though it would make it harder at a glance to see the total context size. I'm not sure what users expect when setting the --ctx-size in conjunction with --parallel regarding memory usage (currently, each client slot gets a fraction of the specified --ctx-size). I presume it was done this way because the KV cache size had always been set from n_ctx. In any case, it's probably best to avoid making user-facing breaking changes like this for Transformer-based models in this PR, though, so I'll leave this idea unimplemented for now.

TL;DR:

I'll try to make Mamba's KV cache size proportional to n_parallel as it seems to be the appropriate parameter to get the max number of distinct sequences processed at once.

compilade avatar Feb 05 '24 21:02 compilade

I've been thinking about what parts of the KV cache API can and cannot be supported for Mamba.

In general, functions which operate on whole sequences or the whole KV cache can be relatively easily supported.

But a lot of KV cache API functions take a range of token positions, and this cannot easily work with Mamba (too many states would need to be kept unnecessarily).

Function Can be supported Acceptable
llama_kv_cache_clear Yes Yes
llama_kv_cache_seq_rm Partially No
llama_kv_cache_seq_cp Partially Yes
llama_kv_cache_seq_keep Yes Yes
llama_kv_cache_seq_shift No Yes
llama_kv_cache_seq_div No Yes

Here, "Partially" means "Only on entire sequences" (all tokens of a sequence, regardless of their position).

The most problematic function is llama_kv_cache_seq_rm, which is used in the server example to clear tokens after the system prompt. This could be worked around by dedicating a seq_id to the system prompt, and then using llama_kv_cache_seq_cp to copy over the system prompt to the other sequences when it's needed. The seq_ids for the client slots would need to be offset, though.

I think that most of what is currently done with position ranges (when using llama_kv_cache_seq_cp and llama_kv_cache_seq_rm) could be done with better sequence management.

This is kind of a blocker for Mamba support in llama.cpp, but it can wait. I need to finish trying to make multiple independent sequences work with Mamba before this can be useful to fix.

compilade avatar Feb 09 '24 00:02 compilade

This could be worked around by dedicating a seq_id to the system prompt, and then using llama_kv_cache_seq_cp to copy over the system prompt to the other sequences when it's needed. The seq_ids for the client slots would need to be offset, though.

Yes, that sounds like the right way to do it

I think that most of what is currently done with position ranges (when using llama_kv_cache_seq_cp and llama_kv_cache_seq_rm) could be done with better sequence management.

More thoughts on this are welcome

ggerganov avatar Feb 09 '24 12:02 ggerganov

Now that multiple sequences can be processed at once, I've been trying to make the server example work with Mamba.

I think that most of what is currently done with position ranges (when using llama_kv_cache_seq_cp and llama_kv_cache_seq_rm) could be done with better sequence management.

More thoughts on this are welcome

I think I was wrong. Some uses of position ranges do seem necessary. The server example currently uses llama_kv_cache_seq_rm with position ranges to keep the common part between the cached tokens and the requested prompt. The default web example seems to trim one or two token at the end of each response, which means the "correct" way to do this with Mamba is to reprocess the whole prompt when it's not completely the same as the cached tokens (which is not ideal and could give a bad first impression of Mamba replying slower over time because of the re-processing time).

I'm wondering if it's okay to make llama_kv_cache_seq_rm return a bool (instead of void, changing the public API for that function) to let the caller know whether the removal succeeded or not (it would only be fallible for Mamba for now). This way, the server can still try to trim the prompt when it can, and fallback to re-process everything after the system prompt when it can't.

But I've been thinking of a way to calculate previous states from more recent ones. From the (2a) equation of the paper, which looks like what is done in ggml_ssm_scan, the next ssm_states is calculated thusly :

$$h_t = \mathbfit{\overline{A}}{h_{t-1}} + \mathbfit{\overline{B}}x_t$$

Solving for $h_{t-1}$, it should be possible to get the previous ssm_states :

$$h_{t-1} = \frac{h_t - \mathbfit{\overline{B}}x_t}{\mathbfit{\overline{A}}}$$

But getting the previous conv_states (which is also necessary to get that $x_t$) requires the fourth last processed token (4 comes from the value of d_conv in the Mamba models). So a list of previous tokens would need to be kept to make the states go further back (still much lighter on memory use than actually keeping the states).

~~But I'm not sure how to integrate that with the forward pass. Should the roll-back be done at the next llama_decode after llama_kv_cache_seq_rm, or right away? (making llama_kv_cache_seq_rm sometimes slow)~~ This could be handled similarly to KV cache defrag. Where could the token list be stored which won't unnecessarily reserve memory for models which won't need it? What about needing to go further back than the batch size? Which time would "unprocessing" tokens count towards? Is rolling-back the states even possible? (probably not, the z tensor (the right branch of the Mamba architecture in Figure 3 of the Mamba paper) might make this impossible, or at least more complicated than the above)

These are questions I'll ponder during the next month (so probably in another PR), after I make the parallel and server examples work out-of-the-box with Mamba in the coming weekend (right now they work, but not as-is).

compilade avatar Feb 22 '24 03:02 compilade

Okay, I think this is finally ready for review. Pretty much everything works (on CPU, at least), and I've updated the first post with an "Out of scope for this PR" section, because this is getting big.

I'd be happy to answer any question or criticism regarding this Mamba implementation.

compilade avatar Mar 03 '24 19:03 compilade

Are you familiar with RWKV?

I've read about it, but I'm not as familiar with RWKV as I'd like, unfortunately. :sweat_smile:

I'm wondering how well the proposed changes fit with the RWKV architecture. Asking for any insights, so I can do some planning for future developments

Forward compatibility means making sure the added GGUF metadata keys for state sizes are general enough. I'll at least start by listing the main differences I noticed when looking at RWKV-v5 right now.

  • The operators are different; ggml_ssm_conv and ggml_ssm_scan don't fit with RWKV's architecture (at least not from a cursory glance). RWKV also seems to use trigonometric functions like tanh unlike Mamba.
    • Maybe the ggml_ssm_* operators could be renamed to ggml_mamba_* if they are not general enough.
  • The state size seems different in RWKV. The total state size per layer (in RWKV-v5) is (n_embd/n_head) * n_embd (which is also head_size * head_size * n_head), and it's for both time-mixing and channel-mixing, while in Mamba, the total state size per layer is (d_conv - 1 + d_state) * d_inner, where d_inner = 2*n_embd. (the -1 is because the first column of the conv_state is shifted out each time, so it's not actually needed. conv_state is a rolling state.)
    • Mamba has 2 state types per layer, one rolling (conv_state), and one recurrent (ssm_state), while RWKV seems to only have a recurrent state per layer (but I might be wrong).

Ideally, the required state size could be directly stored in GGUF metadata. Mamba's conv_state is a rolling state; this needs to be considered when allocating the states, because if keeping past rolling states is desirable for backtracking, only one column per previous state needs to be kept instead of the whole thing each time. Perhaps the shift (1) should be stored too alongside its size? Though I don't see why a rolling state would use any other shift than 1 column (are there counter-examples out there?). The recurrent state of RWKV could use the {arch}.ssm.d_state and {arch}.ssm.d_inner keys for its size, and keep {arch}.ssm.d_conv to 0 (the -1 won't occur when d_conv is zero) if it doesn't have a rolling state. Still thinking about better names for those metadata keys, but (to me) they look reasonable even if they keep these names.

Thanks for bringing this up, by the way, because comparing Mamba with RWKV made me realize that the conv_state of Mamba could be explained with the words "rolling state".

compilade avatar Mar 04 '24 20:03 compilade

Will be reviewing this PR in the following days, thanks

Since master continues to change (and that's good), I hope it's okay if I resolve conflicts with ~~git rebase -i master and git push --force-with-lease, and not~~ git merge master. (EDIT: see note at the end)

If you're curious about what I'm changing, I've found git range-diff master old-branch new-branch to work very well to know what changed in the diffs after rebasing. Though, I realize that this only works if you conveniently already have the old and new commits fetched locally.

If you prefer that I use merge commits instead of rebasing, you can react to this comment with :-1: (I know it has a bad connotation, but it's the correct reaction for disagreement) Otherwise if it's fine that I rebase, I'd like to know too (with :+1:, possibly).

Apologies in advance if I rebase before you read this, but know that I'll use merge commits in the future if that's what you prefer. (EDIT: I'll use merge commits from now on, because rebases are harder to review on GitHub. I've restored this branch before the shameless rebase I did, then merged master while resolving conflicts in the same way. This restored the correct commit timeline instead of putting them all after this comment. Also, the conflict resolution is now easier to audit (both on GitHub and with git log --remerge-diff or git log --cc))

compilade avatar Mar 04 '24 22:03 compilade

The implementation is pretty good

Thanks!

I'm still not convinced we need to introduce n_parallel and llama_n_max_seq().

Imagine the following case: A user wants to use Mamba 3B to process a prompt with a length of... 1337 tokens. This user is only using a single sequence. Out of habit with how other models work, the user passes --ctx-size 2048.

Now, the two ways to do this:

  • With the number of sequences coming from n_parallel (how it's done in this PR for Mamba)
    • A single sequence is allocated (this uses 23.75 MiB of RAM) :white_check_mark:
    • llama_batch_init() creates a buffer big enough for 2048 tokens :white_check_mark:
    • Everything goes well since 1337 tokens can fit in the 2048-tokens buffer. :white_check_mark:
  • With the number of sequences coming from n_ctx
    • 2048 sequences are allocated (this uses 47.5 GiB of RAM) :exclamation:
    • llama_batch_init() creates a buffer big enough for 2048 tokens. :white_check_mark:
    • Everything goes well except if the user doesn't have at least 64 GiB of RAM. :yellow_circle:

Okay that was unfair. Let's say the user is better-informed and passes --ctx-size 1 or that n_ctx is somehow 1 (e.g. from different defaults).

  • With the number of sequences coming from n_parallel or from n_ctx (both are 1 here so the result is the same):
    • A single sequence is allocated (this uses 23.75 MiB of RAM) :white_check_mark:
    • llama_batch_init() creates a buffer big enough for 1 token. :exclamation:
    • Buffer overflow when the 1337 tokens are added to the 1-token batch buffer. :x:

I don't really see from where else than n_ctx the size of the batch buffer (allocated with llama_batch_init()) could come from (especially since this is about the batch buffer before it's split-into-n_batch-sized-parts).

Using n_parallel for the number of sequences was overall easier than trying to change the meaning of n_ctx.

The same reasoning also applies for examples like perplexity and parallel.

I hope this better explains why the context size and the number of sequences were made orthogonal for Mamba. (note that unlike with Mamba, llama_n_ctx() and llama_n_max_seq() are equivalent for Transformer-based models)

If in some places we expect the input to not be big bigger than n_ctx (such as the context shift logic), we can try to fix these (simply disable context shift for Mamba models).

These checks are also used to avoid overflowing the buffer allocated with llama_batch_init().

Currently, context shifting is faked for recurrent models to let n_past be made smaller than n_ctx, while still getting consecutive token positions. (though ideally n_ctx should be big enough for this to never happen, hence the current very big (2**20 (1048576)) default context length stored in the model's metadata for Mamba in convert-hf-to-gguf.py (even though maybe only LongMamba could make use of it))

Either way, we can merge it as it is since the API change is quite small

I agree. Note that I've adapted my changes for Mamba to the recent refactor of the server example. The only "weird" things I'm (still) doing there are

  • using slot.id + 1 as the seq_id of each slot (since the system prompt uses the KV cache's seq_id 0)
    • this might be confusing, but at least the external behavior stays the same (i.e. slot id 0 exists)
  • a little dance with n_parallel to add 1 to it to reserve a sequence id for the system prompt, initialize the model, then remove 1 from n_parallel so that it can be used with the same meaning as before as the number of client slots.
  • checking for the failure of llama_kv_cache_seq_rm() when removing the part of the cache which is not common with the prompt, because recurrent models (currently) can't have their states partially reset.

compilade avatar Mar 07 '24 21:03 compilade

Since the transformers library is getting support for Mamba (https://github.com/huggingface/transformers/pull/28094), the official Mamba models have been re-released with more metadata.

See https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406

I think I should rename the GGUF key-value pairs I added for Mamba to make them more similar to their transformers counterpart.

Current GGUF name transformers name Possible new GGUF name
{arch}.ssm.d_conv conv_kernel {arch}.ssm.conv_kernel
{arch}.ssm.d_state state_size {arch}.ssm.state_size
{arch}.ssm.d_inner intermediate_size {arch}.ssm.inner_size
{arch}.ssm.dt_rank time_step_rank {arch}.ssm.time_step_rank

This would break existing GGUF-converted Mamba models, though. (none have been published yet it seems, so those affected should easily be able to reconvert) If I rename them, it needs to happen before merging.

(EDIT: the above change has been done. If there are any objections, I'd like to know)

compilade avatar Mar 08 '24 01:03 compilade

I don't really see from where else than n_ctx the size of the batch buffer (allocated with llama_batch_init()) could come from (especially since this is about the batch buffer before it's split-into-n_batch-sized-parts).

Thanks, I agree now. We should actually start using llama_n_max_seq() instead of n_ctx to init batches in the examples to make it more semantically clear. We can do this in another PR

Feel free to merge this (squash in single commit) when you think it is ready. Maybe add a short notice in the "Recent API changes" section in the README.md to help 3rd party devs and consider updating the GGUF spec with the new keys

ggerganov avatar Mar 08 '24 09:03 ggerganov

We should actually start using llama_n_max_seq() instead of n_ctx to init batches in the examples to make it more semantically clear.

There might be a misunderstanding here. To be clear, llama_n_max_seq() returns the upper limit of acceptable seq_id in batches. This is only relevant when dealing with multiple sequences.

What caused llama_n_max_seq() to exist is the perplexity example, which creates a lot of sequences, especially in the HellaSwag benchmark. I needed this limit to make it avoid using sequences ids that could not fit in Mamba's KV cache.

Unless an example really uses ALL available sequences on any single token in a batch, llama_n_max_seq() should not be used when initializing batches. n_ctx is not replaced by this.

Feel free to merge this (squash in single commit) when you think it is ready.

Noted.

compilade avatar Mar 08 '24 16:03 compilade