flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

[QST] How flash-attn calc the dropout?

Open zhang22222 opened this issue 1 year ago • 1 comments

https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/csrc/flash_attn/src/flash_fwd_kernel.h#L345 Hi @tridao ,I don't understand the real meaning of the variable block_row_idx and block_col_idx when FA calc dropout, why block_col_idx = n_block * (kBlockN / 32) and block_row_idx = m_block * (kBlockM / 16) + tidx / 32;? why It's not the element position within the current block? Dose the dropout result related with kblockM and kblockN? I‘m very confused how FA computing dropout. could you so kindness to explain it? thanks.

zhang22222 avatar Jul 30 '24 08:07 zhang22222

https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/csrc/flash_attn/src/flash_fwd_kernel.h#L1080

tridao avatar Jul 30 '24 16:07 tridao

Hi @tridao why 16x32? can we choose block size? img_v3_02ht_5aae75e8-6139-4b5e-b9d8-2deebe3fe10g the first 32 means 16x32? the second 32 means wave size?

zhang22222 avatar Dec 25 '24 10:12 zhang22222

One call to Philox RNG gives 128 random bits: https://github.com/Dao-AILab/flash-attention/blob/a93359a2bfdedfcd054622e6f595f99d7a23c17e/csrc/flash_attn/src/philox.cuh#L31 We use 8 random bits to generate one dropout mask: https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/csrc/flash_attn/src/dropout.h#L46 So each thread can do dropout on 16 elements per Philox call. There are 32 threads in a warp -> 512 elements, so that's a block of 16 x 32

tridao avatar Jan 10 '25 15:01 tridao