Tri Dao
Tri Dao
The model parameters are still defined in pytorch so it's just `sum(p.numel() for p in model.parameters())`. For FLOPS you can calculate by hand, or search the issues on this repo.
Right LDSM won't work for V if the data is 8bit. We might have some way to address this soon.
you can use LDSM.T and byte-permute, then LDSM, as a way to transpose V we'll release that code soon idk if it works well without warp specialization
Sorry i mean LDSM.T, byte permute, then store using STSM. That way you can transpose V.
I see, I forgot that STSM is Hopper only. The other option is to transpose V in a separate kernel, or fused it with a preceding kernel (e.g. gemm).
X V B K C Q
The figure is not drawn to scale, it's just an illustration. The way we do it, softmax only has 1 MUFU (exponential). There's no floating point division. Division is done...
For that kind of profiling you'd need to record the global clock, store to global memory, then visualize it later. It's quite manually. Triton has a profiler (Proton) that does...
Yes. The original FlashAttention implementation (May 2022) didn't have any seqlen parallelism. Later on (in code v1) we have a kind of parallelism in the forward pass where we decide...