FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

Encounter the runtime error training with lora and flash_attention together

Open Jeffwan opened this issue 1 year ago • 2 comments

I am using fschat to finetune a vicuna model with lora and get following error. I tried normal trining with flash attn as well and it works fine. It's probably something wrong on the lora training configuration side? But I can not figure this out. Seems I am using exact training scripts as https://github.com/lm-sys/FastChat/pull/138#issuecomment-1495289110 Does anyone have clues?

RuntimeError: Expected q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.
[2023-05-03 05:03:05,497] [INFO] [runner.py:540:main] cmd = /usr/local/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgM119 --master_addr=127.0.0.1 --master_port=9849 --enable_each_rank_log=None /tmp/FastChat/fastchat/train/train_lora.py --deepspeed /data/lora/deepspeed-config.json --lora_r 8 --lora_alpha 16 --lora_dropout 0.05 --model_name_or_path /mnt/vicuna/13b-v1.1 --data_path /mnt/data/all_in_one_v1.1.json --bf16 True --output_dir /mnt/vicuna-finetune/v3.4 --num_train_epochs 3 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 1200 --save_total_limit 100 --learning_rate 2e-5 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type cosine --logging_steps 1 --tf32 True --model_max_length 2048


time to load utils op: 0.00040268898010253906 seconds
  0%|                                                                                                                                                                                            | 0/573 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/tmp/FastChat/fastchat/train/train_lora.py", line 151, in <module>
    train()
  File "/tmp/FastChat/fastchat/train/train_lora.py", line 141, in train
    trainer.train()
  File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 1929, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2699, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2731, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1675, in forward
    loss = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 678, in forward
    return self.base_model(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 687, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 577, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/tmp/FastChat/fastchat/train/llama_flash_attn_monkey_patch.py", line 88, in forward
    output_unpad = flash_attn_unpadded_qkvpacked_func(
  File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 256, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
  File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, rng_state, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Here's my configuration

    deepspeed /tmp/FastChat/fastchat/train/train_lora.py \
        --deepspeed /data/lora/deepspeed-config.json \
        --lora_r 8 \
        --lora_alpha 16 \
        --lora_dropout 0.05 \
        --model_name_or_path $MODEL_WEIGHTS_PATH \
        --data_path $DATA_PATH \
        --bf16 True \
        --output_dir $CHECKPOINT_PATH \
        --num_train_epochs 3 \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --gradient_accumulation_steps 1 \
        --evaluation_strategy "no" \
        --save_strategy "steps" \
        --save_steps 1200 \
        --save_total_limit 100 \
        --learning_rate 2e-5 \
        --weight_decay 0. \
        --warmup_ratio 0.03 \
        --lr_scheduler_type "cosine" \
        --logging_steps 1 \
        --tf32 True \
        --model_max_length 2048

deepspeed-config.json

{
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true
    },
    "train_micro_batch_size_per_gpu": 1
}

/cc @ZYHowell

Jeffwan avatar May 02 '23 18:05 Jeffwan

I do use apply_lora.py script to merge the weights, it does use float16 as tyorch_dtype

https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/model/apply_lora.py#L20-L29

Update: I notice the train_lora.py doesn't use float16. However, the default one would be float32. I tried to add torch_dtype=torch.float16 but seems it's not working..

https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/train/train_lora.py#L98-L100

Jeffwan avatar May 02 '23 18:05 Jeffwan

Does your GPU support bfloat16? If not, please try to remove --bf16 True.

ZYHowell avatar May 03 '23 17:05 ZYHowell

@ZYHowell Thanks for the reply. I am using 8 * A100-SMX-80G which should support bf16. II will remove it to give a another try and seem the bf16 performance is similar to fp16 for A100.

The weird thing is even I change code to explicitly use torch.float16,

model = transformers.AutoModelForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
               torch_dtype=torch.float16)
print(model.dtype)
print(model.config.torch_dtype)

it shows torch.float32 and torch.float16, If I use model.half(), both are torch.float16. but it reports RuntimeError: output tensor must have the same type as input tensor Do you have any clues?

Jeffwan avatar May 03 '23 18:05 Jeffwan

The model.dtype controls the output of each layer, please print out the dtype of this tensor. If it is f32, please further check the self.q_proj.dtype which is supposed to be fp16/bf16

ZYHowell avatar May 03 '23 18:05 ZYHowell

@ZYHowell I did some tests and here's the results. Seems the tensor passed to forward is already fp32.

  1. removing bf16 doesn't help.
  2. query_states.dtype is torch.float32
  3. hidden_states is torch.float32 as well
  4. self.q_proj is a Linear
Linear(
  in_features=5120, out_features=5120, bias=False
  (lora_dropout): ModuleDict(
    (default): Dropout(p=0.05, inplace=False)
  )
  (lora_A): ModuleDict(
    (default): Linear(in_features=5120, out_features=8, bias=False)
  )
  (lora_B): ModuleDict(
    (default): Linear(in_features=8, out_features=5120, bias=False)
  )
)

Jeffwan avatar May 03 '23 22:05 Jeffwan

so you need to make the dtype of the model to float16/bfloat16. If you are using the train_lora script, I think you need to add these lines in your deepspeed config:

"bf16": {
    "enabled": true
},

or

"fp16": {
    "enabled": true
}

For more detail about fp16's configuration, please see here.

ZYHowell avatar May 03 '23 22:05 ZYHowell

@ZYHowell Thanks a lot! You save my day!

After adding bf16 in deepspeed config, it works fine. One problem I encounter is the OOM issue. After I changed train_micro_batch_size_per_gpu from 4 to 1. It's running fine.

The last confusing parts

  1. I notice the train_mem.py actually runs using train_micro_batch_size_per_gpu=4 without OOM issue but the lora script can only run with train_micro_batch_size_per_gpu=1, I set gradient_accumulation_steps=2 for both of them. Do you happen to know what could be the reason?

  2. The loss decreases slowly comparing to train_memory.py Not sure if that's expected or not.

## train_mem.py
{'loss': 1.272, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.01}
{'loss': 0.8334, 'learning_rate': 1.9854595177171968e-05, 'epoch': 0.25}
...
{'loss': 0.735, 'learning_rate': 1.903074732324533e-05, 'epoch': 0.5}
.....
{'loss': 0.5697, 'learning_rate': 1.5484518712493188e-05, 'epoch': 1.01}
{'loss': 0.4537, 'learning_rate': 1.0455145991329639e-05, 'epoch': 1.5}
{'loss': 0.2661, 'learning_rate': 5.198430651940846e-06, 'epoch': 2.0}
{'loss': 0.3011, 'learning_rate': 1.339745962155613e-06, 'epoch': 2.5}
{'loss': 0.2498, 'learning_rate': 6.478088794448223e-10, 'epoch': 2.97}


## train_lora.py
{'loss': 1.5374, 'learning_rate': 5.714285714285715e-07, 'epoch': 0.0}                                                                                                        
{'loss': 1.0186, 'learning_rate': 1.142857142857143e-06, 'epoch': 0.01}                                                                                                                                                                                                      
{'loss': 1.0084, 'learning_rate': 1.9976980418510628e-05, 'epoch': 0.15}                                                                                                                                                                                               
{'loss': 1.1323, 'learning_rate': 1.9851603302355373e-05, 'epoch': 0.25}                                                                                                                                                                                                        
{'loss': 1.0513, 'learning_rate': 1.9054761481647815e-05, 'epoch': 0.5}                                                                                                       
.....                                                                                           
{'loss': 0.9736, 'learning_rate': 1.6696907901061803e-05, 'epoch': 0.87}  

BTW, I will help answer other Lora related issues in the community and reduce your support burden. I really appreciate your help!

Jeffwan avatar May 03 '23 22:05 Jeffwan

Thanks for your willingness to help the community, we sincerely appreciate it! For my experience in OOM, I basically try to open the grdient checkpointing, then I can train with a larger batch size.

The gradient checkpointing needs to apply a monkey patch in the hf trainer or llama model definition. There will be a warning with some pointer about how to do the monkey patch. You can also read here for more details.

ZYHowell avatar May 04 '23 00:05 ZYHowell

@Jeffwan Have you tried llama-30b-hf with lora?

better629 avatar May 06 '23 03:05 better629

@ZYHowell The https://github.com/lm-sys/FastChat/issues/170 mentioned that you have a plan about 30b with lora. Do you use load_in_8bit in lora script? It seems that it costs more than 40G GPU-MEM if you only use A100-40G.

better629 avatar May 06 '23 03:05 better629

We have multiple A100-40G

ZYHowell avatar May 09 '23 15:05 ZYHowell

@better629 I have not tried 30b yet and we will explore 30B or 65B later and let you know the results

Jeffwan avatar May 23 '23 21:05 Jeffwan

why qlora‘s loss is slower, i find the same question.

alphanlp avatar Jul 01 '23 16:07 alphanlp