LLaMA-Factory icon indicating copy to clipboard operation
LLaMA-Factory copied to clipboard

DPO + FSDP

Open rkinas opened this issue 8 months ago • 8 comments

Reminder

  • [x] I have read the above rules and searched the existing issues.

System Info

Problem with DPO + FSDP (error during training - first step and loss calculation) data and model on different device cpu/gpu. ORPO, SiMPO works.

Reproduction

Put your message here.

Others

No response

rkinas avatar Apr 08 '25 13:04 rkinas

check thisout @lindy-test

brokewq avatar Apr 08 '25 14:04 brokewq

I reproduced it on two machines:

  • SLURM cluster
  • local machine with 4x GPU

[rank3]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:3!

DPO:

  • full tinetuning
  • fsdp
  • accelerate launch

rkinas avatar Apr 08 '25 16:04 rkinas

@rkinas Could you please provide the full traceback

hiyouga avatar Apr 08 '25 16:04 hiyouga

[INFO|trainer.py:2423] 2025-04-08 16:33:42,733 >> Number of trainable parameters = 399,126,912 0%| | 0/1940 [00:00<?, ?it/s] [rank0]: Traceback (most recent call last): [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 28, in [rank0]: main() [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 19, in main [rank0]: run_exp() [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 107, in run_exp [rank0]: _training_function(config={"args": args, "callbacks": callbacks}) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 75, in _training_function [rank0]: run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/workflow.py", line 83, in run_dpo [rank0]: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2245, in train [rank0]: return inner_training_loop( [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2560, in inner_training_loop [rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3736, in training_step loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 272, in compute_loss [rank0]: return super().compute_loss(model, inputs, return_outputs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1408, in compute_loss [rank0]: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 241, in get_batch_loss_metrics [rank0]: reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 220, in compute_reference_log_probs [rank0]: reference_chosen_logps, reference_rejected_logps, * = self.concatenated_forward(ref_model, batch) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 189, in concatenated_forward [rank0]: all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward [rank0]: return model_forward(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in call [rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs)) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast [rank0]: return func(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper [rank0]: output = func(self, *args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func [rank0]: return func(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 821, in forward [rank0]: outputs: BaseModelOutputWithPast = self.model( [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper [rank0]: output = func(self, *args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 527, in forward [rank0]: inputs_embeds = self.embed_tokens(input_ids) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 190, in forward [rank0]: return F.embedding( [rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2551, in embedding [rank0]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) [rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select) [rank2]: Traceback (most recent call last): [rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 28, in [rank2]: main() [rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 19, in main [rank2]: run_exp() [rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 107, in run_exp [rank2]: _training_function(config={"args": args, "callbacks": callbacks}) [rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 75, in _training_function

rkinas avatar Apr 08 '25 16:04 rkinas

Have you turned on fsdp cpu offload?

hiyouga avatar Apr 08 '25 17:04 hiyouga

No. I use

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16  # or fp16
num_machines: 1  # the number of nodes
num_processes: 2  # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

rkinas avatar Apr 08 '25 17:04 rkinas

5 minutes ago I run training on pure TRL (ORPO) and it works with FSDP

rkinas avatar Apr 08 '25 17:04 rkinas

maybe related to https://github.com/huggingface/trl/issues/1147 https://github.com/huggingface/trl/pull/2539

hiyouga avatar Apr 08 '25 17:04 hiyouga

save problem any progress? train with lora is ok, but full param train is not ok

qZhang88 avatar Jun 16 '25 13:06 qZhang88

same probem here

zzfoutofspace avatar Aug 27 '25 00:08 zzfoutofspace