PiPPy
PiPPy copied to clipboard
[BUG] num_stages incorrect and some assertions
Hi,
First of all, thank you for the great work.
I am trying the llama example script with llama2-7b-hf and the following key packages:
torch 2.5.0
torchpippy 0.2.0
torchtext 0.6.0
torchview 0.2.6
When I run torchrun --nproc-per-node 4 pippy_llama.py, I got the following error on device 0 :
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank0]: stage = pipe.build_stage(rank, device=device)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1150, in build_stage
[rank0]: return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 799, in __init__
[rank0]: _PipelineStageBase.__init__(
[rank0]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 138, in __init__
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: Pipeline group size 4 cannot be larger than number of stages 1
I can trace back to _number_and_count_forward_stages in _IR.py and indeed the num_stages = 1 due to there is only one node.op == "call_module", and all the other node.op == "call_function".
Just for the sake to go deeper, I hard code the return in _number_and_count_forward_stages to be 4. Then I got the following error
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank0]: stage = pipe.build_stage(rank, device=device)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1150, in build_stage
[rank0]: return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 816, in __init__
[rank0]: raise AssertionError(
[rank0]: AssertionError: Number of submodules in pipe graph 1 does not match number of stages 4
[rank2]: Traceback (most recent call last):
[rank2]: File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank2]: stage = pipe.build_stage(rank, device=device)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank2]: stage_module = self.get_stage_module(stage_index)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank2]: return getattr(self.split_gm, f"submod_{stage_idx}")
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank2]: raise AttributeError(
[rank2]: AttributeError: 'GraphModule' object has no attribute 'submod_2'. Did you mean: 'submod_0'?
[rank1]: Traceback (most recent call last):
[rank1]: File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank1]: stage = pipe.build_stage(rank, device=device)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank1]: stage_module = self.get_stage_module(stage_index)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank1]: return getattr(self.split_gm, f"submod_{stage_idx}")
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank1]: raise AttributeError(
[rank1]: AttributeError: 'GraphModule' object has no attribute 'submod_1'. Did you mean: 'submod_0'?
[rank3]: Traceback (most recent call last):
[rank3]: File "/mnt/disk1/w84373270/test_pippy.py", line 48, in <module>
[rank3]: stage = pipe.build_stage(rank, device=device)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 1126, in build_stage
[rank3]: stage_module = self.get_stage_module(stage_index)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/distributed/pipelining/_IR.py", line 643, in get_stage_module
[rank3]: return getattr(self.split_gm, f"submod_{stage_idx}")
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/llm/anaconda3/envs/llamapython/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
[rank3]: raise AttributeError(
[rank3]: AttributeError: 'GraphModule' object has no attribute 'submod_3'. Did you mean: 'submod_0'?
It seems the version matching problem is still there. By the way, the same problems happen if I uninstall torchpippy.
Could you give me some hints?
Thank you very much!