import flash_attn_2_cuda as flash_attn_cuda fails
Please check that this issue hasn't been reported before.
- [X] I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
should train a model ...
I basically tried every installation setup, using conda, pip, different versions of torch etc ....
Also tried ALL available solutions that are already reported
Current behaviour
ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
import flash_attn_2_cuda as flash_attn_cuda
(med_llm_venv) [email protected] axolotl$ tmux a
File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/core/trainer_builder.py", line 40, in
File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/utils/callbacks/init.py", line 18, in
File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/utils/callbacks/init.py", line 18, in
from axolotl.utils.callbacks import (
File "/beegfs/.global1/ws/dyfe751f-MEDICALLLMTRAIN/axolotl/src/axolotl/utils/callbacks/init.py", line 18, in
from optimum.bettertransformer import BetterTransformerfrom optimum.bettertransformer import BetterTransformer
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/init.py", line 14, in
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/init.py", line 14, in
from optimum.bettertransformer import BetterTransformer
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/init.py", line 14, in
from .models import BetterTransformerManager
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/init.py", line 17, in
from .models import BetterTransformerManager
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/init.py", line 17, in
from .models import BetterTransformerManager
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/init.py", line 17, in
from .decoder_models import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/decoder_models.py", line 18, in
from .decoder_models import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/decoder_models.py", line 18, in
from .decoder_models import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/optimum/bettertransformer/models/decoder_models.py", line 18, in
from transformers.models.bart.modeling_bart import BartAttention
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 58, in
from transformers.models.bart.modeling_bart import BartAttention
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 58, in
from transformers.models.bart.modeling_bart import BartAttention
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/transformers/models/bart/modeling_bart.py", line 58, in
from flash_attn import flash_attn_func, flash_attn_varlen_func
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/init.py", line 3, in
from flash_attn import flash_attn_func, flash_attn_varlen_func
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/init.py", line 3, in
from flash_attn import flash_attn_func, flash_attn_varlen_func
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/init.py", line 3, in
from flash_attn.flash_attn_interface import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 10, in
from flash_attn.flash_attn_interface import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 10, in
from flash_attn.flash_attn_interface import (
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 10, in
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEi
E0503 20:53:54.127000 140556082500608 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 194756) of binary: /home/dyfe751f/.conda/envs/med_llm_venv/bin/python
Traceback (most recent call last):
File "/home/dyfe751f/.conda/envs/med_llm_venv/bin/accelerate", line 8, in
args.func(args)
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1048, in launch_command
multi_gpu_launcher(args)
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/accelerate/commands/launch.py", line 702, in multi_gpu_launcher
distrib_run.run(args)
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/torch/distributed/run.py", line 870, in run
elastic_launch(
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 132, in call
return launch_agent(self._config, self._entrypoint, list(args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dyfe751f/.conda/envs/med_llm_venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
Steps to reproduce
install as in the readme then try all suggested fixes
Config yaml
no changes to accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
Possible solution
No response
Which Operating Systems are you using?
- [X] Linux
- [ ] macOS
- [ ] Windows
Python Version
3.11
axolotl branch-commit
main
Acknowledgements
- [X] My issue title is concise, descriptive, and in title casing.
- [X] I have searched the existing issues to make sure this bug has not been reported yet.
- [X] I am using the latest version of axolotl.
- [X] I have provided enough information for the maintainers to reproduce and diagnose the issue.
Uninstall flash-attention once
pip uninstall flash-attn
Install flash-attention from source
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention
python setup.py install
works for me.
https://github.com/Dao-AILab/flash-attention/issues/931
The fix above worked for me as well.