FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

Anybody know what is the version of `flash_attn` used for finetune?

Open Oscarjia opened this issue 1 year ago • 2 comments
trafficstars

When attempting to execute the FastChat\scripts\train_vicuna_7b.sh script, it raises an exception with the following error message:

File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 16, in <module>
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
ImportError: cannot import name 'flash_attn_unpadded_func' from 'flash_attn.flash_attn_interface' (/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py)

Does anyone know why this error occurred? Additionally, why hasn't the repository provided a requirements.txt file to specify the required environment for fine-tuning the model?

Oscarjia avatar Jan 28 '24 11:01 Oscarjia

image

Hello, I also am currently facing a similar issue during fine tuning. Despite setting all the Environment variables correctly, this issue still persists.

I am able to run the model locally (both with CPU and GPU option). If you have figured out the solution, could you please share your approach?

Thanks

dhruvpes avatar May 09 '24 14:05 dhruvpes

Yes, actually i recommend you create a docker container for finetuning. For example: FROM nvcr.io/nvidia/pytorch:23.11-py3 and install some basic package below and then run train_vicuna_7b.sh

pip install peft==0.5.0 \
    transformers==4.37.1 \
    transformers-stream-generator==0.0.4 \
    deepspeed==0.12.3 \
    accelerate==0.26.1 \
    gunicorn==20.1.0 \
    flask==2.1.2 \
    flask_api==3.1 \
    langchain==0.1.4 \
    fastapi==0.109.1 \
    uvicorn==0.19.0 \
    jinja2==3.1.2 \
    huggingface_hub==0.20.3 \
    grpcio-tools==1.60.0 \
    bitsandbytes==0.42.0 \
    sentencepiece==0.1.99 \
    safetensors==0.4.2 \
    datasets==2.16.1 \
    texttable==1.7.0 \
    toml==0.10.2  \
    numpy==1.24.4 \
    scikit-learn==1.3.2\
    loguru==0.7.0 \
    protobuf==4.24.4 \
    pydantic==2.5.1 \
    python-dotenv==1.0.0 \
    tritonclient[all]==2.41.1 \
    sse-starlette==2.0.0 \
    boto3==1.34.30 \
    jsonlines==4.0.0

Hope this can help you.

Oscarjia avatar May 15 '24 10:05 Oscarjia