LLaMA-Factory icon indicating copy to clipboard operation
LLaMA-Factory copied to clipboard

训练开启Flash Attention运行报错

Open onenotell opened this issue 11 months ago • 0 comments

Reminder

  • [X] I have read the README and searched the existing issues.

Reproduction

使用chinese-alpaca-2-7b模型在两块H800进行SFT训练,开启Flash Attention加速,训练报错,请帮忙看一下,谢谢。 信息如下:

[INFO|trainer.py:1812] 2024-03-14 12:07:36,974 >> ***** Running training ***** [INFO|trainer.py:1813] 2024-03-14 12:07:36,974 >> Num examples = 48,818 [INFO|trainer.py:1814] 2024-03-14 12:07:36,974 >> Num Epochs = 1 [INFO|trainer.py:1815] 2024-03-14 12:07:36,974 >> Instantaneous batch size per device = 4 [INFO|trainer.py:1818] 2024-03-14 12:07:36,974 >> Total train batch size (w. parallel, distributed & accumulation) = 32 [INFO|trainer.py:1819] 2024-03-14 12:07:36,974 >> Gradient Accumulation steps = 4 [INFO|trainer.py:1820] 2024-03-14 12:07:36,974 >> Total optimization steps = 5 [INFO|trainer.py:1821] 2024-03-14 12:07:36,975 >> Number of trainable parameters = 6,929,256,448 0%| | 0/5 [00:00<?, ?it/s]Traceback (most recent call last): File "/root/llm/train.py", line 91, in <module> main() File "/root/llm/train.py", line 50, in main run_exp() File "/root/llm/src/llmtuner/train/tuner.py", line 32, in run_exp run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) File "/root/llm/src/llmtuner/train/sft/workflow.py", line 73, in run_sft train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1624, in train return inner_training_loop( File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1961, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2902, in training_step loss = self.compute_loss(model, inputs) File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2925, in compute_loss outputs = model(**inputs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 1852, in forward loss = self.module(*inputs, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1176, in forward outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1008, in forward layer_outputs = self._gradient_checkpointing_func( File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 451, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 230, in forward outputs = run_function(*args) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 740, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/root/llm/src/llmtuner/extras/patches/llama_patch.py", line 127, in llama_flash_attn_forward cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) TypeError: LlamaRotaryEmbedding.forward() missing 1 required positional argument: 'position_ids' 0%| | 0/5 [00:00<?, ?it/s] [2024-03-14 12:07:39,691] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 16822 [2024-03-14 12:07:42,010] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 16823

Expected behavior

No response

System Info

python 3.10 transformers 4.38.2 torch 2.1.2

Others

No response

onenotell avatar Mar 14 '24 10:03 onenotell