SpecForge icon indicating copy to clipboard operation
SpecForge copied to clipboard

[Bug] TP=16 or TP=32 Failed

Open fan-niu opened this issue 4 months ago • 3 comments

Checklist

  • [x] 1. I have searched related issues but cannot get the expected help.
  • [x] 2. The bug has not been fixed in the latest version.
  • [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • [x] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
  • [x] 5. Please use English, otherwise it will be closed.

Describe the bug

When I train llama70B on 32 H100 images with TP=16 or TP=32, I get an error.

[rank5]: Traceback (most recent call last):
[rank5]:   File "SpecForge/scripts/train_eagle3_online.py", line 673, in <module>
[rank5]:     main()
[rank5]:   File "SpecForge/scripts/train_eagle3_online.py", line 516, in main
[rank5]:     plosses, _, acces = eagle3_model(
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in f
orward
[rank5]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "SpecForge/specforge/core/eagle3.py", line 168, in forward
[rank5]:     hidden_states, target, loss_mask, input_ids = self._prepare_data(
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank5]:     return func(*args, **kwargs)
[rank5]:   File "SpecForge/specforge/core/eagle3.py", line 99, in _prepare_data
[rank5]:     outputs = self.target_model(
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank5]:     output = func(self, *args, **kwargs)
[rank5]:   File "SpecForge/specforge/modeling/target/llama.py", line 534, in forward
[rank5]:     outputs: BaseModelOutputWithPast = self.model(
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank5]:     output = func(self, *args, **kwargs)
[rank5]:   File "SpecForge/specforge/modeling/target/llama.py", line 425, in forward
[rank5]:     layer_outputs = decoder_layer(
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank5]:     return super().__call__(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "SpecForge/specforge/modeling/target/llama.py", line 258, in forward
[rank5]:     hidden_states, self_attn_weights = self.self_attn(
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:   File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:   File "SpecForge/specforge/modeling/target/llama.py", line 180, in forward
[rank5]:     key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank5]: RuntimeError: shape '[1, 1234, -1, 128]' is invalid for input of size 78976

78976 = 1234 * 64

Reproduction

torchrun \
    --nnodes 4 \
    --nproc_per_node=8 \
    --node_rank 0 \
    --master_addr 127.0.0.1 \
    --master_port 8080 \
    $ROOT_DIR/scripts/train_eagle3_online.py \
    --target-model-path $model_path \
    --draft-model-config $ROOT_DIR/configs/llama31-70B-eagle3.json \
    --train-data-path $train_set \
    --eval-data-path $test_set \
    --output-dir $ROOT_DIR/outputs/ \
    --num-epochs 30 \
    --draft-global-batch-size 32 \
    --draft-micro-batch-size 1 \
    --learning-rate 2e-4 \
    --max-length 1234 \
    --ttt-length 7 \
    --chat-template llama31 \
    --cache-dir $ROOT_DIR/cache \
    --attention-backend flex_attention \
    --warmup-ratio 0.02 \
    --eval-interval 1 \
    --save-interval 1 \
    --seed 1234 \
    --dist-timeout 3600000 \
    --report-to tensorboard \
    **--tp-size 16**

Environment

code commit : 1dde031d1c746e007a0e4f368b5e1c0ddb87d9d5

fan-niu avatar Sep 08 '25 15:09 fan-niu

@yilian49 @FrankLeeeee @yubofredwang Hi team, could you please help solve this problem? Thanks a lot.

fan-niu avatar Sep 08 '25 15:09 fan-niu

Can anyone help with this issue, thank you very much.

fan-niu avatar Sep 14 '25 13:09 fan-niu

Hi, @fan-niu , I ran into the same issue.

[rank9]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 522, in forward
[rank9]:     outputs: BaseModelOutputWithPast = self.model(
[rank9]:                                        ^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank9]:     return self._call_impl(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank9]:     return forward_call(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank9]:     output = func(self, *args, **kwargs)
[rank9]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 413, in forward
[rank9]:     layer_outputs = decoder_layer(
[rank9]:                     ^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank9]:     return super().__call__(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank9]:     return self._call_impl(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank9]:     return forward_call(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 246, in forward
[rank9]:     hidden_states, self_attn_weights = self.self_attn(
[rank9]:                                        ^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank9]:     return self._call_impl(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank9]:     return forward_call(*args, **kwargs)
[rank9]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 175, in forward
[rank9]:     key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank9]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: RuntimeError: shape '[1, 1726, -1, 128]' is invalid for input of size 110464
[rank14]: Traceback (most recent call last):
[rank14]:   File "/fsx/brayden/dev/SpecForge/scripts/train_eagle3_online.py", line 675, in <module>
[rank14]:     main()
[rank14]:   File "/fsx/brayden/dev/SpecForge/scripts/train_eagle3_online.py", line 512, in main
[rank14]:     plosses, _, acces = eagle3_model(
[rank14]:                         ^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]:     return self._call_impl(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]:     return forward_call(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank14]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank14]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]:     return self._call_impl(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]:     return forward_call(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/core/eagle3.py", line 170, in forward
[rank14]:     hidden_states, target, loss_mask, input_ids = self._prepare_data(
[rank14]:                                                   ^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank14]:     return func(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/core/eagle3.py", line 100, in _prepare_data
[rank14]:     outputs = self.target_model(
[rank14]:               ^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]:     return self._call_impl(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]:     return forward_call(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank14]:     output = func(self, *args, **kwargs)
[rank14]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 522, in forward
[rank14]:     outputs: BaseModelOutputWithPast = self.model(
[rank14]:                                        ^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]:     return self._call_impl(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]:     return forward_call(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank14]:     output = func(self, *args, **kwargs)
[rank14]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 413, in forward
[rank14]:     layer_outputs = decoder_layer(
[rank14]:                     ^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank14]:     return super().__call__(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]:     return self._call_impl(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]:     return forward_call(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 246, in forward
[rank14]:     hidden_states, self_attn_weights = self.self_attn(
[rank14]:                                        ^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]:     return self._call_impl(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]:     return forward_call(*args, **kwargs)
[rank14]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]:   File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 175, in forward
[rank14]:     key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank14]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: RuntimeError: shape '[1, 1726, -1, 128]' is invalid for input of size 110464
[rank0]:[W1006 20:05:54.549071146 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

I wonder if you possibly have a fix for this yet.

b8zhong avatar Oct 06 '25 21:10 b8zhong