flash-attention
flash-attention copied to clipboard
Implemented Flash Attention2 for Intel GPU hardware
Hi,
@tridao, We'd like to add intel backend for flash-attn. And this PR implements the mha_fwd function for intel hardware, such as Intel(R) Arc(TM) B580 Graphics(BMG) and Intel® Core™ Ultra 7 Processor(Lunar Lake). Other APIs, eg mha_bwd/varlen_mha, are working in progress (WIP). The C++ API follows the same design as the original CUDA/ROCm implementation, and the Python interface is reused without changes.
This implementation works seamlessly with stock PyTorch, no third-party dependencies. And, it does not affect the existing support for NVIDIA or ROCm hardware.
cc @pengzhao-intel
@tridao any thought or suggestion for this PR?
Thanks for this contribution! What happens if user calls a function that's not currently supported (e.g. paged KV or varlen)?
Thanks for this contribution! What happens if user calls a function that's not currently supported (e.g. paged KV or varlen)?
Currently, all checks are handled on the kernel side, and an error will be raised if a feature is not supported. We are actively working on adding support for these features.
any word on when this will be merged? very excited for this! I use Intel GPUs
@Wanzizhu this is great work !
I just went through a couple of files and saw some minor improvements , although did not go through the entire change.
(Tagging others for FYI : @jgong5 @rbiessy @mehdi-goli @alcpz)
This is excellent work!! And looking forward to using this feature in Intel Client GPUs.
I would guess this need a rebase now. Or are you waiting on paged attention to merge?