cutlass
cutlass copied to clipboard
[QST] Questions about ex77 blackwell Flash attention
- Why do we need to include a first batch of padding in example 77? What should the padding size be?
- How does example 77 handle variable sequence length with TMA? I was skimming the code and expected to see some tensormap replace calls but don't see any.
@v0i0 , would you please take a look?
Rawn, nice to see you here again.
Hey Rawn,
There are a few ways to implement variable sequence length, and this seemed simplest when I wrote this - I'd likely handle it differently today. The first basic question for var seq len support is whether we want our implementation to be correct no matter what data is in the tensor (think inf or nan), or whether we have control over that (this is often the case). The second question is whether we have extra memory, either padding in front or in the back.
You are correct that if we want to make no assumptions about the data, and we can't have any padding, then group-gemm-like tensormap.replace is the only option.
If we can have padding, but make no assumption on the data, we can do what is implemented here: We have one "offset dimension" which we use to shift to the beginning of the sequence (with box size 1), and one "masking dimension", which we use to offset s.t. the TMA-based masking is correct. We then initialize the TMA coords such that their mutual effect is neutral. So instead of having a TMA descriptor of (seqlen, ...) with box (blocking, ...), we have a TMA descriptor of (max_seqlen, max_seqlen) with box (1, blocking), and then index it as (seqlen - max_seqlen, max_seqlen - seqlen). But notice how the first term is negative. So we offset the base ptr by max_seqlen * sequence_stride, and index by (seqlen, max_seqlen - seqlen). Since TMA descriptors need the entire memory area they address to be addressable, we thus need padding of max_seqlen * seqlen_stride (which is often equal to batch_stride).
Lastly, if we can assume no nan or inf in the data, you could also implement this with only masking during softmax, without all the offset logic or tensormap.replace. Zeroing out the P tensor where appropriate would be sufficient, I believe. This is probably the way I would recommend today.
This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.
Hi @v0i0, thanks for writing this kernel!
We are finding this very promising for my team's use-case, and we have integrated it with pytorch using python bindings. Variable sequence length case is critical for us, but the front padding is leading to a noticeable perf hit. We can't figure out an efficient way to add the padding to the pytorch tensor avoiding DtoD memcpy. From your comment, it seems like it should be possible to remove the padding, and I wanted to understand which way do you recommend for our use-case.
IIUC, you are suggesting two ways to remove the padding:
- tensormap.replace (on device modification of the descriptor)
- Assume no NaN in data and mask P (We can safely assume no NaN in our case)
Few questions:
-
Do you have any recommendations on which way would be easier to implement? Also, how will we store using (2) without using tensormap.replace?
-
We are a bit confused by the statement “Since TMA descriptors need the entire memory area they address to be addressable” - since cuda docs state that out-of-bound accesses are automatically zeroed out. Wondering if we can keep indexing as (seqlen - max_seqlen, max_seqlen - seqlen) and remove the padding?
-
Do you have any plans to add var len support for the backward pass?
Our team is quite interested in this kernel, and we might have bandwidth to make these changes, but might require some guidance. TIA!
cc: @y-sq
[Comment edited based on my updated understanding after some internal discussion]
Hey @devashishshankar!
I would expect masking P to be much easier, i.e. (2), however it will not fix your problem with the output tensor. There, you'd either have to use tensormap.replace, some sort of padding (e.g. to tilesize), or a simt epilogue.
The TMA assumes that all memory that is encoded in the TMA descriptor is mapped (this may not be spelled out explicitly since you are not going to easily run into it, but it is true at the least on Blackwell; feel free to try it out). Note that your suggestions means that (seqlen-max_seqlen) is always negative, and will trigger masking to zero from the tma. That's why we fold the max_seqlen offset into the pointer instead.
we can assume no nan or inf in the data, you could also implement this with only masking during softmax, without all the offset logic or tensormap.replace. Zeroing out the P tensor where appropriate would be sufficient, I believe. This is probably the way I would recommend today.
Can someone explain it in more detail? it's also abstract to me!
@liuqi123123
If no nan or inf in Q and K, Q@K (contraction along head_dim) always produce valid P matrix in seqlen boundary, for OOB values they non NaN or Inf.
masking during softmax means adding large negative values to the OOB results pre-softmax. The exponentiation (safe softmax) will make them zero.
P@V for OOB seqlen position, P contains zeros. If V has no NaN or Inf. P@V always produce a valid result.
In practice, assumption on OOB values are very bad. Asking the users for padding their tensor is also an extremely bad design choice!
This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.