[QST] How flash-attn calc the dropout?
https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/csrc/flash_attn/src/flash_fwd_kernel.h#L345 Hi @tridao ,I don't understand the real meaning of the variable block_row_idx and block_col_idx when FA calc dropout, why block_col_idx = n_block * (kBlockN / 32) and block_row_idx = m_block * (kBlockM / 16) + tidx / 32;? why It's not the element position within the current block? Dose the dropout result related with kblockM and kblockN? I‘m very confused how FA computing dropout. could you so kindness to explain it? thanks.
https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/csrc/flash_attn/src/flash_fwd_kernel.h#L1080
Hi @tridao
why 16x32? can we choose block size?
the first 32 means 16x32? the second 32 means wave size?
One call to Philox RNG gives 128 random bits: https://github.com/Dao-AILab/flash-attention/blob/a93359a2bfdedfcd054622e6f595f99d7a23c17e/csrc/flash_attn/src/philox.cuh#L31 We use 8 random bits to generate one dropout mask: https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/csrc/flash_attn/src/dropout.h#L46 So each thread can do dropout on 16 elements per Philox call. There are 32 threads in a warp -> 512 elements, so that's a block of 16 x 32