flash-attention
flash-attention copied to clipboard
Error with Pytoch containers
GPU: 2x RTX 4090
Memory: 128GB
CPU: 64 cores
CUDA: 12.3.52
NVIDIA Driver: 545.23.08
PyTorch Container: 23.11
nvcc --version (on host machine): Cuda compilation tools, release 12.3, V12.3.107 Build cuda_12.3.r12.3/compiler.33567101_0
I am trying to fine-tune mistral 7b instruct v0.2 and am running into these errors-
First, I run into an error due to protobuf (I have 4.24.4 and the error message suggests downgrading it to 3.20.x or lower)
(Upon downgrading, pip gives this error cudf 23.10.0 requires protobuf<5,>=4.21, but you have protobuf 3.20.3 which is incompatible. but things continue to work)
TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
1. Downgrade the protobuf package to 3.20.x or lower.
2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).
More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mistral_finetuning_script.py", line 84, in <module>
from trl import SFTTrainer
File "./python3.10/site-packages/trl/__init__.py", line 5, in <module>
from .core import set_seed
File "./python3.10/site-packages/trl/core.py", line 25, in <module>
from transformers import top_k_top_p_filtering
File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1355, in __getattr__
value = getattr(module, name)
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
module = self._get_module(self._class_to_module[name])
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
raise RuntimeError(
RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
1. Downgrade the protobuf package to 3.20.x or lower.
2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).
More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
After downgrading to 3.20.3, I run into an issue with flash-attn
trainer = SFTTrainer(
File "./python3.10/site-packages/trl/trainer/sft_trainer.py", line 163, in __init__
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
File "./python3.10/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
return model_class.from_pretrained(
File "./python3.10/site-packages/transformers/modeling_utils.py", line 3588, in from_pretrained
config = cls._autoset_attn_implementation(
File "./python3.10/site-packages/transformers/modeling_utils.py", line 1387, in _autoset_attn_implementation
cls._check_and_enable_flash_attn_2(
File "./python3.10/site-packages/transformers/modeling_utils.py", line 1483, in _check_and_enable_flash_attn_2
raise ImportError(
ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: you need flash_attn package version to be greater or equal than 2.1.0. Detected version 2.0.4. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.
After upgrading flash-attn to 2.3.6 (latest at the time of the pytorch container release), I get this pip warning ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. transformer-engine 1.0.0+66d91d5 requires flash-attn<=2.0.4,>=1.0.6, but you have flash-attn 2.3.6 which is incompatible.
and this error-
Traceback (most recent call last):
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1364, in _get_module
return importlib.import_module("." + module_name, self.__name__)
File "/usr/lib/python3.10/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 883, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "./python3.10/site-packages/transformers/generation/utils.py", line 93, in <module>
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
File "./python3.10/site-packages/accelerate/__init__.py", line 3, in <module>
from .accelerator import Accelerator
File "./python3.10/site-packages/accelerate/accelerator.py", line 35, in <module>
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
File "./python3.10/site-packages/accelerate/checkpointing.py", line 24, in <module>
from .utils import (
File "./python3.10/site-packages/accelerate/utils/__init__.py", line 153, in <module>
from .launch import (
File "./python3.10/site-packages/accelerate/utils/launch.py", line 33, in <module>
from ..utils.other import is_port_in_use, merge_dicts
File "./python3.10/site-packages/accelerate/utils/other.py", line 36, in <module>
from .transformer_engine import convert_model
File "./python3.10/site-packages/accelerate/utils/transformer_engine.py", line 21, in <module>
import transformer_engine.pytorch as te
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/__init__.py", line 11, in <module>
from .attention import DotProductAttention
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 61, in <module>
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
from flash_attn.flash_attn_interface import (
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mistral_finetuning_script.py", line 84, in <module>
from trl import SFTTrainer
File "./python3.10/site-packages/trl/__init__.py", line 5, in <module>
from .core import set_seed
File "./python3.10/site-packages/trl/core.py", line 25, in <module>
from transformers import top_k_top_p_filtering
File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1355, in __getattr__
value = getattr(module, name)
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
module = self._get_module(self._class_to_module[name])
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
raise RuntimeError(
RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
Even if I install the latest version 2.5.2 (latest at the time of writing this), I get a similar error-
Traceback (most recent call last):
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1364, in _get_module
return importlib.import_module("." + module_name, self.__name__)
File "/usr/lib/python3.10/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 883, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "./python3.10/site-packages/transformers/generation/utils.py", line 93, in <module>
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
File "./python3.10/site-packages/accelerate/__init__.py", line 3, in <module>
from .accelerator import Accelerator
File "./python3.10/site-packages/accelerate/accelerator.py", line 35, in <module>
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
File "./python3.10/site-packages/accelerate/checkpointing.py", line 24, in <module>
from .utils import (
File "./python3.10/site-packages/accelerate/utils/__init__.py", line 153, in <module>
from .launch import (
File "./python3.10/site-packages/accelerate/utils/launch.py", line 33, in <module>
from ..utils.other import is_port_in_use, merge_dicts
File "./python3.10/site-packages/accelerate/utils/other.py", line 36, in <module>
from .transformer_engine import convert_model
File "./python3.10/site-packages/accelerate/utils/transformer_engine.py", line 21, in <module>
import transformer_engine.pytorch as te
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/__init__.py", line 11, in <module>
from .attention import DotProductAttention
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 61, in <module>
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
File "/usr/local/lib/python3.10/dist-packages/flash_attn/__init__.py", line 3, in <module>
from flash_attn.flash_attn_interface import (
File "/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py", line 10, in <module>
import flash_attn_2_cuda as flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops15sum_IntList_out4callERKNS_6TensorEN3c1016OptionalArrayRefIlEEbSt8optionalINS5_10ScalarTypeEERS2_
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mistral_finetuning_script.py", line 84, in <module>
from trl import SFTTrainer
File "./python3.10/site-packages/trl/__init__.py", line 5, in <module>
from .core import set_seed
File "./python3.10/site-packages/trl/core.py", line 25, in <module>
from transformers import top_k_top_p_filtering
File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1355, in __getattr__
value = getattr(module, name)
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1354, in __getattr__
module = self._get_module(self._class_to_module[name])
File "./python3.10/site-packages/transformers/utils/import_utils.py", line 1366, in _get_module
raise RuntimeError(
RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops15sum_IntList_out4callERKNS_6TensorEN3c1016OptionalArrayRefIlEEbSt8optionalINS5_10ScalarTypeEERS2_
Some other things I've tried-
- Uninstalling
transformer-engine - Using the latest pytorch container 24.01
Try flash-attn 2.5.1 on nvcr 23.12 or 24.01.
The symbol "ZN2at4_ops15sum_IntList_out4callERKNS_6TensorEN3c1016OptionalArrayRefIlEEbSt8optionalINS5_10ScalarTypeEERS2" is a mangled CPP function name, to demangle it use this Demangler tool
the function is
at::_ops::sum_IntList_out::call(at::Tensor const&, c10::OptionalArrayRef<long>, bool, std::optional<c10::ScalarType>, at::Tensor&)
So what happens is that flash-attn built with a pytorch version that does not align, don't know which version should..
@tsvisab, not sure if you've resolved this but for anyone who might come across this but installing from source solved it for me.