mistral.rs
mistral.rs copied to clipboard
Initial KV RingAttention code
This is the start of the RingAttention code. The changes so far have been to create multiple KV caches (if multiple num_devices) and to try to create separate chunks.
Code Metrics Report
=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== C Header 2 35 28 0 7 Dockerfile 1 34 25 0 9 Happy 1 442 369 0 73 JSON 12 105 104 0 1 Python 46 2018 1718 62 238 TOML 20 596 536 2 58 YAML 2 21 19 2 0 ------------------------------------------------------------------------------- Jupyter Notebooks 4 0 0 0 0 |- Markdown 2 77 32 31 14 |- Python 2 196 169 1 26 (Total) 273 201 32 40 ------------------------------------------------------------------------------- Markdown 30 2080 0 1580 500 |- BASH 5 101 98 0 3 |- JSON 1 12 12 0 0 |- Python 5 92 82 0 10 |- Rust 7 441 395 22 24 |- TOML 2 75 63 0 12 (Total) 2801 650 1602 549 ------------------------------------------------------------------------------- Rust 202 62743 56960 1148 4635 |- Markdown 103 950 13 885 52 (Total) 63693 56973 2033 4687 =============================================================================== Total 321 68074 59759 2794 5521 ===============================================================================
I didn't plan to use the chunk method. I plan on using IndexOp. I think I had some issues with the chunk method, but don't remember why.
I kept the forward block mostly the same. I just moved the feedforward (MLP layer) after attention is run on all chunks. Not 100% sure it'll work, but I'm trying to follow the algorithm as close as possible.
To use Tensor::chunk maybe we can split the input ids with dim = D::Minus1 into the number of devices chunks.
After that, we can just follow the algorithm for the rest. I would also recommend you split this into another model implementation, perhaps llama_ring_attention.rs and associated NormalModelLoader as this seems to be quite invasive. Essentially you need to split Block into a few steps.
Essentially you need to split Block into a few steps.
Can you give me more details on the steps? That's where I'm confused.
I would also recommend you split this into another model implementation, perhaps llama_ring_attention.rs and associated NormalModelLoader as this seems to be quite invasive.
I can do this
I did a simple test to the mapper function that you pointed out above. I've confirmed that it is the problem.
Here is the code snippet of the test:
The model now gives the correct output.
When you have the time, please let me know what should be done to create a more robust solution. I'm not sure how to proceed with it.
I don't want this to fall to the wayside, as I'll need it for my long context usecase.
- The latest commit has the sequence parallelism code using IndexOp. I wasn't sure how to use dimension of the chunk method so figured IndexOp was the best route to go.
- Also, added the llama_ring_attention code, which is just the same as the llama code. Figured it would be easier to see the sequence parallelism in the llama.rs file. I'll revert the llama.rs to master in the next commit.
- Finally, I looked into the NormalLoader, but wasn't sure of the changes needed to use the llama_ring_attention. I'll need some help with that.
I'm also still stuck with how to refactor the forward pass that you highlighted as an issue last time. Still need help with that :(