FSDP2 with LoRA training throws dtype mismatch error
System Info
Issue: FSDP2 with LoRA training throws dtype mismatch error
Description
When training with FSDP2 and LoRA, the following error occurs:
FSDP expects uniform original parameter dtype but got {torch.float32, torch.bfloat16}
However, using FSDP1 with the same setup works fine without any errors.
Analysis
I investigated this issue and found that verl has implemented a specific wrap policy for FSDP1 in LoRA scenarios, but this wrap policy is missing for FSDP2. This likely causes the dtype inconsistency when FSDP2 tries to handle the mixed dtypes from LoRA parameters.
Additional Context
The issue seems to be related to how FSDP2 handles parameter wrapping and dtype consistency in LoRA scenarios compared to FSDP1's implemented wrap policy.
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
actor:
strategy: fsdp2
fsdp_config:
fsdp_size: -1
model_dtype: bfloat16
Expected behavior
FSDP2 should work with LoRA training just like FSDP1, handling mixed dtypes properly or providing a consistent wrap policy.
Same issue
same issue
I had the same issue when using FSDP2 with LoRA. The problem was that LoRA-related modules (which have different dtype) were not being sharded before the transformer layers, unlike the FSDP1 wrap policy. Making FSDP2 shard the LoRA modules first fixed the dtype mismatch error.
I patched apply_fsdp2 in fsdp_utils.py like this:
https://github.com/volcengine/verl/blob/fc7df6f7f99bad09b394463c75bb64ef6a21191b/verl/utils/fsdp_utils.py#L507-L508
def apply_fsdp2(model, fsdp_kwargs, config, is_lora=False): # FIX: add is_lora param
"""model: AutoModelForCausalLM"""
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get(
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None
lora_modules = [] # FIX: manage LoRA modules separately
transformer_modules = []
for name, module in model.named_modules():
##################################################################################
# FIX: identify LoRA modules (following lambda_policy_fn in get_fsdp_wrap_policy)
if is_lora and (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
lora_modules.append(module)
##################################################################################
if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
):
transformer_modules.append(module)
######################################################################################
# FIX: shard LoRA modules first
for idx, module in enumerate(lora_modules):
# if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
# print(f"wrap LoRA module {module.__class__.__name__}")
with maybe_patch_fsdp_module(module):
fully_shard(module, **fsdp_kwargs)
######################################################################################
for idx, module in enumerate(transformer_modules):
# if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
# print(f"wrap module {module.__class__.__name__}")
with maybe_patch_fsdp_module(module):
fully_shard(module, **fsdp_kwargs)
# if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
# print(f"wrap module {model.__class__.__name__}")
with maybe_patch_fsdp_module(model):
fully_shard(model, **fsdp_kwargs) # fsdp2 will not reshard_after_forward for root module
Then I passed the new is_lora flag wherever apply_fsdp2 is called. In my setup that meant updating these two lines in fsdp_workers.py:
https://github.com/volcengine/verl/blob/fc7df6f7f99bad09b394463c75bb64ef6a21191b/verl/workers/fsdp_workers.py#L528
https://github.com/volcengine/verl/blob/fc7df6f7f99bad09b394463c75bb64ef6a21191b/verl/workers/fsdp_workers.py#L1409
apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config, is_lora=self._is_lora)
I'm not sure if this is the proper fix, but at least it worked for me!