metal-flash-attention
metal-flash-attention copied to clipboard
Faster alternative to Metal Performance Shaders
Hi thanks a lot for this really cool library. I've taken it for a spin on M3, and saw on it that GEMM seemed to perform better on MPS. Do...
Hi! I was wondering if you would be interested in adding `bf16` support to MFA or at least the GEMM kernels? For `mlx` Apple defined a custom type: https://github.com/ml-explore/mlx/blob/76c919b4ecf0cccaa1cfef214d12be0ad71485cc/mlx/backend/metal/kernels/bf16.h (MIT...
For grouped query available in Mistral, it is requires a different H (conventionally H / Hk) for key / value. Conventionally, H_k is also smaller than H and H %...
I try to optimize GEMV using shared memory to speed up I\O,theoretically speaking,GEMV with sram will have better bandwidth. BUT here comes a weird performance result. **Device: M2 Ultra 128GB**...
Hi ,thank you for implement flash-attention in MPS , it can be run flash-attention on Mac . But no document to say how to use it in python or pytorch...
An accuracy issue arises during integration with SSD-1B model. q, k can be large enough that q*k can exceed half-precision range. This is OK because the scale usually applied on...
Hi, when I try to compile the GEMM kernel, I get an error: `Undefined symbol(s) for architecture 'air64':\n '@air.simdgroup_async_copy_2d.p1i8.p3i8', referenced from:\n _Z10_gemm_implIfEvPU9MTLdeviceT_S2_S2_PU14MTLthreadgroupS0_Dv3_jtt in program_sourc` I've made sure to install Xcode...
Hello Philip, Great project ! It has been something I have been waiting for some time now. Can you give me some guideline on how I can replace current flash...
In the documentation it says "Batch 8 or more compute commands into the same MTLComputeCommandEncoder." How is this done? I was under the impression that individual kernels were encoded in...
Any plans on upgrading this repo for v2 of [flash-attention](https://github.com/Dao-AILab/flash-attention)?