Support for packed layout with paged attention
Hi all,
I've been playing with cudnn-frontend to test the Flash Attention kernel. Overall, it's easy to use and fast but I've come across a limitation that I don't really understand.
It seems that the kernel can't be used in paged mode with packed tensors. This is something that other paged attention kernels support (and it makes a big difference in terms of performance as well as tokens can be batched per sequence).
So two questions about that:
- Is it a limitation only in
cudnn-frontend? Because I couldn't find in the backend doc such a limitation - Are there plans to add that feature in the future ?
Hi @Corendos ,
Thanks for your question. We are actually enabling it in the upcoming cuDNN frontend release. It's roughly a week out, so if you'd like to enable it earlier, you can just comment out L336 in scaled_dot_prodcut_flash_attention.h, and manually verify that you have cuDNN v9.7 as the backend cuDNN version.
Please note that in this case only Q will have packed support. In theory packing could be combined with paged K and V caches: it would be the page tables themselves that would be packed then. However the amount of compression you get from this is minimal. But please let us know if you disagree or if you see other good use cases for that!
Hey @nvmbreughe, thanks for the quick answer !
Funny that you suggest that, because that's exactly what I tried ! š I was just not sure about the correctness of the output.
To give more context, we are currently working on a ML framework that uses XLA under the hood (through the PJRT abstraction). XLA has some support for cudnn flash attention, just not the paged version so we hacked support. It was working great for decode but was quite slow for prefill due to the naive approach, so I was wondering if cudnn supported packed + paged as other kernels supported it. So overall that's great news !
There is still a small issue on XLA side, I'm faced with a CUDA error about misalignment. I tweaked the cudnn frontend samples in order to reproduce but the error doesn't trigger. The log output of cudnn is the same in both cases, so I guess that's an issue in XLA, I just wanted to see if that rang a bell ?
As for :
In theory packing could be combined with paged K and V caches: it would be the page tables themselves that would be packed then. However the amount of compression you get from this is minimal. But please let us know if you disagree or if you see other good use cases for that!
I was mainly looking for Q packed support so I don't think that's really required to compress the page tables.
Cheers š
Happy to help, @Corendos.
Funny that you suggest that, because that's exactly what I tried !
Ha, great! That will work as long as cuDNN backend is at least v9.7.
There is still a small issue on XLA side, I'm faced with a CUDA error about misalignment. I tweaked the cudnn frontend samples in order to reproduce but the error doesn't trigger. The log output of cudnn is the same in both cases, so I guess that's an issue in XLA, I just wanted to see if that rang a bell ?
Not sure. compute-sanitizer "may" tell you more, so maybe try running it through there? Do you have a way to get the starting addresses of each tensor XLA allocated? Does it happen with paged caches only?
Not sure. compute-sanitizer "may" tell you more, so maybe try running it through there? Do you have a way to get the starting addresses of each tensor XLA allocated? Does it happen with paged caches only?
Ok, so this was a mistake on my side. Due to the way XLA does its stuff the UID used when building the graph and executing it can be different. I introduced a mismatch so the wrong tensors were given.
Ha, great! That will work as long as cuDNN backend is at least v9.7.
About that, I think I discovered a bug. In the documentation, I understand that in the case of packed layout, Q first dimension (let's call it T) can be different than the batch size B.
However, when I try to build a graph with such a difference, I get an error. Here is the CUDNN logs:
[cudnn_frontend]
{"context":{"compute_data_type":"FLOAT","intermediate_data_type":"FLOAT","io_data_type":"HALF","name":"","sm_count":-1},"cudnn_backend_version":"9.7.1","cudnn_frontend_version":11000,"json_version":"1.0","nodes":[{"alibi_mask":false,"attn_scale_value":"3DB504F3","diagonal_alignment":"TOP_LEFT","dropout_probability":null,"inputs":{"K":"container_K","Page_table_K":"page_table_k","Page_table_V":"page_table_v","Q":"Q","SEQ_LEN_KV":"seq_kv","SEQ_LEN_Q":"seq_q","V":"container_V"},"is_inference":true,"left_bound":null,"max_seq_len_kv":4096,"name":"flash_attention","outputs":{"O":"flash_attention::O"},"padding_mask":true,"right_bound":null,"tag":"SDPA_FWD"}],"tensors":{"Q":{"data_type":null,"dim":[17,32,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"Q","pass_by_value":null,"reordering_type":"NONE","stride":[4096,128,128,1],"uid":1,"uid_assigned":true},"container_K":{"data_type":null,"dim":[32768,8,16,128],"is_pass_by_value":false,"is_virtual":false,"name":"container_K","pass_by_value":null,"reordering_type":"NONE","stride":[16384,128,1024,1],"uid":2,"uid_assigned":true},"container_V":{"data_type":null,"dim":[32768,8,16,128],"is_pass_by_value":false,"is_virtual":false,"name":"container_V","pass_by_value":null,"reordering_type":"NONE","stride":[16384,128,1024,1],"uid":3,"uid_assigned":true},"flash_attention::O":{"data_type":null,"dim":[17,32,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"flash_attention::O","pass_by_value":null,"reordering_type":"NONE","stride":[4096,128,128,1],"uid":4,"uid_assigned":true},"page_table_k":{"data_type":"INT32","dim":[16,1,256,1],"is_pass_by_value":false,"is_virtual":false,"name":"page_table_k","pass_by_value":null,"reordering_type":"NONE","stride":[256,256,1,1],"uid":9,"uid_assigned":true},"page_table_v":{"data_type":"INT32","dim":[16,1,256,1],"is_pass_by_value":false,"is_virtual":false,"name":"page_table_v","pass_by_value":null,"reordering_type":"NONE","stride":[256,256,1,1],"uid":10,"uid_assigned":true},"seq_kv":{"data_type":"INT32","dim":[17,1,1,1],"is_pass_by_value":false,"is_virtual":false,"name":"seq_kv","pass_by_value":null,"reordering_type":"NONE","stride":[1,1,1,1],"uid":8,"uid_assigned":true},"seq_q":{"data_type":"INT32","dim":[17,1,1,1],"is_pass_by_value":false,"is_virtual":false,"name":"seq_q","pass_by_value":null,"reordering_type":"NONE","stride":[1,1,1,1],"uid":7,"uid_assigned":true}}}
[cudnn_frontend] INFO: Validating SDPANode flash_attention...
[cudnn_frontend] INFO: Validating SDPANode flash_attention...
[cudnn_frontend] INFO: Inferrencing properties for Scaled_dot_product_flash_attention node flash_attention...
[cudnn_frontend] INFO: Validating PagedCacheLoadNode paged_k_cache_operation...
[cudnn_frontend] INFO: Inferrencing properties for matmul node bmm1...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node attn_scale...
[cudnn_frontend] INFO:attn_scale::OUT_0 stride computed from bmm1::C
[cudnn_frontend] INFO: Inferrencing properties for pointwise node gen_row_idx_padding...
[cudnn_frontend] INFO:gen_row_idx_padding::OUT_0 stride computed from attn_scale::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node gen_col_idx_padding...
[cudnn_frontend] INFO:gen_col_idx_padding::OUT_0 stride computed from attn_scale::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node lt_row_sq_padding...
[cudnn_frontend] INFO:lt_row_sq_padding::OUT_0 stride computed from gen_row_idx_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node lt_col_skv_padding...
[cudnn_frontend] INFO:lt_col_skv_padding::OUT_0 stride computed from gen_col_idx_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node and_row_col_padding...
[cudnn_frontend] INFO:and_row_col_padding::OUT_0 stride computed from lt_col_skv_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node select_padding...
[cudnn_frontend] INFO:select_padding::OUT_0 stride computed from and_row_col_padding::OUT_0
[cudnn_frontend] INFO: Validating SoftmaxNode softmax...
[cudnn_frontend] INFO: Inferrencing properties for Softmax node softmax.
[cudnn_frontend] INFO: Inferrencing properties for reduction node M...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node sub...
[cudnn_frontend] INFO:sub_M stride computed from select_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node exp...
[cudnn_frontend] INFO:exp_sub_M stride computed from sub_M
[cudnn_frontend] INFO: Inferrencing properties for reduction node sum...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node log...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node add...
[cudnn_frontend] INFO: stride computed from log::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node div...
[cudnn_frontend] INFO: stride computed from exp_sub_M
[cudnn_frontend] INFO: Validating PagedCacheLoadNode paged_v_cache_operation...
[cudnn_frontend] INFO: Inferrencing properties for matmul node bmm2...
[cudnn_frontend] INFO: Creating cudnn tensors for node named 'flash_attention':
[cudnn_frontend] INFO: Creating Backend Tensor named 'attn_scale::IN_1' with UID 5
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["FLOAT"] Id: 5 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 1,1,1,1 ] Str [ 1,1,1,1 ] isVirtual: 0 isByValue: 1 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] INFO: Creating Backend Tensor named 'container_V' with UID 3
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["HALF"] Id: 3 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 32768,8,16,128 ] Str [ 16384,128,1024,1 ] isVirtual: 0 isByValue: 0 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] INFO: Creating Backend Tensor named 'container_K' with UID 2
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["HALF"] Id: 2 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 32768,8,16,128 ] Str [ 16384,128,1024,1 ] isVirtual: 0 isByValue: 0 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] INFO: Creating Backend Tensor named 'Q' with UID 1
[cudnn_frontend] INFO: Creating Backend Tensor named 'ragged_offset_q' with UID 12
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["INT32"] Id: 12 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 17,1,1,1 ] Str [ 1,1,1,1 ] isVirtual: 0 isByValue: 0 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] ERROR: CUDNN_BACKEND_TENSOR_DESCRIPTOR cudnnFinalize failedptrDesc->finalize() cudnn_status: CUDNN_STATUS_BAD_PARAM. ["CUDNN_BACKEND_API_FAILED"] because (e.getCudnnStatus() != CUDNN_STATUS_SUCCESS) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/cudnn_interface.h:86
[cudnn_frontend] ERROR: detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/node_interface.h:395
[cudnn_frontend] ERROR: create_cudnn_tensors_node(uid_to_backend_tensors, potential_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/node_interface.h:242
[cudnn_frontend] ERROR: sub_node->create_cudnn_tensors_subtree(uid_to_backend_tensors, potential_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/node_interface.h:244
[cudnn_frontend] ERROR: create_cudnn_tensors_subtree(uid_to_tensors, start_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/graph_interface.h:566
[cudnn_frontend] ERROR: this->build_operation_graph(handle) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/graph_interface.h:1502
If you want to reproduce, here is a gist containing the modified sample I used: https://gist.github.com/Corendos/ab4712e1c53b72ff114b108635bc5c1f
I saw that there was a recent release of CuDNN backend (9.8.0), do you know by any chance if it was fixed in this version ?
After a bit more investigation, it seems that the error originates when the Q Tensor is built here:
https://github.com/NVIDIA/cudnn-frontend/blob/91b7532f3386768bba4f444ee7672b497f34da8a/include/cudnn_frontend_Tensor.h#L502
I also tried with the 9.8.0 release of CuDNN backend but the error is still triggering.
I'll try to see if the error also happens when the kernel is used in a non-paged way and keep you posted.
Just found out that you can enable log in CuDNN backend with CUDNN_LOGDEST_DBG=stdout CUDNN_LOGLEVEL_DBG=3 and here is the output:
The interesting part being:
I! CuDNN (v90800 87) function cudnnBackendFinalize() called:
i! descriptor: type=CUDNN_BACKEND_TENSOR_DESCRIPTOR:
i! type: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[256,32,1,128];
i! strideA: type=int; val=[4096,128,128,1];
i! uid: type=int64_t; val=1;
i! alignmentInBytes: type=int64_t; val=16;
i! isVirtual: type=bool; val=false;
i! isByVal: type=bool; val=false;
i! Time: 2025-03-07T09:16:49.176468 (0d+0h+0m+1s since start)
i! Process=354365; Thread=354365; GPU=NULL; Handle=NULL; StreamId=NULL.
I! CuDNN (v90800 87) function cudnnBackendFinalize() called:
i! status: type=cudnnStatus_t; val=CUDNN_STATUS_BAD_PARAM (2000);
i! Time: 2025-03-07T09:16:49.176476 (0d+0h+0m+1s since start)
i! Process=354365; Thread=354365; GPU=NULL; Handle=NULL; StreamId=NULL.
E! CuDNN (v90800 87) function cudnnBackendFinalize() called:
e! Info: Traceback contains 3 message(s)
e! Error: CUDNN_STATUS_BAD_PARAM; Reason: CUDNN_ATTR_TENSOR_RAGGED_OFFSET_DESC ragged dim should match dim value + 1 of original tensor. All other offset dim values should be singleton. at: offset_dimA[dim] != this->_dimA[dim] + 1 && offset_dimA[dim] != 1
e! Error: CUDNN_STATUS_BAD_PARAM; Reason: finalize_internal()
e! Error: CUDNN_STATUS_BAD_PARAM; Reason: ptrDesc->finalize()
e! Time: 2025-03-07T09:16:49.176482 (0d+0h+0m+1s since start)
e! Process=354365; Thread=354365; GPU=NULL; Handle=NULL; StreamId=NULL.
So this seems like a potential issue, as the purpose of packed tensors is to allow a Q tensor with more than one token per batch dimension.
Is there a way to report this directly to the CuDNN backend team? Iād love to help if needed, please let me know how I can contribute! š
@Corendos thanks for reporting the bug. @nvmbreughe can you create a NVBug next week?
@Corendos would you be interested in connecting to discuss your use cases?
Also please note that we also tried noping out the assertion inside cudnn_graph.so, and unfortunately it fails later.
Hi @Corendos / @steeve ,
From the sample, looks like there is a mismatch in us documenting the ragged offset and Q tensor.
Looking at the multiple tensors,
"Q" -> "dim":[17,32,1,128]
"ragged_offset_q" -> dim [17,1,1,1]
"page_table_k":{"data_type":"INT32","dim":[16,1,256,1]
The expectation is
Q is B,H,S,D
Ragged offset is B+1,1,1,1
Page_table is B, ...
The reason the Ragged offset is B+1, is because the first offset starts at 0.
Hope that makes sense
Regards, Anerudhan
Hi all !
@Corendos would you be interested in connecting to discuss your use cases?
@mnicely I would love to, how do you want to proceed ?
The expectation is
Ragged offset is B+1,1,1,1 Page_table is B, ...The reason the Ragged offset is B+1, is because the first offset starts at 0.
It's true that the documentation of the kernel says that, but there is also the part about Supported Tensor Layout that introduces a new name for a dimension. It says that in the case of packed layout, Q has a shape called THD, with T = sum(seq_len) and this allows the batch size and this T to be different.
Also, in my understanding, forcing the ragged offset to be of size B + 1 and Q to be of size B is not different than the non-packed layout. In that case, you have a 1-to-1 mapping between Q "slots" and offsets and it's equivalent to non-packed. The usecase I see for this kernel (and also how other popular paged attention kernels work) is to allow prefilling, where you treat multiple input tokens per batch size. In other words, you have T >> B and it seems to be currently impossible.
Hi @Corendos ,
Apologies for the delay. Ragged tensors/THD are mainly used in cases where the number of tokens varies for each sequence in your batch. This is useful for prefill with Q, and for both prefill and decode when K/V is not paged.
While the actual dimensions are the same as before ([B,H,S,D]), you don't need to allocate space for the entireB*H*S*D block. This can save memory if many sequences have < S tokens. Do note that since the layout is packed, sequences are no longer aligned on the batch dimension, and thus need a ragged offset to indicate where on the B*S dimension each sequence starts.
Does this help answer your question or am I missing something?
Hey @nvmbreughe no worries
I think I understand. What you are saying is that you need to send bogus dimensions values (B,H,S,D) and that they will be ignored by the kernel because it will only read offset from the ragged tensor?
I kind of expected cudnn to accept a tensor with T,H,D dimensions (because that's the actual/real dimension of the tensor) because that's how other popular paged attention kernel work š