jax
jax copied to clipboard
Flash attention soft capping support
In Jax experimental pallas kernels for TPU , there is support for attn logits softcapping for paged attention but not for flash attention. If support can be added for pallas flash kernels as well, as it can then be used in pytorch xla as well as vllm implementation. Gemma 2 9b model works even with logit softcapping but 27 b doesn't.
PR for support of soft capping for Paged Attention
Pytorch xla custom kernel integration for paged attention
Need for flash attention support for running Gemma 2 with VLLM on TPUs