SpecForge icon indicating copy to clipboard operation
SpecForge copied to clipboard

[Bug] Model weights saved incompletely under multi-TP training

Open Qin10 opened this issue 5 months ago • 4 comments

Checklist

  • [x] 1. I have searched related issues but cannot get the expected help.
  • [x] 2. The bug has not been fixed in the latest version.
  • [ ] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • [ ] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
  • [x] 5. Please use English, otherwise it will be closed.

Describe the bug

Bug Description

When training with multiple TP ranks (e.g., --tp-size 4), the saved model weights are incomplete - only containing weights from individual TP ranks instead of the full model.

Model config: hidden_size=2048, intermediate_size=12288, num_attention_heads=32

Expected vs Actual Behavior

Expected (full model weights):

midlayer.mlp.down_proj.weight: torch.Size([2048, 12288])
midlayer.mlp.gate_proj.weight: torch.Size([12288, 2048])
midlayer.mlp.up_proj.weight: torch.Size([12288, 2048])
midlayer.self_attn.k_proj.weight: torch.Size([512, 4096])
midlayer.self_attn.o_proj.weight: torch.Size([2048, 4096])
midlayer.self_attn.q_proj.weight: torch.Size([4096, 4096])
midlayer.self_attn.v_proj.weight: torch.Size([512, 4096])

Actual (incomplete/sharded weights):

midlayer.mlp.down_proj.weight: torch.Size([2048, 3072])
midlayer.mlp.gate_proj.weight: torch.Size([3072, 2048])
midlayer.mlp.up_proj.weight: torch.Size([3072, 2048])
midlayer.self_attn.k_proj.weight: torch.Size([128, 4096])
midlayer.self_attn.o_proj.weight: torch.Size([2048, 1024])
midlayer.self_attn.q_proj.weight: torch.Size([1024, 4096])
midlayer.self_attn.v_proj.weight: torch.Size([128, 4096])

Impact

It may cause SGLang to throw an error when loading the Eagle model weights: Image

Reproduction

Qwen3-30B-A3B

Environment

Qin10 avatar Aug 10 '25 08:08 Qin10

Hi can you provide your command for your script? I will try to reproduce this.

FrankLeeeee avatar Aug 10 '25 12:08 FrankLeeeee

Hi can you provide your command for your script? I will try to reproduce this.

I tried this PR #117 for long ctx offline training, and my script is:

#!/bin/bash

ROOT_DIR="specforge_dir"
export PYTHONPATH=$ROOT_DIR:$PYTHONPATH


MAX_LENGTH=10240

TARGET_MODEL_PATH=""
OUTPUT_PATH=""

TRAIN_DATA_PATH=""
TRAIN_HIDDEN_PATH=""

HS_CACHE_PATH=""


NUM_GPUS=16
TP=4


torchrun \
    --standalone \
    --nproc_per_node $NUM_GPUS \
    $ROOT_DIR/scripts/train_eagle3_offline.py \
    --target-model-path $TARGET_MODEL_PATH \
    --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.json \
    --train-data-path $TRAIN_DATA_PATH \
    --train-hidden-states-path $TRAIN_HIDDEN_PATH \
    --output-dir $OUTPUT_PATH \
    --num-epochs 10 \
    --batch-size 1 \
    --learning-rate 1e-4 \
    --max-length $MAX_LENGTH \
    --chat-template qwen \
    --cache-dir $HS_CACHE_PATH \
    --embedding-key model.embed_tokens.weight \
    --tp-size $TP \
    --save-interval 1

It’s possible that when training the draft model in multi-TP, the corresponding model weight saving logic needs to be modified.

Qin10 avatar Aug 12 '25 02:08 Qin10

Possible, I guess it might be due to CP. Let me verify this.

FrankLeeeee avatar Aug 13 '25 01:08 FrankLeeeee

@Qin10 I update save_pretrained method in the Eagle3DraftModel base class (specforge/modeling/draft/base.py). you can try new #117

yd-oom avatar Aug 18 '25 16:08 yd-oom