Tri Dao

Results 642 comments of Tri Dao

i mean flexattn in this repo (flash_attn.cute) https://github.com/Dao-AILab/flash-attention/blob/main/tests/cute/test_score_mod.py

in general 4D attn mask isn't the right abstraction for prefix-lm (4d mask is too general and you'll pay for it with slow down). @drisspg do we have example of...

We choose to have q @ K^T in fp32 for better numerical stability.

FA uses (batch, seqlen, nheads, headdim). Torch sdpa expects (batch, nheads, seqlen, headdim).

As always, you want to check against a reference implementation: (flashattention in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl in fp32).

There's no guarantee of bitwise identical results for two different implementations since floating point maths are not associative ``` In [1]: import torch In [2]: a = torch.randn(10, dtype=torch.bfloat16, device='cuda')...

Yep i'd love to understand why compile takes so much memory. We do use a lot of templating, but I don't quite get how that translates to a very large...