mistral.rs icon indicating copy to clipboard operation
mistral.rs copied to clipboard

[Feature] Implementation of multi-gpu KV cache (RingAttention)

Open joshpopelka20 opened this issue 1 year ago • 19 comments

I'll work through adding it to quantized llama first, as I know that architecture the most. Link to the paper: https://arxiv.org/abs/2310.01889

joshpopelka20 avatar Jul 22 '24 19:07 joshpopelka20

I've been researching the algorithm further, and I'm thinking I'm going to have a problem implementing this with Rust. To start, based on my understanding, I'd need to split the input into different Tensors, with each split tensor on a different GPU. Pytorch has select https://pytorch.org/docs/stable/generated/torch.select.html, but doesn't look like candle has an equivalent. Is that correct?

joshpopelka20 avatar Jul 26 '24 18:07 joshpopelka20

To split the input into different tensors, I would use the narrow method to split along different ranges. If you want something a bit more flexible, IndexOp is probably what you want.

To implement torch.select, I'd use IndexOp (.i) to extract at the given dimension.

Are there any other problems? Happy to help!

EricLBuehler avatar Jul 27 '24 01:07 EricLBuehler

Thanks for the help! That's it at the moment.

I'm just trying to implement the first part of the algorithm, which is the sequence parallelism. As I work through it, I'll definitely have more questions; it's just going to be slow as I'm really digging deep into this for the first time.

joshpopelka20gmail avatar Jul 27 '24 14:07 joshpopelka20gmail

I've been trying to implement this algorithm from the paper https://arxiv.org/pdf/2310.01889, and it really isn't working.

image

The KV cache isn't being split so that's a big problem, but I'm more concerned with the garbage output that I'm getting. If I do a PR, can you give it a look and let me know what looks incorrect?

joshpopelka20 avatar Jul 31 '24 19:07 joshpopelka20

Just adding a little more info.

I think the biggest problem I'm facing is that the KV cache needs to cycle between GPUs. I'm trying to do this with a for loop to get something working, but I don't think it is correct.

The device that the KV cache is using is set based on the Key and Value tensor devices. It's almost like I need a method to move the Tensors to new Devices. Here is my implementation so far:

let num_caches = self.kv_caches.len();
let mut accumulated_attention: Option<Tensor> = None;

for cache_rotation in 0..num_caches {
    let cache_idx = (chunk_idx + cache_rotation) % num_caches;
    let kv_cache = &self.kv_caches[cache_idx];
    let mut cache = kv_cache.lock();

    let device_chunk = chunk.device();

    // Determine the device of the cache
    let cache_device = cache.iter().find_map(|opt| {
        opt.as_ref().map(|(k, _)| k.device().clone())
    }).unwrap_or_else(|| device_chunk.clone());

    let mask = CausalMasker.make_causal_mask_as_attn_bias(
        input_ids,
        metadata
            .as_ref()
            .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
            .unwrap_or(&*cache as &dyn PastKvLenCache),
        chunk.dtype(),
        self.blocks[0].attn.num_attention_heads,
    )?;

    let mut x = self.mapper.map(chunk.to_device(&cache_device)?, block_idx)?;

    x = block.forward(
        &x,
        &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
        seqlen_offsets,
        start_offsets_kernel.clone(),
        block_idx,
        &mut cache,
        metadata
            .as_mut()
            .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)),
    )?;

    // Accumulate attention results
    if let Some(ref mut acc) = accumulated_attention {
        *acc = acc.add(&x.to_device(acc.device())?)?;
    } else {
        accumulated_attention = Some(x);
    }

Let me know if you'd like to see the full code so far in a PR. I've made most of the changes in mistralrs-core/src/models/llama.rs.

joshpopelka20 avatar Aug 01 '24 15:08 joshpopelka20

@joshpopelka20 yes if you do a PR, I can absolutely take a look.

It's almost like I need a method to move the Tensors to new Devices.

Could you use Tensor::to_device? For the KV cache to cycle between 2 devices, I would check out the Cache struct. This struct has no idea of device mapping though, so it can innately handle KV cache split on multiple GPUs.

EricLBuehler avatar Aug 05 '24 13:08 EricLBuehler

@EricLBuehler thanks for the reply.

For:

Could you use Tensor::to_device?

I'm already using this for the chunks. From this code: let mut x = self.mapper.map(chunk.to_device(&cache_device)?, block_idx)?;, you can see I'm setting the chunk to the kv_cache device (the chunk is just the seq_len split into num_devices, in my case, four.

            for j in 0..num_devices {
                let start = j * chunk_size;
                let end = if j == num_devices - 1 {
                    seq_len
                } else {
                    (j+ 1) * chunk_size
                };
    
                let chunk = x.i((.., start..end,..))?; // use IndexOps 
                let device = &self.cuda_devices[j];
                chunks.push(chunk.to_device(&device)?); 
            }

I would check out the Cache struct.

I asked about a method to move the KV cache because that's what the algorithm does. I'm trying to do it by moving the chunks device, but I'd think if I could have a way to move the KV cache device, I'd get better results. The only method that looks close is:

fn clone_in_cache(
    num_hidden_layers: usize,
    cache: &mut LayerCaches,
    seqs: &mut [&mut crate::sequence::Sequence],
    src: SeqCache,
)

I think I need a method that takes in two kv caches and the new device, then clones the first kv cache data and moves it to the second kv cache on the new device. So basically, is there a method to clone the key and value tensors? I can use to_device for changing devices.

joshpopelka20 avatar Aug 05 '24 15:08 joshpopelka20

I think I have enough info now. Currently, I'm working through doing some instruction fine-tuning, which is taking all my time. I should be able to work on this again starting next week.

joshpopelka20 avatar Aug 08 '24 14:08 joshpopelka20

@joshpopelka20 sounds good - please feel free to open a PR so I can take a look!

I think I need a method that takes in two kv caches and the new device, then clones the first kv cache data and moves it to the second kv cache on the new device. So basically, is there a method to clone the key and value tensors? I can use to_device for changing devices.

Modifications like this would be done in the Cache struct. You could add a method to it which tells it to map the devices for its KV cache, would that work?

EricLBuehler avatar Aug 08 '24 14:08 EricLBuehler

Modifications like this would be done in the Cache struct. You could add a method to it which tells it to map the devices for its KV cache

That's the plan. After I add that, I'll create the PR.

joshpopelka20 avatar Aug 08 '24 14:08 joshpopelka20

I've found out why I'm getting garbage output. My problem is with the Tensor operations. I have to clone 'x' to insert into the chunks Vec and clone 'chunk' in the map function. When I do this, the output isn't coherent. For some reason, that operation isn't copying the Tensor properly. See this code sample:

let mut chunks: Vec<Tensor> = Vec::with_capacity(num_devices);
chunks.push(x.clone());

let mut cache = self.kv_caches[0].lock();
let mask = CausalMasker.make_causal_mask_as_attn_bias(
    input_ids,
    metadata
        .as_ref()
        .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
        .unwrap_or(&*cache as &dyn PastKvLenCache),
    chunks[0].dtype(),
    self.blocks[0].attn.num_attention_heads,
)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
    x = self.mapper.map(chunks[0].clone(), block_idx)?;

When I remove the clone function, I get these errors: image

image

This is the image of the garbled output: image

Any help would be appreciated.

joshpopelka20 avatar Aug 14 '24 18:08 joshpopelka20

The issue seems to be how the clone assigns to the device. Here, you can see that the 'x' tensor is split across devices, but the 'chunk' is on the same device: image

I don't think I can use that method as is. It needs something to clone to the correct device. Here is the code snippet for reference:

for (block_idx, block) in self.blocks.iter().enumerate() {
    // x = self.mapper.map(x, block_idx)?;
    // x = self.mapper.map(&chunks[0], block_idx)?;
    println!("x device {:?}", x.device());
    println!("chunk device {:?}", chunks[0].device());
    x = self.mapper.map(chunks[0].clone(), block_idx)?;

joshpopelka20 avatar Aug 14 '24 19:08 joshpopelka20

Can you please open a PR so I can take a look at this? It's OK if it's unfinished for now, I can just take a look at what is going on right now. I'll be able to help better that way!

EricLBuehler avatar Aug 14 '24 20:08 EricLBuehler

I've add PR #684.

I restarted everything to try to see where the errors are, so it's very minimal. The biggest issue is the garbled output at the moment. After I figure that out, I can start adding the IndexOp and multiple KV caches.

joshpopelka20 avatar Aug 14 '24 20:08 joshpopelka20

Ok great! I'll take a look.

EricLBuehler avatar Aug 14 '24 20:08 EricLBuehler

I've added a commit with the Sequence Parallelism code. Also, documented a few things in the comments of the PR. Please review when you have the time.

joshpopelka20 avatar Aug 22 '24 15:08 joshpopelka20

I'm stuck with trying to implement this, and need some help. My current issue is that the blocks and chunks are on different devices. Here is the error: image

I'm not sure how to proceed; as per the algorithm, the input sequence should be split across devices. But, candle won't allow operations across devices. Any help would be appreciated.

joshpopelka20 avatar Sep 03 '24 17:09 joshpopelka20

Researching further, I'm thinking Ring Attention is incompatible with the pipeline parallelism implementation since the layers are split across devices; most likely need to have the model shared across devices so I can run the operations that are failing. Maybe, I'll need to wait for tensor parallelism #617 .

joshpopelka20 avatar Sep 04 '24 20:09 joshpopelka20

@joshpopelka20 that's probably true. The timeline for that is a bit undefined, it needs some time to have a clean implementation here, but I want to get #617 and #684 implemented for sure.

I've been focusing on things around quantization recently, tensor parallelism is going to build off of some of that infra.

EricLBuehler avatar Sep 12 '24 23:09 EricLBuehler