Uni-Fold
Uni-Fold copied to clipboard
Flash-Attention
Hi, if you want to use the optimized flash attention code, you can check out the code here. And this document may be helpful. Hope this can help you.
I run into NaNs if I enable flash attention.
unicore.nan_detector | NaN detected in output of model.evoformer.blocks.47.tri_att_end.mha.linear_v, shape: torch.Size([1, 292, 292, 128]), backward
WARNING | unicore.nan_detector | NaN detected in output of model.evoformer.blocks.21.msa_att_row.mha.linear_v, shape: torch.Size([1, 256, 184, 256]), backward
I get also lots of new warnings: UserWarning: Using non-full backward hooks on a Module that does not return a single Tensor or a tuple of Tensors is deprecated and will be removed in future versions. This hook will be missing some of the grad_output. Please use register_full_backward_hook to get the documented behavior.
Is it working for you @Xreki?
A100 with bfloat16 enabled
Can you provide some details for the installation of flash attention? It seems that the backward did not work correctly.
@lhatsk It seems OK for me. I use the docker image dptechnology/unicore:latest-pytorch1.12.1-cuda11.6-flashattn, test the monomer model with demo data on 1-A100 GPU, using bfloat16 and no NaNs.
I installed flash attention from source according to the README. torch 1.12.1 + CUDA 11.2 I tested it with multimer on 4 GPUs distributed over two nodes (finetuning). It doesn't happen right away. Interestingly, I also get NaNs with OpenFold when I enable flash attention (different data, different cluster, different software setup, monomer) but it happens in the pTM computation there.
Can you write a single test for the flash_attn interface with the shape of the input like [1, 292, 292, 128], so that we can test the function whether works properly?
Just running _flash_attn(q,k,v) works without NaNs. I tested it now also with the pre-compiled package and Uni-Fold monomer, also NaNs. Seems to happen after two or three samples.
you now can use this branch: https://github.com/dptech-corp/Uni-Fold/tree/flash-attn , to try the flash-attention.