[LinalgExt] Initial support for aten::flex_attention and rewrite to linalg_ext.attention
This PR adds comprehensive support for converting PyTorch's torch::hop_flex_attention operation to IREE's linalg_ext.attention operation. The implementation includes proper handling of score modification functions, mask functions, and correct softmax computation using base-e exponential (exp) instead of base-2 exponential (exp2).
Changes
- Added
FlexAttentionOpConversionpattern to converttorch.hop_flex_attentiontolinalg_ext.attention - Modified
AttentionOpto pass iteration indices as block arguments to the region - Enables score modification functions to access batch, head, query sequence, and key sequence indices
- Added lit tests for both LSE (log-sum-exp) and non-LSE cases
- Simplified region handling by using block arguments instead
Fixes:
- PyTorch's flex attention already supplies the correct scale for base-e softmax. This commit fixes the computation to use
expinstead ofexp2: - The use_exp2 flag is mostly unused in dialect conversions and passes, I presume it's used as a KernelOption. The changes here will not modify the default behavior.
Testing:
- I ran the entire flex_attention_hop implementation with randomised input tensors, (Also see torch-mlir) through
aot.exportand compared against eager mode, and I noticed no accuracy losses (On CPU) - Command:
iree-compile --iree-stream-partitioning-favor=min-peak-memory --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-triple=x86_64-pc-linux-elf --iree-llvmcpu-debug-symbols=false ../torch-mlir/exported_ir.mlir -o ./flex_attention_cpu.vmfb
I can review after you fix all build failures. It looks not ready to me when CI is so red.
Please click Re-request item or ping me on discord when it is ready for the review. Thanks! :)
@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?
@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?
Yeah, first you should make a change to the AttentionOp itself which will get the block arguments support. Then you can make the changes to the lowering? The part adding block arguments here needs to be flushed out a bit more. Do we really need the change to get block arguments in? One option is to not change block arguments but use linalg_ext.index (similar to `linalg.index). That might be easier to land without too many breakages.
@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?
Yeah, first you should make a change to the
AttentionOpitself which will get the block arguments support. Then you can make the changes to the lowering? The part adding block arguments here needs to be flushed out a bit more. Do we really need the change to get block arguments in? One option is to not change block arguments but uselinalg_ext.index(similar to `linalg.index). That might be easier to land without too many breakages.
Previously, that's how the implementation looked like, I added support for linalg_ext.index ops inside the attention region. That worked, but bloated up the IR, so I thought this was a simpler solution.
In any case, It's easier to support index ops inside the attention region (With small modifications to the ::verify method), so I'll switch back to it, because it's less likely to break things in the future.
Regardless, I'll have to add tiling support for these index ops in the attention region, which will be a follow up PR.
Why would you need to add support for index ops while tiling. You shouldn't need to
I had discussed this with @Groverkss earlier, the goal was to add support for iree_linalg_ext.index in the attention op region to enable index-based logic (e.g., masking for RPE). At the time, he pointed out that using the mask operand in AttentionOp was a hack.
While most tiled ops don’t require index in the region, this op does. It relies on index-dependent semantics like relative position encoding, which can’t be expressed correctly without index after tiling.