flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Does FA3 varlen func support pad between sequences?

Open QiZhangNV opened this issue 1 month ago • 9 comments

Hi, I noticed that the varlen function takes both cu_seqlens_q/k and seqused_q/k. Does this mean it can handle padding between sequences?
For example, given [q1, q1, PAD, q2, q2, q2, PAD, q3, PAD], if I pass cu_seqlens_q = [0, 3, 7, 9] and seqused_q = [2, 3, 1], would that be correct?

QiZhangNV avatar Nov 06 '25 03:11 QiZhangNV

Yes that's right. But note that the kernel won't touch the output memory of the padding tokens, so the output for the padding tokens will be uninitialized (it could contain arbitrary values, including Inf or NaN). If you use this output for the next operation (e.g. matmul) you can get inf or NaN gradients.

tridao avatar Nov 06 '25 16:11 tridao

Thanks for your explanation! I have one thing still not quite sure about. Do you mean that after calling the kernel, I need to manually zero out the padding parts of the output? Something like:

o = fwd(q, k, v);
set_zeros(0);
...
dq, dk, dv = bwd(do);
set_zeros(dq);
set_zeros(dk);
set_zeros(dv);

Also, why it's designed this way?

QiZhangNV avatar Nov 07 '25 02:11 QiZhangNV

If you need to, you can zero out parts that are not initialized in the output and grad (i.e. padding tokens) yourself. This API isn't really designed for padding tokens (though you can use it that way). It's meant for cases like context parallel where each iteration does attention on only part of the input. We aim for efficiency: we do not touch memory that we don't need to touch. I personally think padding tokens should not exist.

tridao avatar Nov 07 '25 02:11 tridao

I think only minor modifications to the kernel if needed. A new parameter actual_seq_qk_padding (cu_seqlens_q/k[bid+1] - cu_seqlens_q/k[bid]) can be added in the the epilogue for preds; Meanwhile, apply actual_seq_q (seqused_q[bid]) to the mask function, and add mask for the row direction (or called seq_q) in the forward. if (row >= actual_seq_q ) {acc_s[i] = -inf}

NVIDIA-JerryChen avatar Nov 07 '25 02:11 NVIDIA-JerryChen

If your padding tokens is only on Q and not on K & V, you can just pretend those are legit tokens and don't need seqused_q right? Then the output for the padding tokens will be written.

tridao avatar Nov 07 '25 02:11 tridao

This works for forward, there will be issues with the backward. by the way, If assigned zero values to the padding region for qkv, we only need cu_seq_len to get the expected result. seqused_qk loses its design value.

Therefore, I think we should modify the mask function and the preds in the epilogue (almost no additional performance loss), rather than initializing values for the padding part. This will also result in better performance.

NVIDIA-JerryChen avatar Nov 07 '25 02:11 NVIDIA-JerryChen

In our case, Q, K, and V all require padding, and the padded parts are not zeros. Therefore, without the modification that JerryChen mentioned, we would have to manually set the output of the padded positions to zero ourselves, which is not what we want.

As for the comment “I personally think padding tokens should not exist,” our case comes from a scenario in Transformer Engine, where we currently need to use FusedAttention (see: https://github.com/NVIDIA/TransformerEngine/blob/v2.9/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L663-L670).

QiZhangNV avatar Nov 07 '25 04:11 QiZhangNV

If you have a kernel that zeros out the padding tokens (sth like out[padding_indices, :, :] = 0.0) then you could apply that to the output and the incoming gradient (dout) and dQ. Then gradient wrt Q, K, V would be correct.

tridao avatar Nov 07 '25 05:11 tridao

We have concerns that using a standalone zeroing kernel is suboptimal from a performance perspective. Incorporating the zeroing logic into the kernel, following JerryChen’s approach, should preserve performance while reducing overhead, and it would also enhance FA3’s applicability and robustness across more cases. Given that the interface is already exposed, making it more complete feels like a natural improvement.

Would you be supportive if we prepare a patch with this kernel change and open a PR for review?

QiZhangNV avatar Nov 07 '25 13:11 QiZhangNV