trl icon indicating copy to clipboard operation
trl copied to clipboard

FSDP Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

Open qZhang88 opened this issue 1 year ago • 4 comments

running dpo with Qwen meet flatten problem. FSDP config as follow

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
num_machines: 1
num_processes: 2
main_training_function: main
mixed_precision: bf16
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

qZhang88 avatar Jun 11 '24 11:06 qZhang88

Could you share mode detailed error message?

vwxyzjn avatar Jun 11 '24 15:06 vwxyzjn

detailed codes as following.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank0]:     main()
[rank0]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank0]:     trainer.train()
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank0]:     self.model = self.accelerator.prepare(self.model)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank0]:     result = tuple(
[rank0]:              ^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank0]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank0]:     return self.prepare_model(obj, device_placement=device_placement)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank0]:     model = FSDP(model, **kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank0]:     _auto_wrap(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank0]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:                                         ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:                                         ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]:                                         ^^^^^^^^^^^^^^^^
[rank0]:   [Previous line repeated 2 more times]
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank0]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank0]:     return wrapper_cls(module, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank0]:     _init_param_handle_from_module(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank0]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank0]:     handle = FlatParamHandle(
[rank0]:              ^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank0]:     self._init_flat_param_and_metadata(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank0]:     ) = self._validate_tensors_to_flatten(params)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank0]:     raise ValueError(
[rank0]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
[rank1]: Traceback (most recent call last):
[rank1]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank1]:     main()
[rank1]:   File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank1]:     trainer.train()
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank1]:     self.model = self.accelerator.prepare(self.model)
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank1]:     result = tuple(
[rank1]:              ^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank1]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank1]:     return self.prepare_model(obj, device_placement=device_placement)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank1]:     model = FSDP(model, **kwargs)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank1]:     _auto_wrap(
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank1]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]:                                         ^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]:                                         ^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]:                                         ^^^^^^^^^^^^^^^^
[rank1]:   [Previous line repeated 2 more times]
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank1]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank1]:     return wrapper_cls(module, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank1]:     _init_param_handle_from_module(
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank1]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank1]:     handle = FlatParamHandle(
[rank1]:              ^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank1]:     self._init_flat_param_and_metadata(
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank1]:     ) = self._validate_tensors_to_flatten(params)
[rank1]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank1]:     raise ValueError(
[rank1]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

I am running a code modified from this script https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py

And I am runing with QLoRA. And source code for BnB config is modified to support param bnb_4bit_quant_storage

    --load_in_4bit True \
    --use_bnb_nested_quant True \
    --bnb_4bit_quant_storage bfloat16 \

If QLoRA is not used. FSDP is all fine, but during training, it will meet OOM error with some long training example. So I am trying to use FDSP with QLoRA.

qZhang88 avatar Jun 13 '24 08:06 qZhang88

Check note from SFTTrainer, the error is caused by peft_module_casting_to_bf16 or prepare_model_for_kbit_training

                # Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
                # peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
                # QLoRA + FSDP / DS-Zero3

qZhang88 avatar Jun 18 '24 05:06 qZhang88

+1

Minami-su avatar Jul 02 '24 00:07 Minami-su

+1

atc0m avatar Jul 08 '24 19:07 atc0m

🥺

Minami-su avatar Jul 11 '24 05:07 Minami-su

🥺

Minami-su avatar Jul 11 '24 05:07 Minami-su

as was said above, remove prepare_model_for_kbit_training

this solved the issue for me

https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L128

mrT23 avatar Jul 26 '24 21:07 mrT23

as was said above, remove prepare_model_for_kbit_training

this solved the issue for me

https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L128

Hi. Thank you for your solution, but I don't get it. Should I remove all code that called to 'prepare_model_for_kbit_training'?

vananh0905 avatar Jul 28 '24 05:07 vananh0905

As far as I understand, the solution is to comment https://github.com/huggingface/trl/blob/332062372d35eaa33912b3283be0de9b0555652b/trl/trainer/dpo_trainer.py#L279

qgallouedec avatar Aug 05 '24 15:08 qgallouedec

We are still facing this issue when trying to train with QLoRA + PEFT + FSDP. Is the only fix to comment out a line in the TRL trainer source code??

JohnGiorgi avatar Feb 28 '25 01:02 JohnGiorgi

Also us experiencing the same error. However, while using sftt trainer all works well but with cpo/dpo trainer we got this error. does we going to reopen this issue? @qgallouedec ?

using trl==0.16.1 maybe this is related? https://github.com/huggingface/trl/issues/2537

also I always wonder what is better best practice send a peft model to trainers or peft_config and let the trainers "do the work"? I always tought the later but it seems like there is different behavior in trainers

shon-otmazgin-wix avatar Apr 24 '25 20:04 shon-otmazgin-wix