Reminder
- [x] I have read the above rules and searched the existing issues.
System Info
Problem with DPO + FSDP (error during training - first step and loss calculation) data and model on different device cpu/gpu. ORPO, SiMPO works.
Reproduction
Put your message here.
Others
No response
I reproduced it on two machines:
- SLURM cluster
- local machine with 4x GPU
[rank3]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:3!
DPO:
- full tinetuning
- fsdp
- accelerate launch
@rkinas Could you please provide the full traceback
[INFO|trainer.py:2423] 2025-04-08 16:33:42,733 >> Number of trainable parameters = 399,126,912
0%| | 0/1940 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 28, in
[rank0]: main()
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 19, in main
[rank0]: run_exp()
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 107, in run_exp
[rank0]: _training_function(config={"args": args, "callbacks": callbacks})
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 75, in _training_function
[rank0]: run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/workflow.py", line 83, in run_dpo
[rank0]: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2245, in train
[rank0]: return inner_training_loop(
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2560, in inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3736, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 272, in compute_loss
[rank0]: return super().compute_loss(model, inputs, return_outputs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1408, in compute_loss
[rank0]: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 241, in get_batch_loss_metrics
[rank0]: reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 220, in compute_reference_log_probs
[rank0]: reference_chosen_logps, reference_rejected_logps, * = self.concatenated_forward(ref_model, batch)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 189, in concatenated_forward
[rank0]: all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]: return model_forward(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in call
[rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]: return func(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank0]: return func(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 821, in forward
[rank0]: outputs: BaseModelOutputWithPast = self.model(
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 527, in forward
[rank0]: inputs_embeds = self.embed_tokens(input_ids)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank0]: return F.embedding(
[rank0]: File "/mnt/sda/llm/training/LLaMA-Factory/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank0]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)
[rank2]: Traceback (most recent call last):
[rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 28, in
[rank2]: main()
[rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/train.py", line 19, in main
[rank2]: run_exp()
[rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 107, in run_exp
[rank2]: _training_function(config={"args": args, "callbacks": callbacks})
[rank2]: File "/mnt/sda/llm/training/LLaMA-Factory/src/llamafactory/train/tuner.py", line 75, in _training_function
No. I use
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16 # or fp16
num_machines: 1 # the number of nodes
num_processes: 2 # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
5 minutes ago I run training on pure TRL (ORPO) and it works with FSDP
maybe related to https://github.com/huggingface/trl/issues/1147 https://github.com/huggingface/trl/pull/2539
save problem any progress? train with lora is ok, but full param train is not ok