PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

Pipeline Schedule confused

Open wuhouming opened this issue 1 year ago • 1 comments

When I run “torchrun --rdzv-backend=c10d --rdzv-endpoint=localhost:29500 --nnodes=1 --nproc-per-node=4 test_pipeline_schedule.py --schedules gpipe”,I got the following outputs:

[2023-12-03 08:40:53,722] torch.distributed.run: [WARNING] 
[2023-12-03 08:40:53,722] torch.distributed.run: [WARNING] *****************************************
[2023-12-03 08:40:53,722] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2023-12-03 08:40:53,722] torch.distributed.run: [WARNING] *****************************************
{'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
MY KWARGS ARE {'rank': 0, 'local_rank': 0, 'world_size': 4, 'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
{'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
MY KWARGS ARE {'rank': 3, 'local_rank': 3, 'world_size': 4, 'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
{'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
MY KWARGS ARE {'rank': 1, 'local_rank': 1, 'world_size': 4, 'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
[W Utils.hpp:133] Warning: Environment variable TORCH_NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt)
[W Utils.hpp:133] Warning: Environment variable TORCH_NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt)
{'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
MY KWARGS ARE {'rank': 2, 'local_rank': 2, 'world_size': 4, 'no_trace': False, 'trace_dir': './traces', 'schedules': ['gpipe'], 'device': 'cuda'}
[W Utils.hpp:133] Warning: Environment variable TORCH_NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt)
[W Utils.hpp:133] Warning: Environment variable TORCH_NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt)
finished pipeline stage init, self.stage_id=3, self.is_first_stage=False, self.is_last_stage=True, self.num_stages=4, self.fwd_input.shape=torch.Size([8, 4000]), self.fwd_output_grads.shape=torch.Size([8, 4000])
finished pipeline stage init, self.stage_id=1, self.is_first_stage=False, self.is_last_stage=False, self.num_stages=4, self.fwd_input.shape=torch.Size([8, 4000]), self.fwd_output_grads.shape=torch.Size([8, 4000])
finished pipeline stage init, self.stage_id=2, self.is_first_stage=False, self.is_last_stage=False, self.num_stages=4, self.fwd_input.shape=torch.Size([8, 4000]), self.fwd_output_grads.shape=torch.Size([8, 4000])
finished pipeline stage init, self.stage_id=0, self.is_first_stage=True, self.is_last_stage=False, self.num_stages=4, self.fwd_input.shape=torch.Size([8, 4000]), self.fwd_output_grads.shape=torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb True is_last_mb False
[1 FORWARD 1] is_first_mb True is_last_mb False
0 forward 0 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb False
0 forward 1 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb False
0 forward 2 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb False
0 forward 3 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb False
0 forward 4 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb False
0 forward 5 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb False
0 forward 6 finished, microbatch: torch.Size([8, 4000])
[0 FORWARD 0] is_first_mb False is_last_mb True
0 forward 7 finished, microbatch: torch.Size([8, 4000])
1 forward 0 finished, microbatch: torch.Size([8, 4000])
[1 FORWARD 1] is_first_mb False is_last_mb False
1 forward 1 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb True is_last_mb False
[1 FORWARD 1] is_first_mb False is_last_mb False
1 forward 2 finished, microbatch: torch.Size([8, 4000])
[1 FORWARD 1] is_first_mb False is_last_mb False
1 forward 3 finished, microbatch: torch.Size([8, 4000])
[1 FORWARD 1] is_first_mb False is_last_mb False
1 forward 4 finished, microbatch: torch.Size([8, 4000])
[1 FORWARD 1] is_first_mb False is_last_mb False
1 forward 5 finished, microbatch: torch.Size([8, 4000])
[1 FORWARD 1] is_first_mb False is_last_mb False
1 forward 6 finished, microbatch: torch.Size([8, 4000])
[1 FORWARD 1] is_first_mb False is_last_mb True
1 forward 7 finished, microbatch: torch.Size([8, 4000])
2 forward 0 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb False
2 forward 1 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb False
2 forward 2 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb False
[3 FORWARD 3] is_first_mb True is_last_mb False
2 forward 3 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb False
2 forward 4 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb False
2 forward 5 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb False
2 forward 6 finished, microbatch: torch.Size([8, 4000])
[2 FORWARD 2] is_first_mb False is_last_mb True
2 forward 7 finished, microbatch: torch.Size([8, 4000])
3 forward 0 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb False
3 forward 1 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb False
3 forward 2 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb False
3 forward 3 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb False
3 forward 4 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb False
3 forward 5 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb False
3 forward 6 finished, microbatch: torch.Size([8, 4000])
[3 FORWARD 3] is_first_mb False is_last_mb True
3 forward 7 finished, microbatch: torch.Size([8, 4000])
[3 BACKWARD 3] is_first_mb True is_last_mb False
3 backward 0 finished
[3 BACKWARD 3] is_first_mb False is_last_mb False
[2 BACKWARD 2] is_first_mb True is_last_mb False
3 backward 1 finished
[3 BACKWARD 3] is_first_mb False is_last_mb False
3 backward 2 finished
[3 BACKWARD 3] is_first_mb False is_last_mb False
3 backward 3 finished
[3 BACKWARD 3] is_first_mb False is_last_mb False
3 backward 4 finished
[3 BACKWARD 3] is_first_mb False is_last_mb False
3 backward 5 finished
[3 BACKWARD 3] is_first_mb False is_last_mb False
3 backward 6 finished
[3 BACKWARD 3] is_first_mb False is_last_mb True
3 backward 7 finished
2 backward 0 finished
[2 BACKWARD 2] is_first_mb False is_last_mb False
[1 BACKWARD 1] is_first_mb True is_last_mb False
2 backward 1 finished
[2 BACKWARD 2] is_first_mb False is_last_mb False
2 backward 2 finished
[2 BACKWARD 2] is_first_mb False is_last_mb False
2 backward 3 finished
[2 BACKWARD 2] is_first_mb False is_last_mb False
2 backward 4 finished
[2 BACKWARD 2] is_first_mb False is_last_mb False
2 backward 5 finished
[2 BACKWARD 2] is_first_mb False is_last_mb False
2 backward 6 finished
[2 BACKWARD 2] is_first_mb False is_last_mb True
2 backward 7 finished
1 backward 0 finished
[0 BACKWARD 0] is_first_mb True is_last_mb False
[1 BACKWARD 1] is_first_mb False is_last_mb False
0 backward 0 finished
[0 BACKWARD 0] is_first_mb False is_last_mb False
1 backward 1 finished
[1 BACKWARD 1] is_first_mb False is_last_mb False
1 backward 2 finished
[1 BACKWARD 1] is_first_mb False is_last_mb False
1 backward 3 finished
[1 BACKWARD 1] is_first_mb False is_last_mb False
1 backward 4 finished
[1 BACKWARD 1] is_first_mb False is_last_mb False
1 backward 5 finished
[1 BACKWARD 1] is_first_mb False is_last_mb False
1 backward 6 finished
[1 BACKWARD 1] is_first_mb False is_last_mb True
1 backward 7 finished
0 backward 1 finished
[0 BACKWARD 0] is_first_mb False is_last_mb False
0 backward 2 finished
[0 BACKWARD 0] is_first_mb False is_last_mb False
0 backward 3 finished
[0 BACKWARD 0] is_first_mb False is_last_mb False
0 backward 4 finished
[0 BACKWARD 0] is_first_mb False is_last_mb False
0 backward 5 finished
[0 BACKWARD 0] is_first_mb False is_last_mb False
0 backward 6 finished
[0 BACKWARD 0] is_first_mb False is_last_mb True
0 backward 7 finished

It seems to be done in microbatch order stage by stage, but microbatches are not scheduled in parallel. I'm a bit confused, It seems that it does not match the GPipe schedule.

wuhouming avatar Dec 03 '23 04:12 wuhouming

Hi @wuhouming, thank you for trying out the pipeline schedule. Fair warning we are actively developing on it and there will be changes and it still has rough edges.

For your particular question, the logging is misleading since logging is performed on the CPU side whereas the computation and communication is performed on cuda streams. The logs may print that something has "finished" whereas it has simply just been enqueued onto the cuda stream and is still is pending to be run. As an illustrative example, here is a PR https://github.com/pytorch/PiPPy/pull/896 which adds an artificial sleep in rank 0's forward. You'll see the output of rank 1,2,3 they still log that they have "finished their forwards" even before the minibatch from rank 0 has been sent. Furthermore, adding a torch.cuda.synchronize() will make ranks 1,2,3 be in lock step with rank 0 since this forces the CPU side to wait for the CUDA current cuda operations (batch_isend_irecv) to finish. Let me know if this explanation makes sense.

A better way to check the real order of operations would be to look at the traces. We will try to integrate the traces and logging in such a way that this is clearer. Happy to reviews suggestions and PRs if you are interested in helping

H-Huang avatar Dec 04 '23 18:12 H-Huang