FSDP Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
running dpo with Qwen meet flatten problem. FSDP config as follow
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
num_machines: 1
num_processes: 2
main_training_function: main
mixed_precision: bf16
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Could you share mode detailed error message?
detailed codes as following.
[rank0]: Traceback (most recent call last):
[rank0]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank0]: main()
[rank0]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank0]: trainer.train()
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank0]: self.model = self.accelerator.prepare(self.model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank0]: result = tuple(
[rank0]: ^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank0]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank0]: return self.prepare_model(obj, device_placement=device_placement)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank0]: model = FSDP(model, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank0]: _auto_wrap(
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank0]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank0]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: [Previous line repeated 2 more times]
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank0]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank0]: return wrapper_cls(module, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank0]: _init_param_handle_from_module(
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank0]: _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank0]: handle = FlatParamHandle(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank0]: self._init_flat_param_and_metadata(
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank0]: ) = self._validate_tensors_to_flatten(params)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank0]: raise ValueError(
[rank0]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
[rank1]: Traceback (most recent call last):
[rank1]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 270, in <module>
[rank1]: main()
[rank1]: File "/ws/alpha_llms/alignment/DPO/run_dpo.py", line 261, in main
[rank1]: trainer.train()
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 1885, in train
[rank1]: return inner_training_loop(
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/transformers/trainer.py", line 2032, in _inner_training_loop
[rank1]: self.model = self.accelerator.prepare(self.model)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1292, in prepare
[rank1]: result = tuple(
[rank1]: ^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1293, in <genexpr>
[rank1]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1169, in _prepare_one
[rank1]: return self.prepare_model(obj, device_placement=device_placement)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py", line 1459, in prepare_model
[rank1]: model = FSDP(model, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 485, in __init__
[rank1]: _auto_wrap(
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank1]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank1]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: [Previous line repeated 2 more times]
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank1]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank1]: return wrapper_cls(module, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank1]: _init_param_handle_from_module(
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank1]: _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank1]: handle = FlatParamHandle(
[rank1]: ^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank1]: self._init_flat_param_and_metadata(
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank1]: ) = self._validate_tensors_to_flatten(params)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/root/miniconda3/lib/python3.12/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank1]: raise ValueError(
[rank1]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
I am running a code modified from this script
https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py
And I am runing with QLoRA. And source code for BnB config is modified to support param bnb_4bit_quant_storage
--load_in_4bit True \
--use_bnb_nested_quant True \
--bnb_4bit_quant_storage bfloat16 \
If QLoRA is not used. FSDP is all fine, but during training, it will meet OOM error with some long training example. So I am trying to use FDSP with QLoRA.
Check note from SFTTrainer, the error is caused by peft_module_casting_to_bf16 or prepare_model_for_kbit_training
# Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
# peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
# QLoRA + FSDP / DS-Zero3
+1
+1
🥺
🥺
as was said above, remove prepare_model_for_kbit_training
this solved the issue for me
https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L128
as was said above, remove
prepare_model_for_kbit_trainingthis solved the issue for me
https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L128
Hi. Thank you for your solution, but I don't get it. Should I remove all code that called to 'prepare_model_for_kbit_training'?
As far as I understand, the solution is to comment https://github.com/huggingface/trl/blob/332062372d35eaa33912b3283be0de9b0555652b/trl/trainer/dpo_trainer.py#L279
We are still facing this issue when trying to train with QLoRA + PEFT + FSDP. Is the only fix to comment out a line in the TRL trainer source code??
Also us experiencing the same error. However, while using sftt trainer all works well but with cpo/dpo trainer we got this error. does we going to reopen this issue? @qgallouedec ?
using
trl==0.16.1
maybe this is related?
https://github.com/huggingface/trl/issues/2537
also I always wonder what is better best practice send a peft model to trainers or peft_config and let the trainers "do the work"? I always tought the later but it seems like there is different behavior in trainers