iree icon indicating copy to clipboard operation
iree copied to clipboard

[LinalgExt] Initial support for aten::flex_attention and rewrite to linalg_ext.attention

Open keshavvinayak01 opened this issue 2 months ago • 7 comments

This PR adds comprehensive support for converting PyTorch's torch::hop_flex_attention operation to IREE's linalg_ext.attention operation. The implementation includes proper handling of score modification functions, mask functions, and correct softmax computation using base-e exponential (exp) instead of base-2 exponential (exp2).

Changes

  • Added FlexAttentionOpConversion pattern to convert torch.hop_flex_attention to linalg_ext.attention
  • Modified AttentionOp to pass iteration indices as block arguments to the region
  • Enables score modification functions to access batch, head, query sequence, and key sequence indices
  • Added lit tests for both LSE (log-sum-exp) and non-LSE cases
  • Simplified region handling by using block arguments instead

Fixes:

  • PyTorch's flex attention already supplies the correct scale for base-e softmax. This commit fixes the computation to use exp instead of exp2:
  • The use_exp2 flag is mostly unused in dialect conversions and passes, I presume it's used as a KernelOption. The changes here will not modify the default behavior.

Testing:

  • I ran the entire flex_attention_hop implementation with randomised input tensors, (Also see torch-mlir) through aot.export and compared against eager mode, and I noticed no accuracy losses (On CPU)
  • Command: iree-compile --iree-stream-partitioning-favor=min-peak-memory --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu --iree-llvmcpu-target-triple=x86_64-pc-linux-elf --iree-llvmcpu-debug-symbols=false ../torch-mlir/exported_ir.mlir -o ./flex_attention_cpu.vmfb

keshavvinayak01 avatar Oct 28 '25 07:10 keshavvinayak01

I can review after you fix all build failures. It looks not ready to me when CI is so red.

hanhanW avatar Nov 12 '25 22:11 hanhanW

Please click Re-request item or ping me on discord when it is ready for the review. Thanks! :)

image

hanhanW avatar Nov 14 '25 23:11 hanhanW

@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?

keshavvinayak01 avatar Nov 21 '25 07:11 keshavvinayak01

@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?

Yeah, first you should make a change to the AttentionOp itself which will get the block arguments support. Then you can make the changes to the lowering? The part adding block arguments here needs to be flushed out a bit more. Do we really need the change to get block arguments in? One option is to not change block arguments but use linalg_ext.index (similar to `linalg.index). That might be easier to land without too many breakages.

MaheshRavishankar avatar Nov 25 '25 00:11 MaheshRavishankar

@MaheshRavishankar Thanks for the review, I'll incorporate some of the suggested changes. As for the PR bifurcation, It's a bit difficult to isolate the lowering and the required changes. The lowering relies on block arguments being supported which is a change to the AttentionOp interface itself. Could you be more specific about how you see this split being done?

Yeah, first you should make a change to the AttentionOp itself which will get the block arguments support. Then you can make the changes to the lowering? The part adding block arguments here needs to be flushed out a bit more. Do we really need the change to get block arguments in? One option is to not change block arguments but use linalg_ext.index (similar to `linalg.index). That might be easier to land without too many breakages.

Previously, that's how the implementation looked like, I added support for linalg_ext.index ops inside the attention region. That worked, but bloated up the IR, so I thought this was a simpler solution.

In any case, It's easier to support index ops inside the attention region (With small modifications to the ::verify method), so I'll switch back to it, because it's less likely to break things in the future.

Regardless, I'll have to add tiling support for these index ops in the attention region, which will be a follow up PR.

keshavvinayak01 avatar Nov 25 '25 07:11 keshavvinayak01

Why would you need to add support for index ops while tiling. You shouldn't need to

MaheshRavishankar avatar Nov 25 '25 19:11 MaheshRavishankar

I had discussed this with @Groverkss earlier, the goal was to add support for iree_linalg_ext.index in the attention op region to enable index-based logic (e.g., masking for RPE). At the time, he pointed out that using the mask operand in AttentionOp was a hack.

While most tiled ops don’t require index in the region, this op does. It relies on index-dependent semantics like relative position encoding, which can’t be expressed correctly without index after tiling.

keshavvinayak01 avatar Nov 26 '25 07:11 keshavvinayak01