mistral.rs icon indicating copy to clipboard operation
mistral.rs copied to clipboard

Initial KV RingAttention code

Open joshpopelka20 opened this issue 1 year ago • 6 comments

This is the start of the RingAttention code. The changes so far have been to create multiple KV caches (if multiple num_devices) and to try to create separate chunks.

joshpopelka20 avatar Aug 14 '24 20:08 joshpopelka20

Code Metrics Report
  ===============================================================================
 Language            Files        Lines         Code     Comments       Blanks
===============================================================================
 C Header                2           35           28            0            7
 Dockerfile              1           34           25            0            9
 Happy                   1          442          369            0           73
 JSON                   12          105          104            0            1
 Python                 46         2018         1718           62          238
 TOML                   20          596          536            2           58
 YAML                    2           21           19            2            0
-------------------------------------------------------------------------------
 Jupyter Notebooks       4            0            0            0            0
 |- Markdown             2           77           32           31           14
 |- Python               2          196          169            1           26
 (Total)                            273          201           32           40
-------------------------------------------------------------------------------
 Markdown               30         2080            0         1580          500
 |- BASH                 5          101           98            0            3
 |- JSON                 1           12           12            0            0
 |- Python               5           92           82            0           10
 |- Rust                 7          441          395           22           24
 |- TOML                 2           75           63            0           12
 (Total)                           2801          650         1602          549
-------------------------------------------------------------------------------
 Rust                  202        62743        56960         1148         4635
 |- Markdown           103          950           13          885           52
 (Total)                          63693        56973         2033         4687
===============================================================================
 Total                 321        68074        59759         2794         5521
===============================================================================
  

github-actions[bot] avatar Aug 14 '24 20:08 github-actions[bot]

I didn't plan to use the chunk method. I plan on using IndexOp. I think I had some issues with the chunk method, but don't remember why.

I kept the forward block mostly the same. I just moved the feedforward (MLP layer) after attention is run on all chunks. Not 100% sure it'll work, but I'm trying to follow the algorithm as close as possible.

joshpopelka20 avatar Aug 14 '24 21:08 joshpopelka20

To use Tensor::chunk maybe we can split the input ids with dim = D::Minus1 into the number of devices chunks.

After that, we can just follow the algorithm for the rest. I would also recommend you split this into another model implementation, perhaps llama_ring_attention.rs and associated NormalModelLoader as this seems to be quite invasive. Essentially you need to split Block into a few steps.

EricLBuehler avatar Aug 14 '24 21:08 EricLBuehler

Essentially you need to split Block into a few steps.

Can you give me more details on the steps? That's where I'm confused.

I would also recommend you split this into another model implementation, perhaps llama_ring_attention.rs and associated NormalModelLoader as this seems to be quite invasive.

I can do this

joshpopelka20 avatar Aug 14 '24 21:08 joshpopelka20

I did a simple test to the mapper function that you pointed out above. I've confirmed that it is the problem.

Here is the code snippet of the test: image

The model now gives the correct output.

When you have the time, please let me know what should be done to create a more robust solution. I'm not sure how to proceed with it.

joshpopelka20 avatar Aug 15 '24 15:08 joshpopelka20

I don't want this to fall to the wayside, as I'll need it for my long context usecase.

  1. The latest commit has the sequence parallelism code using IndexOp. I wasn't sure how to use dimension of the chunk method so figured IndexOp was the best route to go.
  2. Also, added the llama_ring_attention code, which is just the same as the llama code. Figured it would be easier to see the sequence parallelism in the llama.rs file. I'll revert the llama.rs to master in the next commit.
  3. Finally, I looked into the NormalLoader, but wasn't sure of the changes needed to use the llama_ring_attention. I'll need some help with that.

I'm also still stuck with how to refactor the forward pass that you highlighted as an issue last time. Still need help with that :(

joshpopelka20 avatar Aug 22 '24 15:08 joshpopelka20