Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Add Flex Attention Monkey Patch for LLAMA

Open austin362667 opened this issue 10 months ago • 0 comments

Summary

We need flex attention for custom attentions/masks to achieve better performance (for example, shared prefix)

Two ways to enable flex attention in liger:

  1. Set the attn_implementation of ModelConfig from PyTorch sdpa/eager to flex_attention (for instance, LlamaConfig). By doing so, we'll switch config._attn_implementation to use flex attention impl.
  2. (This PR) Patch all attention impls dict in HuggingFace to use flex attention. So that we can still use original default attention key, say sdpa(however now it's flex_attention instead).

Testing Done

  • Hardware Type: <BLANK>
  • [ ] run make test to ensure correctness
  • [X] run make checkstyle to ensure code style
  • [X] run make test-convergence to ensure convergence

austin362667 avatar Jan 25 '25 17:01 austin362667