[Bug] Eagle3 training for gpt-oss-120b fails with OOM
Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [x] 2. The bug has not been fixed in the latest version.
- [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- [x] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
- [x] 5. Please use English, otherwise it will be closed.
Describe the bug
I am trying to train eagle3 heads for gpt-oss-120b on one H100 node. I set NUM_GPUS=8. When I use run_gpt_oss_120b_eagle3_online.sh as is, I get the below error -
[rank2]: server_args = ServerArgs(
[rank2]: ^^^^^^^^^^^
[rank2]: File "<string>", line 275, in __init__
[rank2]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/server_args.py", line 589, in __post_init__
[rank2]: self._handle_model_specific_adjustments()
[rank2]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/server_args.py", line 958, in _handle_model_specific_adjustments
[rank2]: prefill_attn_backend in supported_backends
[rank2]: AssertionError: GptOssForCausalLM requires one of ['triton', 'trtllm_mha', 'fa3', 'fa4'] attention backend, but got the following backends
[rank2]: - Prefill: flashinfer
[rank2]: - Decode: flashinfer
To fix the above issue, I removed attention_backend key from kwargs for ServerArgs. That resulted in the below OOM error.
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 23.79 GiB. GPU 0 has a total capacity of 79.19 GiB of which 15.15 GiB is free. Including non-PyTorch memory, this process has 64.03 GiB memory in use. Of the allocated memory 59.10 GiB is allocated by PyTorch, and 1.33 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
I get OOM even if I use --target-model-backend hf
Reproduction
cd /mnt/git/SpecForge/
pip install -r requirements.txt
pip install -e .
EXP_NAME=test-gpt-oss-120b
TARGET_MODEL_PATH=/mnt/models/gpt-oss-120b
EXP_PATH=/mnt/git/SpecForge/exp/$EXP_NAME
NUM_GPUS=8
MAX_LENGTH=8192
CHAT_TEMPLATE=gpt-oss-naive
python scripts/build_eagle3_dataset_cache.py \
--target-model-path $TARGET_MODEL_PATH \
--draft-model-config ./configs/gpt-oss-120B-eagle3.json \
--train-data-path $EXP_PATH/dataset/all_train.jsonl \
--cache-dir $EXP_PATH/cache/ \
--chat-template $CHAT_TEMPLATE \
--max-length $MAX_LENGTH
torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
scripts/train_eagle3.py \
--target-model-path $TARGET_MODEL_PATH \
--draft-model-config ./configs/gpt-oss-120B-eagle3.json \
--train-data-path $EXP_PATH/dataset/all_train.jsonl \
--output-dir $EXP_PATH/outputs \
--tp-size 8 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length $MAX_LENGTH \
--chat-template $CHAT_TEMPLATE \
--cache-dir $EXP_PATH/cache/ \
--target-model-backend sglang \
--dist-timeout 60
Environment
Main branch of https://github.com/sgl-project/SpecForge
Maybe you can add parameter in scripts/train_eagle3.py.
--attention-backend fa3
Thanks for looking into this. Updating attention_backend to fa3 in ServerArgs is still giving OOM.
[rank4]: Traceback (most recent call last):
[rank4]: File "/mnt/git/SpecForge/scripts/train_eagle3.py", line 775, in <module>
[rank4]: main()
[rank4]: File "/mnt/git/SpecForge/scripts/train_eagle3.py", line 701, in main
[rank4]: plosses, acces = run_forward(
[rank4]: ^^^^^^^^^^^^
[rank4]: File "/mnt/git/SpecForge/scripts/train_eagle3.py", line 489, in run_forward
[rank4]: eagle3_data = target_model.generate_eagle3_data(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/mnt/git/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 376, in generate_eagle3_data
[rank4]: self.extend(
[rank4]: File "/mnt/git/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 350, in extend
[rank4]: logits_list, aux_hidden_states_list, last_hidden_states_list = self._extend(
[rank4]: ^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/mnt/git/SpecForge/specforge/modeling/target/eagle3_target_model.py", line 268, in _extend
[rank4]: eagle3_output, _ = self.model_runner.forward(forward_batch)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/model_executor/model_runner.py", line 2147, in forward
[rank4]: output = self._forward_raw(
[rank4]: ^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/model_executor/model_runner.py", line 2204, in _forward_raw
[rank4]: ret = self.forward_extend(
[rank4]: ^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/model_executor/model_runner.py", line 2092, in forward_extend
[rank4]: return self.model.forward(
[rank4]: ^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank4]: return func(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/models/gpt_oss.py", line 624, in forward
[rank4]: return self.logits_processor(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank4]: return self._call_impl(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank4]: return forward_call(*args, **kwargs)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/mnt/git/SpecForge/specforge/modeling/target/sglang_backend/utils.py", line 142, in forward
[rank4]: ret = replaced_logits_processor_forward_for_eagle3(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/mnt/git/SpecForge/specforge/modeling/target/sglang_backend/utils.py", line 78, in replaced_logits_processor_forward_for_eagle3
[rank4]: logits = self._get_logits(pruned_states, lm_head, logits_metadata)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/layers/logits_processor.py", line 881, in _get_logits
[rank4]: logits = tensor_model_parallel_all_gather(logits)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/distributed/communication_op.py", line 20, in tensor_model_parallel_all_gather
[rank4]: return get_tp_group().all_gather(input_, dim)
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank4]: File "/usr/local/lib/python3.12/dist-packages/sglang/srt/distributed/parallel_state.py", line 817, in all_gather
[rank4]: output_tensor = output_tensor.reshape(
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^
[rank4]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 23.79 GiB. GPU 4 has a total capacity of 79.19 GiB of which 15.15 GiB is free. Including non-PyTorch memory, this process has 64.03 GiB memory in use. Of the allocated memory 59.10 GiB is allocated by PyTorch, and 1.33 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Regarding the memory issue, you might try using the --sglang-mem-fraction-static parameter to reduce memory usage. Since I train very large models offline, these are the only suggestions I have.
I meet the same error,but i want train qwen3 4b
You can use the hf backend.
I meet the same error,but i want train qwen3 4b
Atleast for gpt-oss-120b, using --target-model-backend hf did not work. But it can maybe work for a smaller model.
You can use the hf backend.
I meet the same error,but i want train qwen3 4b
thanks, i succeed