flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Variable memory allocation with varlen kernels

Open CodeCreator opened this issue 1 year ago • 1 comments
trafficstars

Hey!

I'm a big fan of the flash attention varlen kernels, and they are fantastic for saving the memory & compute of pad tokens.

When training with fixed batches of N tokens, I've noticed that the memory will vary substantially depending on cu_seqlens and max_seqlen. I suspect this is due to the allocation of the softmax_lse in a padded format ([num_seqs, num_heads, max_seqlen]) in https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L655, which introduces padding again for shorter sequences.

I wonder how feasible it would be to store the softmax_lse in an unpadded format ([num_tokens, num_heads]) for the varlen kernel (at least for storing activations for the backwards pass).

Do you think that this would achieve approximately constant memory use when training with batches of fixed number of tokens? Thank you!

CodeCreator avatar Jun 26 '24 21:06 CodeCreator

There's a PR for that, will be merged soon.

tridao avatar Jun 26 '24 21:06 tridao