SpecForge
SpecForge copied to clipboard
[Bug] TP=16 or TP=32 Failed
Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [x] 2. The bug has not been fixed in the latest version.
- [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- [x] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
- [x] 5. Please use English, otherwise it will be closed.
Describe the bug
When I train llama70B on 32 H100 images with TP=16 or TP=32, I get an error.
[rank5]: Traceback (most recent call last):
[rank5]: File "SpecForge/scripts/train_eagle3_online.py", line 673, in <module>
[rank5]: main()
[rank5]: File "SpecForge/scripts/train_eagle3_online.py", line 516, in main
[rank5]: plosses, _, acces = eagle3_model(
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]: return self._call_impl(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]: return forward_call(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in f
orward
[rank5]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]: return self._call_impl(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]: return forward_call(*args, **kwargs)
[rank5]: File "SpecForge/specforge/core/eagle3.py", line 168, in forward
[rank5]: hidden_states, target, loss_mask, input_ids = self._prepare_data(
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank5]: return func(*args, **kwargs)
[rank5]: File "SpecForge/specforge/core/eagle3.py", line 99, in _prepare_data
[rank5]: outputs = self.target_model(
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]: return self._call_impl(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]: return forward_call(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank5]: output = func(self, *args, **kwargs)
[rank5]: File "SpecForge/specforge/modeling/target/llama.py", line 534, in forward
[rank5]: outputs: BaseModelOutputWithPast = self.model(
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]: return self._call_impl(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]: return forward_call(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank5]: output = func(self, *args, **kwargs)
[rank5]: File "SpecForge/specforge/modeling/target/llama.py", line 425, in forward
[rank5]: layer_outputs = decoder_layer(
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank5]: return super().__call__(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]: return self._call_impl(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]: return forward_call(*args, **kwargs)
[rank5]: File "SpecForge/specforge/modeling/target/llama.py", line 258, in forward
[rank5]: hidden_states, self_attn_weights = self.self_attn(
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank5]: return self._call_impl(*args, **kwargs)
[rank5]: File "anaconda3/envs/SpecForge_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank5]: return forward_call(*args, **kwargs)
[rank5]: File "SpecForge/specforge/modeling/target/llama.py", line 180, in forward
[rank5]: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank5]: RuntimeError: shape '[1, 1234, -1, 128]' is invalid for input of size 78976
78976 = 1234 * 64
Reproduction
torchrun \
--nnodes 4 \
--nproc_per_node=8 \
--node_rank 0 \
--master_addr 127.0.0.1 \
--master_port 8080 \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path $model_path \
--draft-model-config $ROOT_DIR/configs/llama31-70B-eagle3.json \
--train-data-path $train_set \
--eval-data-path $test_set \
--output-dir $ROOT_DIR/outputs/ \
--num-epochs 30 \
--draft-global-batch-size 32 \
--draft-micro-batch-size 1 \
--learning-rate 2e-4 \
--max-length 1234 \
--ttt-length 7 \
--chat-template llama31 \
--cache-dir $ROOT_DIR/cache \
--attention-backend flex_attention \
--warmup-ratio 0.02 \
--eval-interval 1 \
--save-interval 1 \
--seed 1234 \
--dist-timeout 3600000 \
--report-to tensorboard \
**--tp-size 16**
Environment
code commit : 1dde031d1c746e007a0e4f368b5e1c0ddb87d9d5
@yilian49 @FrankLeeeee @yubofredwang Hi team, could you please help solve this problem? Thanks a lot.
Can anyone help with this issue, thank you very much.
Hi, @fan-niu , I ran into the same issue.
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 522, in forward
[rank9]: outputs: BaseModelOutputWithPast = self.model(
[rank9]: ^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank9]: return self._call_impl(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank9]: return forward_call(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank9]: output = func(self, *args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 413, in forward
[rank9]: layer_outputs = decoder_layer(
[rank9]: ^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank9]: return super().__call__(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank9]: return self._call_impl(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank9]: return forward_call(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 246, in forward
[rank9]: hidden_states, self_attn_weights = self.self_attn(
[rank9]: ^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank9]: return self._call_impl(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank9]: return forward_call(*args, **kwargs)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 175, in forward
[rank9]: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank9]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank9]: RuntimeError: shape '[1, 1726, -1, 128]' is invalid for input of size 110464
[rank14]: Traceback (most recent call last):
[rank14]: File "/fsx/brayden/dev/SpecForge/scripts/train_eagle3_online.py", line 675, in <module>
[rank14]: main()
[rank14]: File "/fsx/brayden/dev/SpecForge/scripts/train_eagle3_online.py", line 512, in main
[rank14]: plosses, _, acces = eagle3_model(
[rank14]: ^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]: return self._call_impl(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]: return forward_call(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank14]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]: return self._call_impl(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]: return forward_call(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/core/eagle3.py", line 170, in forward
[rank14]: hidden_states, target, loss_mask, input_ids = self._prepare_data(
[rank14]: ^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank14]: return func(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/core/eagle3.py", line 100, in _prepare_data
[rank14]: outputs = self.target_model(
[rank14]: ^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]: return self._call_impl(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]: return forward_call(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank14]: output = func(self, *args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 522, in forward
[rank14]: outputs: BaseModelOutputWithPast = self.model(
[rank14]: ^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]: return self._call_impl(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]: return forward_call(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 959, in wrapper
[rank14]: output = func(self, *args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 413, in forward
[rank14]: layer_outputs = decoder_layer(
[rank14]: ^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank14]: return super().__call__(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]: return self._call_impl(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]: return forward_call(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 246, in forward
[rank14]: hidden_states, self_attn_weights = self.self_attn(
[rank14]: ^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank14]: return self._call_impl(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank14]: return forward_call(*args, **kwargs)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: File "/fsx/brayden/dev/SpecForge/.venv/lib/python3.12/site-packages/specforge/modeling/target/llama.py", line 175, in forward
[rank14]: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
[rank14]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank14]: RuntimeError: shape '[1, 1726, -1, 128]' is invalid for input of size 110464
[rank0]:[W1006 20:05:54.549071146 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
I wonder if you possibly have a fix for this yet.