mistral.rs
mistral.rs copied to clipboard
[Feature] Implementation of multi-gpu KV cache (RingAttention)
I'll work through adding it to quantized llama first, as I know that architecture the most. Link to the paper: https://arxiv.org/abs/2310.01889
I've been researching the algorithm further, and I'm thinking I'm going to have a problem implementing this with Rust. To start, based on my understanding, I'd need to split the input into different Tensors, with each split tensor on a different GPU. Pytorch has select https://pytorch.org/docs/stable/generated/torch.select.html, but doesn't look like candle has an equivalent. Is that correct?
To split the input into different tensors, I would use the narrow method to split along different ranges. If you want something a bit more flexible, IndexOp is probably what you want.
To implement torch.select, I'd use IndexOp (.i) to extract at the given dimension.
Are there any other problems? Happy to help!
Thanks for the help! That's it at the moment.
I'm just trying to implement the first part of the algorithm, which is the sequence parallelism. As I work through it, I'll definitely have more questions; it's just going to be slow as I'm really digging deep into this for the first time.
I've been trying to implement this algorithm from the paper https://arxiv.org/pdf/2310.01889, and it really isn't working.
The KV cache isn't being split so that's a big problem, but I'm more concerned with the garbage output that I'm getting. If I do a PR, can you give it a look and let me know what looks incorrect?
Just adding a little more info.
I think the biggest problem I'm facing is that the KV cache needs to cycle between GPUs. I'm trying to do this with a for loop to get something working, but I don't think it is correct.
The device that the KV cache is using is set based on the Key and Value tensor devices. It's almost like I need a method to move the Tensors to new Devices. Here is my implementation so far:
let num_caches = self.kv_caches.len();
let mut accumulated_attention: Option<Tensor> = None;
for cache_rotation in 0..num_caches {
let cache_idx = (chunk_idx + cache_rotation) % num_caches;
let kv_cache = &self.kv_caches[cache_idx];
let mut cache = kv_cache.lock();
let device_chunk = chunk.device();
// Determine the device of the cache
let cache_device = cache.iter().find_map(|opt| {
opt.as_ref().map(|(k, _)| k.device().clone())
}).unwrap_or_else(|| device_chunk.clone());
let mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
chunk.dtype(),
self.blocks[0].attn.num_attention_heads,
)?;
let mut x = self.mapper.map(chunk.to_device(&cache_device)?, block_idx)?;
x = block.forward(
&x,
&mask.clone().map(|m| m.to_device(x.device()).unwrap()),
seqlen_offsets,
start_offsets_kernel.clone(),
block_idx,
&mut cache,
metadata
.as_mut()
.map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), &mut **metadata)),
)?;
// Accumulate attention results
if let Some(ref mut acc) = accumulated_attention {
*acc = acc.add(&x.to_device(acc.device())?)?;
} else {
accumulated_attention = Some(x);
}
Let me know if you'd like to see the full code so far in a PR. I've made most of the changes in mistralrs-core/src/models/llama.rs.
@joshpopelka20 yes if you do a PR, I can absolutely take a look.
It's almost like I need a method to move the Tensors to new Devices.
Could you use Tensor::to_device? For the KV cache to cycle between 2 devices, I would check out the Cache struct. This struct has no idea of device mapping though, so it can innately handle KV cache split on multiple GPUs.
@EricLBuehler thanks for the reply.
For:
Could you use Tensor::to_device?
I'm already using this for the chunks. From this code:
let mut x = self.mapper.map(chunk.to_device(&cache_device)?, block_idx)?;, you can see I'm setting the chunk to the kv_cache device (the chunk is just the seq_len split into num_devices, in my case, four.
for j in 0..num_devices {
let start = j * chunk_size;
let end = if j == num_devices - 1 {
seq_len
} else {
(j+ 1) * chunk_size
};
let chunk = x.i((.., start..end,..))?; // use IndexOps
let device = &self.cuda_devices[j];
chunks.push(chunk.to_device(&device)?);
}
I would check out the Cache struct.
I asked about a method to move the KV cache because that's what the algorithm does. I'm trying to do it by moving the chunks device, but I'd think if I could have a way to move the KV cache device, I'd get better results. The only method that looks close is:
fn clone_in_cache(
num_hidden_layers: usize,
cache: &mut LayerCaches,
seqs: &mut [&mut crate::sequence::Sequence],
src: SeqCache,
)
I think I need a method that takes in two kv caches and the new device, then clones the first kv cache data and moves it to the second kv cache on the new device. So basically, is there a method to clone the key and value tensors? I can use to_device for changing devices.
I think I have enough info now. Currently, I'm working through doing some instruction fine-tuning, which is taking all my time. I should be able to work on this again starting next week.
@joshpopelka20 sounds good - please feel free to open a PR so I can take a look!
I think I need a method that takes in two kv caches and the new device, then clones the first kv cache data and moves it to the second kv cache on the new device. So basically, is there a method to clone the key and value tensors? I can use to_device for changing devices.
Modifications like this would be done in the Cache struct. You could add a method to it which tells it to map the devices for its KV cache, would that work?
Modifications like this would be done in the Cache struct. You could add a method to it which tells it to map the devices for its KV cache
That's the plan. After I add that, I'll create the PR.
I've found out why I'm getting garbage output. My problem is with the Tensor operations. I have to clone 'x' to insert into the chunks Vec and clone 'chunk' in the map function. When I do this, the output isn't coherent. For some reason, that operation isn't copying the Tensor properly. See this code sample:
let mut chunks: Vec<Tensor> = Vec::with_capacity(num_devices);
chunks.push(x.clone());
let mut cache = self.kv_caches[0].lock();
let mask = CausalMasker.make_causal_mask_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
chunks[0].dtype(),
self.blocks[0].attn.num_attention_heads,
)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = self.mapper.map(chunks[0].clone(), block_idx)?;
When I remove the clone function, I get these errors:
This is the image of the garbled output:
Any help would be appreciated.
The issue seems to be how the clone assigns to the device. Here, you can see that the 'x' tensor is split across devices, but the 'chunk' is on the same device:
I don't think I can use that method as is. It needs something to clone to the correct device. Here is the code snippet for reference:
for (block_idx, block) in self.blocks.iter().enumerate() {
// x = self.mapper.map(x, block_idx)?;
// x = self.mapper.map(&chunks[0], block_idx)?;
println!("x device {:?}", x.device());
println!("chunk device {:?}", chunks[0].device());
x = self.mapper.map(chunks[0].clone(), block_idx)?;
Can you please open a PR so I can take a look at this? It's OK if it's unfinished for now, I can just take a look at what is going on right now. I'll be able to help better that way!
I've add PR #684.
I restarted everything to try to see where the errors are, so it's very minimal. The biggest issue is the garbled output at the moment. After I figure that out, I can start adding the IndexOp and multiple KV caches.
Ok great! I'll take a look.
I've added a commit with the Sequence Parallelism code. Also, documented a few things in the comments of the PR. Please review when you have the time.
I'm stuck with trying to implement this, and need some help. My current issue is that the blocks and chunks are on different devices. Here is the error:
I'm not sure how to proceed; as per the algorithm, the input sequence should be split across devices. But, candle won't allow operations across devices. Any help would be appreciated.
Researching further, I'm thinking Ring Attention is incompatible with the pipeline parallelism implementation since the layers are split across devices; most likely need to have the model shared across devices so I can run the operations that are failing. Maybe, I'll need to wait for tensor parallelism #617 .
@joshpopelka20 that's probably true. The timeline for that is a bit undefined, it needs some time to have a clean implementation here, but I want to get #617 and #684 implemented for sure.
I've been focusing on things around quantization recently, tensor parallelism is going to build off of some of that infra.