Liger-Kernel
Liger-Kernel copied to clipboard
Add Flex Attention Monkey Patch for LLAMA
Summary
We need flex attention for custom attentions/masks to achieve better performance (for example, shared prefix)
Two ways to enable flex attention in liger:
- Set the
attn_implementationofModelConfigfrom PyTorchsdpa/eagertoflex_attention(for instance, LlamaConfig). By doing so, we'll switchconfig._attn_implementationto use flex attention impl. - (This PR) Patch all attention impls dict in HuggingFace to use flex attention. So that we can still use original default attention key, say
sdpa(however now it'sflex_attentioninstead).
Testing Done
- Hardware Type: <BLANK>
- [ ] run
make testto ensure correctness - [X] run
make checkstyleto ensure code style - [X] run
make test-convergenceto ensure convergence