Firefly icon indicating copy to clipboard operation
Firefly copied to clipboard

RuntimeError: FlashAttention only support fp16 and bf16 data type

Open sankexin opened this issue 1 year ago • 1 comments

python train.py --train_args_file train_args/sft/qlora/llama3-8b-sft-qlora.json

GPU: NVIDIA A800 80GB PCIe. Max memory: 79.151 GB. Platform = Linux. Pytorch: 2.3.0+cu121. CUDA = 8.0. CUDA Toolkit = 12.1. Bfloat16 = TRUE. Xformers = 0.0.26.post1. FA = True.

Traceback (most recent call last): File "/home/sotanv/Firefly/train.py", line 439, in main() File "/home/sotanv/Firefly/train.py", line 427, in main train_result = trainer.train() File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train return inner_training_loop( File "", line 354, in _fast_inner_training_loop File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2902, in training_step loss = self.compute_loss(model, inputs) File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2925, in compute_loss outputs = model(**inputs) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 817, in forward return model_forward(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 805, in call return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 882, in PeftModelForCausalLM_fast_forward return self.base_model( File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward return self.model.forward(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 813, in _CausalLM_fast_forward outputs = self.model( File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 668, in LlamaModel_fast_forward layer_outputs = torch.utils.checkpoint.checkpoint( File "/usr/local/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner return fn(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 487, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 262, in forward outputs = run_function(*args) File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 664, in custom_forward return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 433, in LlamaDecoderLayer_fast_forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 359, in LlamaAttention_fast_forward A = flash_attn_func(Q, K, V, causal = True) File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 831, in flash_attn_func return FlashAttnFunc.apply( File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 511, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type 0%|

sankexin avatar May 20 '24 02:05 sankexin

请问这个问题解决了吗

TonyUSTC avatar Aug 17 '24 13:08 TonyUSTC