Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [ ] 2. The bug has not been fixed in the latest version.
- [ ] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- [ ] 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
- [ ] 5. Please use English, otherwise it will be closed.
Describe the bug
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/ubuntu/ocean/SpecForge/scripts/prepare_hidden_states.py", line 368, in
[rank0]: main()
[rank0]: File "/home/ubuntu/ocean/SpecForge/scripts/prepare_hidden_states.py", line 364, in main
[rank0]: hidden_states_generator.generate(eagle3_dataset)
[rank0]: File "/home/ubuntu/ocean/SpecForge/scripts/prepare_hidden_states.py", line 288, in generate
[rank0]: self._save_tensor(
[rank0]: File "/home/ubuntu/ocean/SpecForge/scripts/prepare_hidden_states.py", line 186, in _save_tensor
[rank0]: assert not torch.any(
[rank0]: AssertionError: hidden_state is expected to be non-nan
Reproduction
torchrun --nproc_per_node=8
scripts/prepare_hidden_states.py
--model-path meta-llama/Llama-3.1-8B-Instruct
--enable-aux-hidden-states
--data-path cache/dataset/sharegpt.jsonl
--chat-template llama3
--max-length 2048
--tp-size 2
--batch-size 1
--mem-frac=0.8
--num-samples 1000
Environment
cuda12.4, RTX4090, 8GPU
Have you ever prepare hidden states for offline training?
@jiangtaozh which sglang version are you using? I previously come into this issue and it turns out to be something wrong with my sglang at that time.
pip show sglang
Name: sglang
Version: 0.4.10.post1
Can you try the latest version of sglang?
yes. i upgraded to
pip show sglang
Name: sglang
Version: 0.5.1
I will let you know if we still have the nan problem.
Met error: https://github.com/sgl-project/sglang/issues/7033 in sgkernel undefined symbol.
root@15e470876a14:/app/SpecForge# pip show sglang
Name: sglang
Version: 0.5.1
[rank0]: Traceback (most recent call last):
[rank0]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 387, in
[rank0]: main()
[rank0]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 383, in main
[rank0]: hidden_states_generator.generate(eagle3_dataset)
[rank0]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 290, in generate
[rank0]: self._save_tensor(
[rank0]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 188, in _save_tensor
[rank0]: assert not torch.any(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: AssertionError: hidden_state is expected to be non-nan
6%|████▊ | 63/1000 [01:01<15:21, 1.02it/s]
6%|████▊ | 63/1000 [01:01<15:21, 1.02it/s]
[rank5]: Traceback (most recent call last):
[rank5]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 387, in
[rank5]: main()
[rank5]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 383, in main
[rank5]: hidden_states_generator.generate(eagle3_dataset)
[rank5]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 290, in generate
[rank5]: self._save_tensor(
[rank5]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 188, in _save_tensor
[rank5]: assert not torch.any(
[rank5]: ^^^^^^^^^^^^^^
[rank5]: AssertionError: hidden_state is expected to be non-nan
[rank6]: Traceback (most recent call last):
[rank6]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 387, in
[rank6]: main()
[rank6]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 383, in main
[rank6]: hidden_states_generator.generate(eagle3_dataset)
[rank6]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 290, in generate
[rank6]: self._save_tensor(
[rank6]: File "/app/SpecForge/scripts/prepare_hidden_states.py", line 188, in _save_tensor
[rank6]: assert not torch.any(
[rank6]: ^^^^^^^^^^^^^^
[rank6]: AssertionError: hidden_state is expected to be non-nan
The error only happens when tp size is greater than 2. I tested tp=1 and tp=2 both are good.