pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

[c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager `async_op=True` collective

Open yf225 opened this issue 1 year ago • 12 comments

This PR aims to support the following use case:

def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

where the collective is issued in eager (with async_op=True) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.


Test commands:

  • pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait
  • pytest -rA test/test_fx.py::TestDCE::test_keep_collectives
  • pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload
  • pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited
  • pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_work_registry
  • pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited
  • pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_work_registry
  • pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal
  • pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph
  • pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing
  • pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli
  • pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True
  • pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees
  • python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco

Stack from ghstack (oldest at bottom):

  • -> #137763
  • #135273
  • #137161
  • #138178

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

Differential Revision: D64511994

yf225 avatar Oct 11 '24 06:10 yf225

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/137763

Note: Links to docs will display an error until the docs builds have been completed.

:x: 1 New Failure

As of commit 26d81d68320ebc54449acdf2324d617037b68ef3 with merge base 02339e674d8891de9e5b772768324940cc4a4548 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 11 '24 06:10 pytorch-bot[bot]

Hmm I think this might be problematic. The registration mechanism is designed for the function collective subsystem in which we know in-flight tensors are meant to be waited by _c10d_functional.wait_tensor, which removes the work from the registry on completion. Unconditionally registering work here would cause memory leak, because in non-compiled use cases, we would create a permanent entry for each collective call. It will affect all use cases of ProcessGroupNCCL and these will trigger the "unwaited c10d_functional collective calls" warning.

I think the general idea is sound. One thing though is that a user should never throw away the work associated with an in-flight tensor before waiting on it. So can we just require manual work registration?

def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    # If we know this will be waited by a graph, we can register the work like this
    torch._C._distributed_c10d._register_work(y, res)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

yifuwang avatar Oct 11 '24 22:10 yifuwang

Discussed offline:

  1. Keep WorkRegistry global, move the code from Functional.cpp into base ProcessGroup.cpp file (since we use it outside of functional collective now)
  2. Add register_work to each collective in ProcessGroup.hpp, to make it general to all backends not just NCCL
  3. Remove the registry entry in Work::finish, so that cleanup is done properly for pure-eager case

yf225 avatar Oct 11 '24 23:10 yf225

@pytorchbot rebase

yf225 avatar Oct 15 '24 08:10 yf225

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

pytorchmergebot avatar Oct 15 '24 08:10 pytorchmergebot

Successfully rebased gh/yf225/137/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/137763)

pytorchmergebot avatar Oct 15 '24 08:10 pytorchmergebot

The general approach looks good!

One issue is that we shouldn't and don't need to rely on WorkNCCL::isCompleted for entry removal. A work can be removed from the registry as long as its waited (i.e. EventWait issued from CPU). However, unless a job is running with CUDA_LAUNCH_BLOCKING, a collective is most likely not finished on GPU when its work.wait() is called on CPU. Relying on WorkNCCL::isCompleted extends the lifetime of the works in a similar way that record_stream extends the lifetime of allocations, which is not ideal.

Fortunately we don't need to rely on it. Please see my other comment for a suggestion.

yifuwang avatar Oct 17 '24 17:10 yifuwang

@pytorchbot merge

yf225 avatar Oct 18 '24 18:10 yf225

Merge failed

Reason: 1 mandatory check(s) are pending/not yet run. The first few are:

  • EasyCLA

Dig deeper by viewing the pending checks on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

pytorchmergebot avatar Oct 18 '24 18:10 pytorchmergebot

@pytorchbot rebase

yf225 avatar Oct 18 '24 19:10 yf225

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

pytorchmergebot avatar Oct 18 '24 19:10 pytorchmergebot

Successfully rebased gh/yf225/137/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/137763)

pytorchmergebot avatar Oct 18 '24 19:10 pytorchmergebot

@pytorchbot merge -f "test_eager_async_allreduce_inductor_wait is now skipped under rocm, all other CI jobs have passed in previous runs"

yf225 avatar Oct 21 '24 06:10 yf225

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Oct 21 '24 06:10 pytorchmergebot

@awgu - this PR adds weak storage -> work tracking for all collective calls. I wonder if this is useful for record_stream avoidance.

yifuwang avatar Oct 21 '24 21:10 yifuwang

@pytorchmergebot revert -m 'this change is breaking our prod training pipeline (verified with bisect) by increasing memory consumption 4x and causing OOM' -c ghfirst

wdvr avatar Oct 24 '24 17:10 wdvr

@pytorchbot successfully started a revert job. Check the current status here. Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot avatar Oct 24 '24 17:10 pytorchmergebot

@yf225 your PR has been successfully reverted.

pytorchmergebot avatar Oct 24 '24 17:10 pytorchmergebot

Discussed with @yifuwang : ~~1. For functional collective, we need to retain strong pointer to Work in registry; for normal collective, we can retain weak pointer. 2. Best to have two registries: 1 strong and 1 weak. 3. Double-registration (1 strong + 1 weak) for functional collective is okay (this happens since functional collective needs to call normal collective underneath). We only need warning for the strong registry right now.~~

Update: Fundamentally there is an issue:

if we assume we are holding weak pointers in the work registry, then:

dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
torch.ops.c10d_functional.wait_tensor(input1)

There is no strong reference to the dist.all_reduce work object anywhere, so the work object will be deallocated immediately right after the dist.all_reduce call, and therefore wait_tensor will not be able to call wait on the correct work object.

To address this, we probably need to change the collective call site to mark that "we want this non-func col to be waited later".

One idea is we can introduce a with allow_inflight_collective_as_graph_input(): context manager under which:

  1. all nonfunctional collectives will have strong pointers in the registry and will rely on .wait()/.wait_tensor() being called to release the strong pointers
  2. in the WorkRegistry destructor we check whether there are any unwaited nonfunctional collectives and throw error if any

This proposal will not affect the existing "not-calling-wait-on-nonfunctional-collective" use case, and the context manager is only needed when user needs to mix eager and compile.

An alternative is to require user to explicitly call register_work() after a nonfunctional collective if they want to wait on it later. But I feel that this might actually be difficult to roll out: currently it's quite common that we issue a collective in TorchRec and implicitly wait on it in downstream use site of the collective output. If we unconditionally call register_work() for each TorchRec collective, all OSS users of TorchRec need to know to call wait on the resulting work object, otherwise it will cause memory leak due to the strong reference in work registry.

We agreed that it'd be better to have the context manager (wrapping the training loop), so that the enablement is per trainer and per workflow, instead of requiring user to explicitly call to register_work() which all TorchRec users will immediately see the effect and need to modify their usage of TorchRec accordingly.

We will also only have a single registry for both funcol and non-funcol.

yf225 avatar Oct 24 '24 18:10 yf225

I had a question based on the example in the PR desc- what does it mean to torch compile a function but mark it as full graph=true? Clearly you are trying to leave the eager ops out of the compiled region which suggests not a full graph.

wconstab avatar Oct 25 '24 13:10 wconstab

I had a question based on the example in the PR desc- what does it mean to torch compile a function but mark it as full graph=true? Clearly you are trying to leave the eager ops out of the compiled region which suggests not a full graph.

@wconstab It's only fullgraph within the function (i.e. we will not graph-break on the .wait_tensor or the y * y)

yf225 avatar Oct 25 '24 17:10 yf225

Discussed with @yifuwang offline on the API design - the new design will not affect existing non-functional collective use cases. Will land this version to unblock internal use cases.

yf225 avatar Oct 28 '24 18:10 yf225

@pytorchbot merge -f "unrelated failures"

yf225 avatar Oct 28 '24 18:10 yf225

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar Oct 28 '24 18:10 pytorchmergebot

@pytorchbot revert

yf225 avatar Oct 28 '24 20:10 yf225

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -m/--message, -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

pytorch-bot[bot] avatar Oct 28 '24 20:10 pytorch-bot[bot]

@pytorchbot revert -m "Seems to have bad interaction with latest commits on trunk, reverting to be safe" -c landrace

yf225 avatar Oct 28 '24 20:10 yf225

@pytorchbot successfully started a revert job. Check the current status here. Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot avatar Oct 28 '24 20:10 pytorchmergebot

@yf225 your PR has been successfully reverted.

pytorchmergebot avatar Oct 28 '24 20:10 pytorchmergebot

Update: Did two items to prevent regression to existing use cases:

  1. Added memory-stressed test case to test_c10d_nccl.py test_unwaited to cover existing user's "not calling work.wait() for non-functional collective" use case
  2. Gated all new register_work() / unregister_work() calls with c10d::allow_inflight_collective_as_graph_input() check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

yf225 avatar Oct 29 '24 02:10 yf225