Tri Dao
Tri Dao
i mean flexattn in this repo (flash_attn.cute) https://github.com/Dao-AILab/flash-attention/blob/main/tests/cute/test_score_mod.py
in general 4D attn mask isn't the right abstraction for prefix-lm (4d mask is too general and you'll pay for it with slow down). @drisspg do we have example of...
We choose to have q @ K^T in fp32 for better numerical stability.
you can try that out
FA uses (batch, seqlen, nheads, headdim). Torch sdpa expects (batch, nheads, seqlen, headdim).
sdpa is probably just running FA2 :D
As always, you want to check against a reference implementation: (flashattention in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).
There's no guarantee of bitwise identical results for two different implementations since floating point maths are not associative ``` In [1]: import torch In [2]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda')...
We have new wheels for torch 2.8 now
Yep i'd love to understand why compile takes so much memory. We do use a lot of templating, but I don't quite get how that translates to a very large...