PiPPy
PiPPy copied to clipboard
Skip connection + batch_isend_irecv can hang
Test case:
torchrun --nproc-per-node 4 test_fwd.py
Reason:
When stage 0 finishes computation and hit batch_send, all corresponding comm’s from other ranks are indeed fired, except the dotted line (1->2), bc rank 1 will only send to 2 AFTER its compute. But rank 2 batched all its recv’s together (from 0 + from 1). Thus rank 2 will be waiting, causing the first batch_isend_irecv waiting.
Today's PipelineStage implementation did not hang because it uses separate send and recv's, rather than batching. But that has its dropback too.