verl icon indicating copy to clipboard operation
verl copied to clipboard

FSDP2 with LoRA training throws dtype mismatch error

Open Schilings opened this issue 3 months ago • 3 comments

System Info

Issue: FSDP2 with LoRA training throws dtype mismatch error

Description

When training with FSDP2 and LoRA, the following error occurs:

FSDP expects uniform original parameter dtype but got {torch.float32, torch.bfloat16}

However, using FSDP1 with the same setup works fine without any errors.

Analysis

I investigated this issue and found that verl has implemented a specific wrap policy for FSDP1 in LoRA scenarios, but this wrap policy is missing for FSDP2. This likely causes the dtype inconsistency when FSDP2 tries to handle the mixed dtypes from LoRA parameters.

Additional Context

The issue seems to be related to how FSDP2 handles parameter wrapping and dtype consistency in LoRA scenarios compared to FSDP1's implemented wrap policy.

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

actor:
  strategy: fsdp2
  fsdp_config:
    fsdp_size: -1
    model_dtype: bfloat16 

Expected behavior

FSDP2 should work with LoRA training just like FSDP1, handling mixed dtypes properly or providing a consistent wrap policy.

Schilings avatar Sep 14 '25 13:09 Schilings

Same issue

H-Jamieu avatar Sep 28 '25 09:09 H-Jamieu

same issue

nekoteai avatar Nov 03 '25 22:11 nekoteai

I had the same issue when using FSDP2 with LoRA. The problem was that LoRA-related modules (which have different dtype) were not being sharded before the transformer layers, unlike the FSDP1 wrap policy. Making FSDP2 shard the LoRA modules first fixed the dtype mismatch error.

I patched apply_fsdp2 in fsdp_utils.py like this:

https://github.com/volcengine/verl/blob/fc7df6f7f99bad09b394463c75bb64ef6a21191b/verl/utils/fsdp_utils.py#L507-L508

def apply_fsdp2(model, fsdp_kwargs, config, is_lora=False):  # FIX: add is_lora param
    """model: AutoModelForCausalLM"""
    assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"

    default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
    fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get(
        "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
    )

    if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
        fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]

    assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None

    lora_modules = []  # FIX: manage LoRA modules separately
    transformer_modules = []
    for name, module in model.named_modules():
        ##################################################################################
        # FIX: identify LoRA modules (following lambda_policy_fn in get_fsdp_wrap_policy)
        if is_lora and (
            len(list(module.named_children())) == 0
            and getattr(module, "weight", None) is not None
            and module.weight.requires_grad
        ):
            lora_modules.append(module)
        ##################################################################################

        if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or (
            isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings
        ):
            transformer_modules.append(module)

    ######################################################################################
    # FIX: shard LoRA modules first
    for idx, module in enumerate(lora_modules):
        # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        #     print(f"wrap LoRA module {module.__class__.__name__}")
        with maybe_patch_fsdp_module(module):
            fully_shard(module, **fsdp_kwargs)
    ######################################################################################

    for idx, module in enumerate(transformer_modules):
        # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
        #     print(f"wrap module {module.__class__.__name__}")
        with maybe_patch_fsdp_module(module):
            fully_shard(module, **fsdp_kwargs)

    # if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
    #     print(f"wrap module {model.__class__.__name__}")
    with maybe_patch_fsdp_module(model):
        fully_shard(model, **fsdp_kwargs)  # fsdp2 will not reshard_after_forward for root module

Then I passed the new is_lora flag wherever apply_fsdp2 is called. In my setup that meant updating these two lines in fsdp_workers.py: https://github.com/volcengine/verl/blob/fc7df6f7f99bad09b394463c75bb64ef6a21191b/verl/workers/fsdp_workers.py#L528 https://github.com/volcengine/verl/blob/fc7df6f7f99bad09b394463c75bb64ef6a21191b/verl/workers/fsdp_workers.py#L1409

apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config, is_lora=self._is_lora)

I'm not sure if this is the proper fix, but at least it worked for me!

xxnpark avatar Nov 22 '25 08:11 xxnpark