verl icon indicating copy to clipboard operation
verl copied to clipboard

[Bug] Megatron model merger for Qwen3 MoE models

Open BaiqingL opened this issue 3 weeks ago • 3 comments

System Info

Verl installed with latest git commit and provided docker image.

After using megatron to train Qwen3 30B A3B Instruct, attempted to merge model like so:

python -m verl.model_merger merge --backend megatron --tie-word-embedding --local_dir actor --target_dir actor_merged

This yields the following error:

[rank0]:[W1130 23:53:42.191853867 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
root@129-213-94-103:/osmosis/checkpoint/sourcegraph_megatron__gspo_sft/qwen3-30B-base-grpo-sft-megatron/global_step_235# python -m verl.model_merger merge     --backend megatron     --tie-word-embedding     --local_dir actor     --target_dir actor_merged
config: ModelMergerConfig(operation='merge', backend='megatron', target_dir='actor_merged', hf_upload_path=None, private=False, test_hf_dir=None, tie_word_embedding=True, trust_remote_code=False, is_value_model=False, local_dir='actor', hf_model_config_path='actor/huggingface', hf_upload=False, use_cpu_initialization=False)
Warning: Failed to set NUMA affinity: libnuma.so: cannot open shared object file: No such file or directory
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Qwen3MoeConfig {
  "architectures": [
    "Qwen3MoeForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "decoder_sparse_step": 1,
  "dtype": "bfloat16",
  "eos_token_id": 151645,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5472,
  "max_position_embeddings": 262144,
  "max_window_layers": 28,
  "mlp_only_layers": [],
  "model_type": "qwen3_moe",
  "moe_intermediate_size": 768,
  "norm_topk_prob": true,
  "num_attention_heads": 32,
  "num_experts": 128,
  "num_experts_per_tok": 8,
  "num_hidden_layers": 48,
  "num_key_value_heads": 4,
  "output_router_logits": false,
  "pad_token_id": 151643,
  "qkv_bias": false,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000000,
  "router_aux_loss_coef": 0.0,
  "shared_expert_intermediate_size": 0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "transformers_version": "4.57.1",
  "use_cache": true,
  "use_qk_norm": true,
  "use_sliding_window": false,
  "vocab_size": 151936
}

Pipeline shards: [48], total layers: 48
Overridden TransformerConfig init config: {'num_layers': 48, 'hidden_size': 2048, 'num_attention_heads': 32, 'num_query_groups': 4, 'ffn_hidden_size': 5472, 'attention_dropout': 0.0, 'hidden_dropout': 0.0, 'kv_channels': 128, 'layernorm_epsilon': 1e-06, 'add_bias_linear': False, 'activation_func': <function silu at 0x7cbff53fa7a0>, 'normalization': 'RMSNorm', 'gated_linear_unit': True, 'pipeline_dtype': torch.bfloat16, 'params_dtype': torch.bfloat16, 'bf16': True, 'tensor_model_parallel_size': 1, 'pipeline_model_parallel_size': 1, 'expert_model_parallel_size': 1, 'expert_tensor_parallel_size': 1, 'virtual_pipeline_model_parallel_size': None, 'context_parallel_size': 1, 'overlap_p2p_comm': False, 'batch_p2p_comm': False, 'sequence_parallel': False, 'variable_seq_lengths': True, 'masked_softmax_fusion': True, 'moe_token_dispatcher_type': 'alltoall', 'use_cpu_initialization': False, 'moe_ffn_hidden_size': 768, 'moe_router_bias_update_rate': 0.001, 'moe_router_topk': 8, 'num_moe_experts': 128, 'moe_aux_loss_coeff': 0.0, 'moe_router_load_balancing_type': 'none', 'moe_grouped_gemm': True, 'moe_router_score_function': 'softmax', 'persist_layer_norm': True, 'bias_activation_fusion': True, 'bias_dropout_fusion': True, 'moe_router_pre_softmax': False, 'qk_layernorm': True, 'num_layers_in_first_pipeline_stage': None, 'num_layers_in_last_pipeline_stage': None}
/usr/local/lib/python3.12/dist-packages/megatron/core/transformer/transformer_config.py:1338: UserWarning: Using a large number of experts (e.g. >=32) without fp32 routing. Consider enabling moe_router_dtype for better numerical stability.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/cpu_offload.py:695: DeprecationWarning: Offloading weights is deprecated. Using offload_weights=True does not have any effect.
  warnings.warn(
 > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 30532122624
/usr/local/lib/python3.12/dist-packages/megatron/core/dist_checkpointing/strategies/torch.py:916: FutureWarning: `load_state_dict` is deprecated and will be removed in future versions. Please use `load` instead.
  checkpoint.load_state_dict(
/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/planner_helpers.py:418: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  device = getattr(value, "device", None)
[rank0]: Traceback (most recent call last):
[rank0]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]:   File "<frozen runpy>", line 88, in _run_code
[rank0]:   File "/osmosis/verl/verl/model_merger/__main__.py", line 73, in <module>
[rank0]:     main()
[rank0]:   File "/osmosis/verl/verl/model_merger/__main__.py", line 68, in main
[rank0]:     merger.merge_and_save()
[rank0]:   File "/osmosis/verl/verl/model_merger/megatron_model_merger.py", line 496, in merge_and_save
[rank0]:     model_state_dict = self._load_state_dicts(model_ckpt_path)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/osmosis/verl/verl/model_merger/megatron_model_merger.py", line 279, in _load_state_dicts
[rank0]:     model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/osmosis/verl/verl/utils/megatron/dist_checkpointing.py", line 64, in load_dist_checkpointing
[rank0]:     state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/dist_checkpointing/serialization.py", line 161, in load
[rank0]:     loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/lib/python3.12/contextlib.py", line 81, in inner
[rank0]:     return func(*args, **kwds)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/dist_checkpointing/strategies/fully_parallel.py", line 221, in load
[rank0]:     return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/dist_checkpointing/strategies/torch.py", line 916, in load
[rank0]:     checkpoint.load_state_dict(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/typing_extensions.py", line 3004, in wrapper
[rank0]:     return arg(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 41, in load_state_dict
[rank0]:     return _load_state_dict(
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 234, in _load_state_dict
[rank0]:     central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 219, in reduce_scatter
[rank0]:     raise result
[rank0]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0])
[rank0]: Traceback (most recent call last): (RANK 0)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank0]:     local_data = map_fun()
[rank0]:                  ^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/logger.py", line 87, in wrapper
[rank0]:     result = func(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 223, in local_step
[rank0]:     local_plan = planner.create_local_plan()
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/dist_checkpointing/strategies/torch.py", line 606, in create_local_plan
[rank0]:     self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/megatron/core/dist_checkpointing/strategies/torch.py", line 559, in _validate_global_shapes
[rank0]:     raise KeyError(
[rank0]: KeyError: "decoder.layers.self_attention.linear_proj.weight from model not in state dict: ['chained_0.optimizer.state.exp_avg.decoder.final_layernorm.weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.mlp.router.weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.pre_mlp_layernorm.weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.self_attention.k_layernorm.weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.self_attention.linear_proj.weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.self_attention.linear_qkv.layer_norm_weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.self_attention.linear_qkv.weight', 'chained_0.optimizer.state.exp_avg.decoder.layers.self_attention.q_layernorm.weight', 'chained_0.optimizer.state.exp_avg.embedding.word_embeddings.weight', 'chained_0.optimizer.state.exp_avg.output_layer.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.final_layernorm.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.mlp.router.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.pre_mlp_layernorm.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.self_attention.k_layernorm.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.self_attention.linear_proj.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.self_attention.linear_qkv.layer_norm_weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.self_attention.linear_qkv.weight', 'chained_0.optimizer.state.exp_avg_sq.decoder.layers.self_attention.q_layernorm.weight', 'chained_0.optimizer.state.exp_avg_sq.embedding.word_embeddings.weight', 'chained_0.optimizer.state.exp_avg_sq.output_layer.weight', 'chained_0.optimizer.state.fp32_param.decoder.final_layernorm.weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.mlp.router.weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.pre_mlp_layernorm.weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.self_attention.k_layernorm.weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.self_attention.linear_proj.weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.self_attention.linear_qkv.layer_norm_weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.self_attention.linear_qkv.weight', 'chained_0.optimizer.state.fp32_param.decoder.layers.self_attention.q_layernorm.weight', 'chained_0.optimizer.state.fp32_param.embedding.word_embeddings.weight', 'chained_0.optimizer.state.fp32_param.output_layer.weight', 'chained_1.optimizer.state.exp_avg.decoder.layers.mlp.experts.experts.linear_fc1.weight', 'chained_1.optimizer.state.exp_avg.decoder.layers.mlp.experts.experts.linear_fc2.weight', 'chained_1.optimizer.state.exp_avg_sq.decoder.layers.mlp.experts.experts.linear_fc1.weight', 'chained_1.optimizer.state.exp_avg_sq.decoder.layers.mlp.experts.experts.linear_fc2.weight', 'chained_1.optimizer.state.fp32_param.decoder.layers.mlp.experts.experts.linear_fc1.weight', 'chained_1.optimizer.state.fp32_param.decoder.layers.mlp.experts.experts.linear_fc2.weight', 'rng_state/shard_0.0_1.2', 'rng_state/shard_0.1_1.2']"

Information

  • [x] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [x] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Run any megatron GRPO/GSPO example on Qwen3 30B A3B MoE, then attempt to merge model weights to safetensors

Expected behavior

Produce a safetensors model folder

BaiqingL avatar Dec 01 '25 07:12 BaiqingL

try with mbridge, single GPU is enough for exporting https://github.com/volcengine/verl/issues/3057#issuecomment-3190788317

ISEEKYAN avatar Dec 01 '25 09:12 ISEEKYAN

You don't need to merge if you train with mbridge. The merged huggingface is already under huggingface folder

vermouth1992 avatar Dec 01 '25 11:12 vermouth1992

@vermouth1992 Hello, the saved HF model lacks the weight file. How can this be resolved?

Image

dtl123456 avatar Dec 03 '25 11:12 dtl123456