flash-linear-attention
flash-linear-attention copied to clipboard
[RFC] Autotune should consider batch size and number of heads
Proposal
The optimal kernel configuration should adjust based on changes in (batch size × number of heads).
Rationale
The performance of the autotuned kernel can vary significantly when the product of (batch size × number of heads) changes, especially with different levels of parallelism determined by the batch and head dimensions.
Autotuning should also take the total sequence length into account, as the sequence length dimension provides parallelism in addition to the number of heads and batch size.
I would consider this issue, but since token length is still changing during training and reasoning, autotune for tokenlength is still worth considering