FastChat
FastChat copied to clipboard
Anybody know what is the version of `flash_attn` used for finetune?
When attempting to execute the FastChat\scripts\train_vicuna_7b.sh script, it raises an exception with the following error message:
File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 16, in <module>
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
ImportError: cannot import name 'flash_attn_unpadded_func' from 'flash_attn.flash_attn_interface' (/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py)
Does anyone know why this error occurred? Additionally, why hasn't the repository provided a requirements.txt file to specify the required environment for fine-tuning the model?
Hello, I also am currently facing a similar issue during fine tuning. Despite setting all the Environment variables correctly, this issue still persists.
I am able to run the model locally (both with CPU and GPU option). If you have figured out the solution, could you please share your approach?
Thanks
Yes, actually i recommend you create a docker container for finetuning.
For example:
FROM nvcr.io/nvidia/pytorch:23.11-py3
and install some basic package below and then run train_vicuna_7b.sh
pip install peft==0.5.0 \
transformers==4.37.1 \
transformers-stream-generator==0.0.4 \
deepspeed==0.12.3 \
accelerate==0.26.1 \
gunicorn==20.1.0 \
flask==2.1.2 \
flask_api==3.1 \
langchain==0.1.4 \
fastapi==0.109.1 \
uvicorn==0.19.0 \
jinja2==3.1.2 \
huggingface_hub==0.20.3 \
grpcio-tools==1.60.0 \
bitsandbytes==0.42.0 \
sentencepiece==0.1.99 \
safetensors==0.4.2 \
datasets==2.16.1 \
texttable==1.7.0 \
toml==0.10.2 \
numpy==1.24.4 \
scikit-learn==1.3.2\
loguru==0.7.0 \
protobuf==4.24.4 \
pydantic==2.5.1 \
python-dotenv==1.0.0 \
tritonclient[all]==2.41.1 \
sse-starlette==2.0.0 \
boto3==1.34.30 \
jsonlines==4.0.0
Hope this can help you.