[c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager `async_op=True` collective
This PR aims to support the following use case:
def all_reduce_eager(x):
y = x * x
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(req, torch.distributed.Work)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
where the collective is issued in eager (with async_op=True) but waited in compiled region.
This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.
Test commands:
pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_waitpytest -rA test/test_fx.py::TestDCE::test_keep_collectivespytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overloadpytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaitedpytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_work_registrypytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaitedpytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_work_registrypytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equalpytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seedpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocesspytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eagerpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattrpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inlinepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inlinepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilationpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalarpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergencepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensorpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatchpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collectivepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_sourcepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_sourcepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatchpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointingpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocesspytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointingpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eagerpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductorpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattrpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inlinepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inlinepytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eagerpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graphpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductorpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graphpytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointingpytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoullipytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_Truepytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_treespython benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco
Stack from ghstack (oldest at bottom):
- -> #137763
- #135273
- #137161
- #138178
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec
Differential Revision: D64511994
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/137763
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:x: 1 New Failure
As of commit 26d81d68320ebc54449acdf2324d617037b68ef3 with merge base 02339e674d8891de9e5b772768324940cc4a4548 ():
NEW FAILURE - The following job has failed:
- xpu / win-vs2022-xpu-py3 / build (gh)
ninja: build stopped: subcommand failed
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hmm I think this might be problematic. The registration mechanism is designed for the function collective subsystem in which we know in-flight tensors are meant to be waited by _c10d_functional.wait_tensor, which removes the work from the registry on completion. Unconditionally registering work here would cause memory leak, because in non-compiled use cases, we would create a permanent entry for each collective call. It will affect all use cases of ProcessGroupNCCL and these will trigger the "unwaited c10d_functional collective calls" warning.
I think the general idea is sound. One thing though is that a user should never throw away the work associated with an in-flight tensor before waiting on it. So can we just require manual work registration?
def all_reduce_eager(x):
y = x * x
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(req, torch.distributed.Work)
# If we know this will be waited by a graph, we can register the work like this
torch._C._distributed_c10d._register_work(y, res)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
Discussed offline:
- Keep
WorkRegistryglobal, move the code from Functional.cpp into base ProcessGroup.cpp file (since we use it outside of functional collective now) - Add register_work to each collective in ProcessGroup.hpp, to make it general to all backends not just NCCL
- Remove the registry entry in Work::finish, so that cleanup is done properly for pure-eager case
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased gh/yf225/137/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/137763)
The general approach looks good!
One issue is that we shouldn't and don't need to rely on WorkNCCL::isCompleted for entry removal. A work can be removed from the registry as long as its waited (i.e. EventWait issued from CPU). However, unless a job is running with CUDA_LAUNCH_BLOCKING, a collective is most likely not finished on GPU when its work.wait() is called on CPU. Relying on WorkNCCL::isCompleted extends the lifetime of the works in a similar way that record_stream extends the lifetime of allocations, which is not ideal.
Fortunately we don't need to rely on it. Please see my other comment for a suggestion.
@pytorchbot merge
Merge failed
Reason: 1 mandatory check(s) are pending/not yet run. The first few are:
- EasyCLA
Dig deeper by viewing the pending checks on hud
@pytorchbot rebase
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here
Successfully rebased gh/yf225/137/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/137763)
@pytorchbot merge -f "test_eager_async_allreduce_inductor_wait is now skipped under rocm, all other CI jobs have passed in previous runs"
Merge started
Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here
@awgu - this PR adds weak storage -> work tracking for all collective calls. I wonder if this is useful for record_stream avoidance.
@pytorchmergebot revert -m 'this change is breaking our prod training pipeline (verified with bisect) by increasing memory consumption 4x and causing OOM' -c ghfirst
@pytorchbot successfully started a revert job. Check the current status here. Questions? Feedback? Please reach out to the PyTorch DevX Team
@yf225 your PR has been successfully reverted.
Discussed with @yifuwang : ~~1. For functional collective, we need to retain strong pointer to Work in registry; for normal collective, we can retain weak pointer. 2. Best to have two registries: 1 strong and 1 weak. 3. Double-registration (1 strong + 1 weak) for functional collective is okay (this happens since functional collective needs to call normal collective underneath). We only need warning for the strong registry right now.~~
Update: Fundamentally there is an issue:
if we assume we are holding weak pointers in the work registry, then:
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
torch.ops.c10d_functional.wait_tensor(input1)
There is no strong reference to the dist.all_reduce work object anywhere, so the work object will be deallocated immediately right after the dist.all_reduce call, and therefore wait_tensor will not be able to call wait on the correct work object.
To address this, we probably need to change the collective call site to mark that "we want this non-func col to be waited later".
One idea is we can introduce a with allow_inflight_collective_as_graph_input(): context manager under which:
- all nonfunctional collectives will have strong pointers in the registry and will rely on .wait()/.wait_tensor() being called to release the strong pointers
- in the WorkRegistry destructor we check whether there are any unwaited nonfunctional collectives and throw error if any
This proposal will not affect the existing "not-calling-wait-on-nonfunctional-collective" use case, and the context manager is only needed when user needs to mix eager and compile.
An alternative is to require user to explicitly call register_work() after a nonfunctional collective if they want to wait on it later. But I feel that this might actually be difficult to roll out: currently it's quite common that we issue a collective in TorchRec and implicitly wait on it in downstream use site of the collective output. If we unconditionally call register_work() for each TorchRec collective, all OSS users of TorchRec need to know to call wait on the resulting work object, otherwise it will cause memory leak due to the strong reference in work registry.
We agreed that it'd be better to have the context manager (wrapping the training loop), so that the enablement is per trainer and per workflow, instead of requiring user to explicitly call to register_work() which all TorchRec users will immediately see the effect and need to modify their usage of TorchRec accordingly.
We will also only have a single registry for both funcol and non-funcol.
I had a question based on the example in the PR desc- what does it mean to torch compile a function but mark it as full graph=true? Clearly you are trying to leave the eager ops out of the compiled region which suggests not a full graph.
I had a question based on the example in the PR desc- what does it mean to torch compile a function but mark it as full graph=true? Clearly you are trying to leave the eager ops out of the compiled region which suggests not a full graph.
@wconstab It's only fullgraph within the function (i.e. we will not graph-break on the .wait_tensor or the y * y)
Discussed with @yifuwang offline on the API design - the new design will not affect existing non-functional collective use cases. Will land this version to unblock internal use cases.
@pytorchbot merge -f "unrelated failures"
Merge started
Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here
@pytorchbot revert
❌ 🤖 pytorchbot command failed:
@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification
usage: @pytorchbot revert -m MESSAGE -c
{nosignal,ignoredsignal,landrace,weird,ghfirst}
Try @pytorchbot --help for more info.
@pytorchbot revert -m "Seems to have bad interaction with latest commits on trunk, reverting to be safe" -c landrace
@pytorchbot successfully started a revert job. Check the current status here. Questions? Feedback? Please reach out to the PyTorch DevX Team
@yf225 your PR has been successfully reverted.
Update: Did two items to prevent regression to existing use cases:
- Added memory-stressed test case to test_c10d_nccl.py
test_unwaitedto cover existing user's "not calling work.wait() for non-functional collective" use case - Gated all new
register_work()/unregister_work()calls withc10d::allow_inflight_collective_as_graph_input()check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).
The risk of this new version of PR causing regression should be very low.