Does FA3 varlen func support pad between sequences?
Hi, I noticed that the varlen function takes both cu_seqlens_q/k and seqused_q/k. Does this mean it can handle padding between sequences?
For example, given [q1, q1, PAD, q2, q2, q2, PAD, q3, PAD], if I pass cu_seqlens_q = [0, 3, 7, 9] and seqused_q = [2, 3, 1], would that be correct?
Yes that's right. But note that the kernel won't touch the output memory of the padding tokens, so the output for the padding tokens will be uninitialized (it could contain arbitrary values, including Inf or NaN). If you use this output for the next operation (e.g. matmul) you can get inf or NaN gradients.
Thanks for your explanation! I have one thing still not quite sure about. Do you mean that after calling the kernel, I need to manually zero out the padding parts of the output? Something like:
o = fwd(q, k, v);
set_zeros(0);
...
dq, dk, dv = bwd(do);
set_zeros(dq);
set_zeros(dk);
set_zeros(dv);
Also, why it's designed this way?
If you need to, you can zero out parts that are not initialized in the output and grad (i.e. padding tokens) yourself. This API isn't really designed for padding tokens (though you can use it that way). It's meant for cases like context parallel where each iteration does attention on only part of the input. We aim for efficiency: we do not touch memory that we don't need to touch. I personally think padding tokens should not exist.
I think only minor modifications to the kernel if needed. A new parameter actual_seq_qk_padding (cu_seqlens_q/k[bid+1] - cu_seqlens_q/k[bid]) can be added in the the epilogue for preds; Meanwhile, apply actual_seq_q (seqused_q[bid]) to the mask function, and add mask for the row direction (or called seq_q) in the forward. if (row >= actual_seq_q ) {acc_s[i] = -inf}
If your padding tokens is only on Q and not on K & V, you can just pretend those are legit tokens and don't need seqused_q right? Then the output for the padding tokens will be written.
This works for forward, there will be issues with the backward. by the way, If assigned zero values to the padding region for qkv, we only need cu_seq_len to get the expected result. seqused_qk loses its design value.
Therefore, I think we should modify the mask function and the preds in the epilogue (almost no additional performance loss), rather than initializing values for the padding part. This will also result in better performance.
In our case, Q, K, and V all require padding, and the padded parts are not zeros. Therefore, without the modification that JerryChen mentioned, we would have to manually set the output of the padded positions to zero ourselves, which is not what we want.
As for the comment “I personally think padding tokens should not exist,” our case comes from a scenario in Transformer Engine, where we currently need to use FusedAttention (see: https://github.com/NVIDIA/TransformerEngine/blob/v2.9/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L663-L670).
If you have a kernel that zeros out the padding tokens (sth like out[padding_indices, :, :] = 0.0) then you could apply that to the output and the incoming gradient (dout) and dQ. Then gradient wrt Q, K, V would be correct.
We have concerns that using a standalone zeroing kernel is suboptimal from a performance perspective. Incorporating the zeroing logic into the kernel, following JerryChen’s approach, should preserve performance while reducing overhead, and it would also enhance FA3’s applicability and robustness across more cases. Given that the interface is already exposed, making it more complete feels like a natural improvement.
Would you be supportive if we prepare a patch with this kernel change and open a PR for review?