[Bug] Model weights saved incompletely under multi-TP training
Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [x] 2. The bug has not been fixed in the latest version.
- [ ] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- [ ] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
- [x] 5. Please use English, otherwise it will be closed.
Describe the bug
Bug Description
When training with multiple TP ranks (e.g., --tp-size 4), the saved model weights are incomplete - only containing weights from individual TP ranks instead of the full model.
Model config: hidden_size=2048, intermediate_size=12288, num_attention_heads=32
Expected vs Actual Behavior
Expected (full model weights):
midlayer.mlp.down_proj.weight: torch.Size([2048, 12288])
midlayer.mlp.gate_proj.weight: torch.Size([12288, 2048])
midlayer.mlp.up_proj.weight: torch.Size([12288, 2048])
midlayer.self_attn.k_proj.weight: torch.Size([512, 4096])
midlayer.self_attn.o_proj.weight: torch.Size([2048, 4096])
midlayer.self_attn.q_proj.weight: torch.Size([4096, 4096])
midlayer.self_attn.v_proj.weight: torch.Size([512, 4096])
Actual (incomplete/sharded weights):
midlayer.mlp.down_proj.weight: torch.Size([2048, 3072])
midlayer.mlp.gate_proj.weight: torch.Size([3072, 2048])
midlayer.mlp.up_proj.weight: torch.Size([3072, 2048])
midlayer.self_attn.k_proj.weight: torch.Size([128, 4096])
midlayer.self_attn.o_proj.weight: torch.Size([2048, 1024])
midlayer.self_attn.q_proj.weight: torch.Size([1024, 4096])
midlayer.self_attn.v_proj.weight: torch.Size([128, 4096])
Impact
It may cause SGLang to throw an error when loading the Eagle model weights:
Reproduction
Qwen3-30B-A3B
Environment
Hi can you provide your command for your script? I will try to reproduce this.
Hi can you provide your command for your script? I will try to reproduce this.
I tried this PR #117 for long ctx offline training, and my script is:
#!/bin/bash
ROOT_DIR="specforge_dir"
export PYTHONPATH=$ROOT_DIR:$PYTHONPATH
MAX_LENGTH=10240
TARGET_MODEL_PATH=""
OUTPUT_PATH=""
TRAIN_DATA_PATH=""
TRAIN_HIDDEN_PATH=""
HS_CACHE_PATH=""
NUM_GPUS=16
TP=4
torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_offline.py \
--target-model-path $TARGET_MODEL_PATH \
--draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \
--train-data-path $TRAIN_DATA_PATH \
--train-hidden-states-path $TRAIN_HIDDEN_PATH \
--output-dir $OUTPUT_PATH \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length $MAX_LENGTH \
--chat-template qwen \
--cache-dir $HS_CACHE_PATH \
--embedding-key model.embed_tokens.weight \
--tp-size $TP \
--save-interval 1
It’s possible that when training the draft model in multi-TP, the corresponding model weight saving logic needs to be modified.
Possible, I guess it might be due to CP. Let me verify this.
@Qin10 I update save_pretrained method in the Eagle3DraftModel base class (specforge/modeling/draft/base.py). you can try new #117