metal-flash-attention
metal-flash-attention copied to clipboard
Guidelines for modifying H3 with metal-flash-attention
Hello Philip,
Great project ! It has been something I have been waiting for some time now.
Can you give me some guideline on how I can replace current flash attention mechanism in H3 with metal-flash-attention ?
Thanks in advance !
Just to followup,
from this code sample,
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings
from flash_attn.utils.generation import GenerationMixin
One would need to replace these imports with metal-flash-attention implementation. Could you please let me know what is the current status for each of them ? And what would be a good starting point for implementing them.
You would have to call into Metal directly using the PyObjC bindings - check out how Tinygrad does this. Or use a combination of Swift and PythonKit.
However, H3 seems to use algorithms besides FlashAttention, such as FlashConv. I don't have Metal kernels for those, but you can try translating CUDA code to Metal. My repo provides a good workflow which gives access to simdgroup_async_copy
instructions.
Hmm, I thought that FlashAttention was one mechanism and alternative to it was FlashConv. Do you think FlashConv will be faster in metal compared to CUDA ?
I will check Tinygrad.
Do you think FlashConv will be faster in metal compared to CUDA ?
There is no direct way to compare "speed" between Metal and CUDA, as they run on different GPUs with different GFLOPS. Only ways to tell how well a shader utilizes the ALU hardware.
You'll have to look at the algorithm. See whether it has less total operations, and whether it parallelizes efficiently (no workload imbalance).
Thanks for the reply. I will look into it.
On Tue, 8 Aug 2023 at 10:01 PM, Philip Turner @.***> wrote:
Do you think FlashConv will be faster in metal compared to CUDA ?
There is no direct way to compare "speed" between Metal and CUDA, as they run on different GPUs with different GFLOPS. Only ways to tell how well a shader utilizes the ALU hardware.
You'll have to look at the algorithm. See whether it has less total operations, and whether it parallelizes efficiently (no workload imbalance).
— Reply to this email directly, view it on GitHub https://github.com/philipturner/metal-flash-attention/issues/3#issuecomment-1669948164, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAXGU4GPEAWLCO5IXI2TZXTXUJSXVANCNFSM6AAAAAA3H66UGI . You are receiving this because you authored the thread.Message ID: @.***>