Tri Dao

Results 429 comments of Tri Dao
trafficstars

That's an interesting perspective. Both approaches have tradeoffs (more recompute vs more memory access). I did think about this approach (7 GEMMs) when we rewrote for Hopper. The 5 GEMMs...

The numbers for D=256 aren't there yet (as you said, registers is a big headache). Most models use hdim 128.

Btw I'm not sure doing dQ in a separate kernel will save registers. In the kernel with 4 GEMMs, we'd still need registers to hold the accumulator for all the...

Head dimensions are hard-coded. Just like gemm tile sizes are hard-coded (usually 128 x 128 or 128 x 256 or 256 x 128).

For Hopper at least some of the block sizes need to be divisible by 64 (since the wgmma instruction needs M=64). So you can't really decrease M that much (we...

It would run and get correct results. For hdim not divisible by 8 we explicitly pad (which would make it slower). This is mostly so that we can use TMA...

TMA requires addresses to be multiples of 16 bytes, so if you have D=127 TMA won't work. This is standard in GEMM as well, where e.g. Nvidia recommends that dimensions...

No FA2 has bwd for D=256. https://github.com/Dao-AILab/flash-attention/blob/32792d37ec66902e5d82e149971daacbee8b55d7/csrc/flash_attn/src/flash_bwd_launch_template.h#L301

Yes it runs on both A100 and H100 (not using new H100 features). Util on A100 is around the same as hdim64. Util on H100 is not great (obv), like...

Oh these are just numbers off the top of my head. Would need to find an A100 :D I mostly work w H100.