flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

FA-V3 deterministic usage

Open defei-coder opened this issue 7 months ago • 1 comments
trafficstars

Hi, @tridao I'm glad you used the semaphore solution to solve the problem of backward deterministic computing (which used for dq or dk and dv while GQA). I found that the code has been developed(commit), but it not worked now. In code

void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
    VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
        BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
//             BOOL_SWITCH(params.deterministic, Deterministic, [&] {
            // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
            run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
//             });
        });
    });
}

From the kernel code, deterministic are ready, I am a little confused about why this feature is not enabled. I can provide some help if needed.

defei-coder avatar Apr 16 '25 11:04 defei-coder

Oh it just hasn't been tested very well. dq semaphore should work except for hdim256. I'm not sure dk & dv (when GQA) semaphores have worked yet.

tridao avatar Apr 16 '25 16:04 tridao