open_llama icon indicating copy to clipboard operation
open_llama copied to clipboard

xformers error when fine-tuning open_llama_3B with memory_efficient_attention

Open EliverQ opened this issue 1 year ago • 5 comments

Hi, I feel confused about this bug when using memory_efficient_attention. It seems that the embed per head you choose can't match with xformers?

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(4, 512, 32, 100) (torch.bfloat16)
     key         : shape=(4, 512, 32, 100) (torch.bfloat16)
     value       : shape=(4, 512, 32, 100) (torch.bfloat16)
     attn_bias   : <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>
     p           : 0.1
`flshattF` is not supported because:
    query.shape[-1] % 8 != 0
`tritonflashattF` is not supported because:
    dropout > 0.0
    query.shape[-1] % 8 != 0
    key.shape[-1] % 8 != 0
    value.shape[-1] % 8 != 0
`cutlassF` is not supported because:
    query.shape[-1] % 8 != 0
    value.shape[-1] % 8 != 0
`smallkF` is not supported because:
    dtype=torch.bfloat16 (supported: {torch.float32})
    max(query.shape[-1] != value.shape[-1]) > 32
    attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMask'>
    unsupported embed per head: 100

I'll appreciate it if you could help me.

EliverQ avatar Aug 13 '23 15:08 EliverQ

By the way, I think the problem maybe the dtype I use (bf16). But the dtype in your config is fp16 and still doesn't work?

EliverQ avatar Aug 13 '23 15:08 EliverQ

For the 3B model, since there's no official LLaMA 3B, we defined the model size ourselves and it might not agree with the 3B model sizes in other implementations

young-geng avatar Aug 14 '23 09:08 young-geng

For the 3B model, since there's no official LLaMA 3B, we defined the model size ourselves and it might not agree with the 3B model sizes in other implementations

But I just use the hf code and checkpoint you released and don't modify anything.

EliverQ avatar Aug 14 '23 09:08 EliverQ

Hmm, then that might be a bug on the HF side. We've tested it in HF transformers without the memory_efficient_attention and it works as expected.

young-geng avatar Aug 14 '23 09:08 young-geng

Thank you very much! Perhaps I've been using the code incorrectly all along.

EliverQ avatar Aug 14 '23 09:08 EliverQ