Feat: Support TP for long-context draft model training
Motivation
Training Llama-3.1 models (8B and 70B) in offline mode with long context lengths (e.g., 8K, 16K, or 32K) currently fails with Out-of-Memory (OOM) errors, even on multi-GPU setups.
Modifications
-
add TP support in specforge/modeling/draft/llama3_eagle.py
-
rewrite AllReduce in linear.py to aviod UserWarning(UserWarning: c10d::allreduce_: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.))
-
Added Correctness Tests: Included new tests to verify that the output of the TP-enabled implementation is numerically identical to the original single-GPU implementation.
-
Implemented a robust save_pretrained method in the Eagle3DraftModel base class (specforge/modeling/draft/base.py).
Related Issues
#112
Accuracy Test
Benchmark & Profiling
Before (Original): Training Llama-3.1-8B with an 8192 context length on 2*H20 fails with an OOM error.
torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ $ROOT_DIR/scripts/train_eagle3_offline.py \ --target-model-path /mnt/model/Meta-Llama-3.1-8B-Instruct \ --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/longwriter.jsonl \ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/longwriter \ --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \ --num-epochs 1 \ --batch-size 1 \ --learning-rate 1e-4 \ --max-length 8192 \ --chat-template llama3 \ --cache-dir $ROOT_DIR/cache \ --report-to swanlab \ --swanlab-project eagle3 \ --swanlab-key xxx \
will OOM
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 95.22 GiB of which 302.56 MiB is free. Including non-PyTorch memory, this process has 94.91 GiB memory in use. Of the allocated memory 86.89 GiB is allocated by PyTorch, and 6.61 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) [rank0]:[W807 04:12:51.599551331 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Add tp-size:
torchrun \ --standalone \ --nproc_per_node $NUM_GPUS \ $ROOT_DIR/scripts/train_eagle3_offline.py \ --target-model-path /mnt/model/Meta-Llama-3.1-8B-Instruct/main \ --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/longwriter.jsonl \ --train-hidden-states-path $ROOT_DIR/cache/hidden_states/longwriter \ --output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \ --num-epochs 1 \ --batch-size 1 \ --learning-rate 1e-4 \ --max-length 8192 \ --chat-template llama3 \ --cache-dir $ROOT_DIR/cache \ --tp-size $NUM_GPUS
It can run successfully
Todo
Add comprehensive benchmark results for several tp training scenarios
Checklist
- [ ] Format your code according to the Code Formatting with Pre-Commit.
- [ ] Add unit tests as outlined in the Running Unit Tests.
- [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [ ] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.
@yd-oom This is feature is really exciting! could you please solve the conflicts? and did you test it using llama3.1B? Is the accept length good?
@zyksir Hi,Conflicts resolved. This was tested on Llama 3.1 8B. The results with TP=2 are identical to the baseline (non-TP) after two epochs on ShareGPT.
Our team has been using this function internally for a month