FastChat
FastChat copied to clipboard
Encounter the runtime error training with lora and flash_attention together
I am using fschat to finetune a vicuna model with lora and get following error. I tried normal trining with flash attn as well and it works fine. It's probably something wrong on the lora training configuration side? But I can not figure this out. Seems I am using exact training scripts as https://github.com/lm-sys/FastChat/pull/138#issuecomment-1495289110 Does anyone have clues?
RuntimeError: Expected q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.
[2023-05-03 05:03:05,497] [INFO] [runner.py:540:main] cmd = /usr/local/bin/python3 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgM119 --master_addr=127.0.0.1 --master_port=9849 --enable_each_rank_log=None /tmp/FastChat/fastchat/train/train_lora.py --deepspeed /data/lora/deepspeed-config.json --lora_r 8 --lora_alpha 16 --lora_dropout 0.05 --model_name_or_path /mnt/vicuna/13b-v1.1 --data_path /mnt/data/all_in_one_v1.1.json --bf16 True --output_dir /mnt/vicuna-finetune/v3.4 --num_train_epochs 3 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 1200 --save_total_limit 100 --learning_rate 2e-5 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type cosine --logging_steps 1 --tf32 True --model_max_length 2048
time to load utils op: 0.00040268898010253906 seconds
0%| | 0/573 [00:00<?, ?it/s]Traceback (most recent call last):
File "/tmp/FastChat/fastchat/train/train_lora.py", line 151, in <module>
train()
File "/tmp/FastChat/fastchat/train/train_lora.py", line 141, in train
trainer.train()
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 1662, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 1929, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2699, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2731, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1675, in forward
loss = self.module(*inputs, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/peft/peft_model.py", line 678, in forward
return self.base_model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 687, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 577, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/tmp/FastChat/fastchat/train/llama_flash_attn_monkey_patch.py", line 88, in forward
output_unpad = flash_attn_unpadded_qkvpacked_func(
File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 256, in flash_attn_unpadded_qkvpacked_func
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
softmax_lse, rng_state, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Here's my configuration
deepspeed /tmp/FastChat/fastchat/train/train_lora.py \
--deepspeed /data/lora/deepspeed-config.json \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--model_name_or_path $MODEL_WEIGHTS_PATH \
--data_path $DATA_PATH \
--bf16 True \
--output_dir $CHECKPOINT_PATH \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1200 \
--save_total_limit 100 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048
deepspeed-config.json
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true
},
"train_micro_batch_size_per_gpu": 1
}
/cc @ZYHowell
I do use apply_lora.py
script to merge the weights, it does use float16 as tyorch_dtype
https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/model/apply_lora.py#L20-L29
Update: I notice the train_lora.py doesn't use float16. However, the default one would be float32. I tried to add torch_dtype=torch.float16
but seems it's not working..
https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/train/train_lora.py#L98-L100
Does your GPU support bfloat16? If not, please try to remove --bf16 True
.
@ZYHowell Thanks for the reply. I am using 8 * A100-SMX-80G which should support bf16. II will remove it to give a another try and seem the bf16 performance is similar to fp16 for A100.
The weird thing is even I change code to explicitly use torch.float16,
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
torch_dtype=torch.float16)
print(model.dtype)
print(model.config.torch_dtype)
it shows torch.float32
and torch.float16
, If I use model.half()
, both are torch.float16
. but it reports RuntimeError: output tensor must have the same type as input tensor
Do you have any clues?
The model.dtype
controls the output of each layer, please print out the dtype of this tensor. If it is f32, please further check the self.q_proj.dtype
which is supposed to be fp16/bf16
@ZYHowell I did some tests and here's the results. Seems the tensor passed to forward is already fp32.
- removing bf16 doesn't help.
-
query_states.dtype
is torch.float32 -
hidden_states
is torch.float32 as well -
self.q_proj
is a Linear
Linear(
in_features=5120, out_features=5120, bias=False
(lora_dropout): ModuleDict(
(default): Dropout(p=0.05, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=5120, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=5120, bias=False)
)
)
so you need to make the dtype of the model to float16/bfloat16. If you are using the train_lora script, I think you need to add these lines in your deepspeed config:
"bf16": {
"enabled": true
},
or
"fp16": {
"enabled": true
}
For more detail about fp16's configuration, please see here.
@ZYHowell Thanks a lot! You save my day!
After adding bf16
in deepspeed config, it works fine. One problem I encounter is the OOM issue. After I changed train_micro_batch_size_per_gpu
from 4 to 1. It's running fine.
The last confusing parts
-
I notice the
train_mem.py
actually runs usingtrain_micro_batch_size_per_gpu=4
without OOM issue but the lora script can only run withtrain_micro_batch_size_per_gpu=1
, I setgradient_accumulation_steps=2
for both of them. Do you happen to know what could be the reason? -
The loss decreases slowly comparing to
train_memory.py
Not sure if that's expected or not.
## train_mem.py
{'loss': 1.272, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.01}
{'loss': 0.8334, 'learning_rate': 1.9854595177171968e-05, 'epoch': 0.25}
...
{'loss': 0.735, 'learning_rate': 1.903074732324533e-05, 'epoch': 0.5}
.....
{'loss': 0.5697, 'learning_rate': 1.5484518712493188e-05, 'epoch': 1.01}
{'loss': 0.4537, 'learning_rate': 1.0455145991329639e-05, 'epoch': 1.5}
{'loss': 0.2661, 'learning_rate': 5.198430651940846e-06, 'epoch': 2.0}
{'loss': 0.3011, 'learning_rate': 1.339745962155613e-06, 'epoch': 2.5}
{'loss': 0.2498, 'learning_rate': 6.478088794448223e-10, 'epoch': 2.97}
## train_lora.py
{'loss': 1.5374, 'learning_rate': 5.714285714285715e-07, 'epoch': 0.0}
{'loss': 1.0186, 'learning_rate': 1.142857142857143e-06, 'epoch': 0.01}
{'loss': 1.0084, 'learning_rate': 1.9976980418510628e-05, 'epoch': 0.15}
{'loss': 1.1323, 'learning_rate': 1.9851603302355373e-05, 'epoch': 0.25}
{'loss': 1.0513, 'learning_rate': 1.9054761481647815e-05, 'epoch': 0.5}
.....
{'loss': 0.9736, 'learning_rate': 1.6696907901061803e-05, 'epoch': 0.87}
BTW, I will help answer other Lora related issues in the community and reduce your support burden. I really appreciate your help!
Thanks for your willingness to help the community, we sincerely appreciate it! For my experience in OOM, I basically try to open the grdient checkpointing, then I can train with a larger batch size.
The gradient checkpointing needs to apply a monkey patch in the hf trainer or llama model definition. There will be a warning with some pointer about how to do the monkey patch. You can also read here for more details.
@Jeffwan Have you tried llama-30b-hf with lora?
@ZYHowell
The https://github.com/lm-sys/FastChat/issues/170 mentioned that you have a plan about 30b with lora.
Do you use load_in_8bit
in lora script? It seems that it costs more than 40G GPU-MEM if you only use A100-40G.
We have multiple A100-40G
@better629 I have not tried 30b yet and we will explore 30B or 65B later and let you know the results
why qlora‘s loss is slower, i find the same question.