onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Allow Memory Efficient Attention Kernel to run when local window size is set

Open aciddelgado opened this issue 1 year ago • 2 comments

Description

This PR introduces a slight change to the handling of the Local Window Size parameter in the context of Memory Efficient Attention. Previously, setting the Local Window Size to any value other than -1 would disable Memory Efficient Attention. This update allows the kernel to operate regardless of the Local Window Size setting.

Motivation and Context

Previously, models with local attention were tricky to run on CUDA, because Flash Attention supports local window attention on hardware with sm >= 80, but lesser hardware was unsupported. With this PR, users will be able to run these models on lesser hardware, although the output may not match exactly when compared with a model properly using local attention.

The motivation behind this change stems from the challenges faced when running models with local attention on CUDA. Flash Attention, which supports local window attention, was only operable on hardware with a CUDA capability sm_80 or higher. This limitation made it difficult to utilize these models on hardware with lower sm.

With the implementation of this PR, models with local attention can now be executed on hardware with lower sm values. However, it’s important to note that the output may not precisely match that of a model utilizing local attention as intended due to the disregard of the Local Window Size setting. This update, therefore, enhances the versatility of model execution, albeit with potential variations in output.

aciddelgado avatar Jul 10 '24 16:07 aciddelgado

This causes wrong result and we shall avoid that.

How about changing memory efficient attention to support local window here to set non local elements to -inf. If change is small, we can add a patch to cutlass.

Another way is through the bias tensor, which can be act like attention mask: https://github.com/NVIDIA/cutlass/blob/56b46e2d13875b46b8f6a03f9f5ac91e2bfdc01a/examples/41_fused_multi_head_attention/kernel_forward.h#L176 That could be slower and need more memory, but shall work without changing cutlass code.

tianleiwu avatar Jul 10 '24 17:07 tianleiwu

PyTorch has implemented slide window support in efficient attention. Please take a look: https://github.com/pytorch/pytorch/blob/20b62fed21f86374b01f7d5a557a83e4d3f2d130/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L152

tianleiwu avatar Aug 28 '24 21:08 tianleiwu