SpecForge icon indicating copy to clipboard operation
SpecForge copied to clipboard

Feat: Support TP for long-context draft model training

Open yd-oom opened this issue 4 months ago • 2 comments

Motivation

Training Llama-3.1 models (8B and 70B) in offline mode with long context lengths (e.g., 8K, 16K, or 32K) currently fails with Out-of-Memory (OOM) errors, even on multi-GPU setups.

Modifications

  1. add TP support in specforge/modeling/draft/llama3_eagle.py

  2. rewrite AllReduce in linear.py to aviod UserWarning(UserWarning: c10d::allreduce_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.))

  3. Added Correctness Tests: Included new tests to verify that the output of the TP-enabled implementation is numerically identical to the original single-GPU implementation.

  4. Implemented a robust save_pretrained method in the Eagle3DraftModel base class (specforge/modeling/draft/base.py).

Related Issues

#112

Accuracy Test

image The correctness of the Tensor Parallelism implementation was verified by comparing the outputs of the attention and MLP layers against the original, non-parallelized model on 2 GPUs.

Benchmark & Profiling

Before (Original): Training Llama-3.1-8B with an 8192 context length on 2*H20 fails with an OOM error.

torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ $ROOT_DIR/scripts/train_eagle3_offline.py \ --target-model-path /mnt/model/Meta-Llama-3.1-8B-Instruct \ --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/longwriter.jsonl \ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/longwriter \ --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \ --num-epochs 1 \ --batch-size 1 \ --learning-rate 1e-4 \ --max-length 8192 \ --chat-template llama3 \ --cache-dir $ROOT_DIR/cache \ --report-to swanlab \ --swanlab-project eagle3 \ --swanlab-key xxx \ will OOM [rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 95.22 GiB of which 302.56 MiB is free. Including non-PyTorch memory, this process has 94.91 GiB memory in use. Of the allocated memory 86.89 GiB is allocated by PyTorch, and 6.61 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) [rank0]:[W807 04:12:51.599551331 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Add tp-size: torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ $ROOT_DIR/scripts/train_eagle3_offline.py \ --target-model-path /mnt/model/Meta-Llama-3.1-8B-Instruct/main \ --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/longwriter.jsonl \ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/longwriter \ --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \ --num-epochs 1 \ --batch-size 1 \ --learning-rate 1e-4 \ --max-length 8192 \ --chat-template llama3 \ --cache-dir $ROOT_DIR/cache \ --tp-size $NUM_GPUS

It can run successfully

image

Todo

Add comprehensive benchmark results for several tp training scenarios

Checklist

  • [ ] Format your code according to the Code Formatting with Pre-Commit.
  • [ ] Add unit tests as outlined in the Running Unit Tests.
  • [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [ ] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.

yd-oom avatar Aug 06 '25 20:08 yd-oom

@yd-oom This is feature is really exciting! could you please solve the conflicts? and did you test it using llama3.1B? Is the accept length good?

zyksir avatar Sep 09 '25 08:09 zyksir

@zyksir Hi,Conflicts resolved. This was tested on Llama 3.1 8B. The results with TP=2 are identical to the baseline (non-TP) after two epochs on ShareGPT. 加速比

Our team has been using this function internally for a month

yd-oom avatar Sep 18 '25 11:09 yd-oom