flash-attention
flash-attention copied to clipboard
Incorrect "RuntimeError: FlashAttention only support fp16 and bf16 data type"
Flash attn 2.5.7 always complains about the input data type even when it's clearly a correct one.
I'm using the base image nvcr.io/nvidia/pytorch:24.03-py3
>>> import torch, flash_attn
>>> from flash_attn.flash_attn_interface import flash_attn_func
>>> x=torch.randn(1, 4096, 8, 128, dtype=torch.bfloat16, device="cuda")
>>> flash_attn.__version__
'2.5.7'
>>> flash_attn_func(x,x,x)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 831, in flash_attn_func
return FlashAttnFunc.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 572, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 511, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
Exception raised from mha_fwd at /home/runner/work/flash-attention/flash-attention/csrc/flash_attn/flash_api.cpp:340 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x99 (0x7fffec083d89 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x6a (0x7fffec0335ac in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: mha_fwd(at::Tensor&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, bool, std::optional<at::Generator>) + 0x18e9 (0x7ffea50183d9 in /home/user/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x1391c9 (0x7ffea50341c9 in /home/user/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x135819 (0x7ffea5030819 in /home/user/.local/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #11: THPFunction_apply(_object*, _object*) + 0xf59 (0x7fffeb26c0c9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #29: <unknown function> + 0x29d90 (0x7ffff7a00d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x80 (0x7ffff7a00e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
Thanks for the report, I can reproduce it. Investigating now. Might be because of the way torch (in C++) handle dtype.
Hmm compiling from scratch seems to work fine, so sth is wrong about the wheel we built.
I'm guessing this is because 24.03 uses CUDA 12.4 and the wheels built with nvcc 12.2 are somehow not compatible.
What is the recommended fix then? Rebuild flash attention from code?
I just did a fresh install using python setup.py install, but I still get the same error message. Either I did something wrong, or there is still a problem somewhere.
File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py", line 624, in forward
attn_function_output = self._apply_dense_attention(
File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py", line 441, in _apply_dense_attention
attn_output_unpad = flash_attn_varlen_kvpacked_func(
File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 978, in flash_attn_varlen_kvpacked_func
return FlashAttnVarlenKVPackedFunc.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 572, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 432, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
Do you think I can comment the code that does the check, and reinstall, or is that just a sign that things will break down the line?
Ok, it sounds like the issue was with the Phi3-small code, not with the library. The assert was triggering appropriately. Sorry for the noise.
Ok, it sounds like the issue was with the Phi3-small code, not with the library. The assert was triggering appropriately. Sorry for the noise.
How did you solve the issue? I am also trying to fine-tune phi3. But I receive the same error message and I don't know what to do with it tbh.
meet the same problem @tridao do you have the recommended way to solve it?
Sorry I forgot, but I think it was that you had to set a value for use rentrant other than the default.