RFC-0042-torch-distributed-redesign
for preview, please check https://github.com/youkaichao/rfcs/blob/master/RFC-0042-torch-distributed-redesign.md
cc @ezyang @wconstab @kwen2501
An important usecase for this, is dynamic prefill decode disaggregation: we have prefill instance and decode instance dynamically join the group, according to the workload. And they will send/recv kv caches from/to each other.
there are other solutions, like using etcd for communicating metadata, and directly use device communication libraries like our own nccl wrapper. That means completely dropping torch.distributed from our codebase though, and will be our last resort. We do want to use PyTorch as much as we can.
The current global group is necessary for control plane operations over the cluster.
It's conflating the notion of a cluster with that of communication groups so it would be great to separate the two.
One aspect to make this feasible is whether it's possible to implement torch.distributed in terms of torch.distributed2.
One aspect to make this feasible is whether it's possible to implement torch.distributed in terms of torch.distributed2.
do you mean we have stateless version of process group torch.distributed2 , and re-implement global group in torch.distributed ? That can be a great idea!
Thanks for posting this RFC!
I want to see if we can make changes to existing torch.distributed apis first to solve some/all of your problems. And then if needed, we can consider a new set of APIs (e.g. torch.distributed2).
For the src/dst to send/recv, that is something that has been bugging us for a while and I suppose we could fix it in existing APIs without worrying about BC by simply adding new kwargs to the APIs, group_src or group_dst which would be exclusive with src and dst - e.g. you can pass one or the other but not both.
For the global group, I think this might be harder to solve but I'd like to get a document started with the different possibilities and pros/cons. cc @kwen2501
I think a lot of what's being asked here can be done with just a new entrypoint (rather than just init_process_group) and avoid having to create a new package
That's largely what I'm doing in the torchft ProcessGroups -- just initializing the underlying PG without setting the global state. It is definitely a bit clunky (since it operates on the store API) but it's generally works just fine to instantiate a PG without calling init_process_group. https://github.com/pytorch-labs/torchft/blob/main/torchft/process_group.py
i.e. in current PyTorch you can do
from torch.distributed import ProcessGroupNCCL, TCPStore
store = TCPStore(
host_name=host,
port=int(port),
is_master=False,
wait_for_workers=False,
)
store = PrefixStore("my_custom_pg", store)
pg = ProcessGroupNCCL(store, rank=10, world_size=32)
pg.rank(), pg.size()
This can be used completely in an object oriented way without relying on any "internal" apis.
@youkaichao would you be happy to use the workflow @d4l3k proposed? or is there still something missing?
@d4l3k is the PrefixStore needed such that each store can use a default UUID (does each store use UUID 0 or something)? I wonder if we should still provide a little bit of a helper here, (a) we could allow reusing the globally initialized TCPStore if it exists (or accept one as optional kwarg as alternative), (b) we could deal with UUID automatically somehow, and ensure that each PG still has a unique UUID somehow?
@d4l3k that's a great idea. I actually tried it before. however, the problem is, you cannot use pg.send/recv . there are some exceptions like torch.distributed.all_reduce that can work with these standalone groups, but torch.distributed.send/recv do not work.
I'm also exploring an idea of using the tcp store to directly implement a new set of send/recv/broadcast operations, in https://github.com/vllm-project/vllm/blob/377b74fe877c7eb4632c2ca0778b9da9a5db8ae6/vllm/distributed/utils.py#L127 . it works locally, but sometimes hangs during initialization in the ci though.
@youkaichao is the only reason that send/recv do not work because of the dst/src mapping issue? I started to prototype a possible fix for that today, I'll share it here shortly.
Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks. But for certain use cases it could work. We have also been thinking about better support for control-plane communication cc @c-p-i-o
is the only reason that send/recv do not work because of the dst/src mapping issue?
For send/recv, yes, kind of. There are other more complicated cases, though. For example, broadcast:
https://github.com/pytorch/pytorch/blob/659d2132be469a86ea34dcb7f79224c34ebb1685/torch/distributed/distributed_c10d.py#L2580
and broadcast_object_list:
https://github.com/pytorch/pytorch/blob/659d2132be469a86ea34dcb7f79224c34ebb1685/torch/distributed/distributed_c10d.py#L3239C5-L3239C26
they are quite difficult to use if i have a standalone group that is not part of the global group.
Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks.
For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .
We have also been thinking about better support for control-plane communication
that would be great.
For send/recv, yes, kind of. There are other more complicated cases, though. For example, broadcast:
ok, these look like the same thing to me. Basically, if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue. That's what it looks like to me, at least.
For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .
Well, i'm not sure what you mean about polling by default. But if i were to build send/recv on top of tcpstore, i think my 2 choices would be (1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout, (2) implement a new polling thread on the recv side that keeps checking whether a send-data has been posted. I was referring to path (2). I'm not sure if (1) is actually practical for performance reasons but we could check.
(1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout
for my use case, (1) is enough.
if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue
also need to take care of collective ops like allreduce and allgather. the goal is to support subgroups working by their own without any dependency on the global group.
@youkaichao we don't document the ProcessGroup object APIs (I'm not sure why not, we really should) but if you use them directly it should work as expected for send/recv/broadcast as the ranks are PG local rather than global
https://github.com/pytorch/pytorch/blob/f98c601efe9b426bf85d48d4949cddd01b744e55/torch/csrc/distributed/c10d/init.cpp#L2122-L2128
i.e.
pg = ProcessGroupNCCL(store, rank=10, world_size=32)
pg.send(..., local_rank, "").wait()
vs
import torch.distributed as dist
dist.send(..., global_rank, group=pg)
@d4l3k this is great! I don't see broadcast_object_list in the ProcessGroup object APIs though.
do you know why we don't directly use the tcp store to send object in broadcast_object_list, instead we use the tensor transport to send a single-element size tensor, and then a tensor with that amount of bytes? see https://github.com/pytorch/pytorch/blob/476e0697f523bca1a39c6269137b3dacad66e306/torch/distributed/distributed_c10d.py#L3329 for the relevant code.
@youkaichao you're right -- that API isn't a process group API and is instead a layer on top
I think it's really a performance reason -- TCPStore has extra complexity since it's a persistent store and isn't designed for large data transfer. Generally more performant/safer to use pickle+collectives
We don't really recommend broadcast_object_list in general
it is often the case for us to send a tensor with unknown sizes. currently, in inference, we have to:
- send the size of metadata
- send metadata (dtype, shape)
- send the tensor data
It includes 3 steps, and it would be better if we can merge step 1 and 2 into one step, by directly using an object store.
our usecase involves sending multiple small tensors frequently, so reducing the latency is critical here.
@youkaichao you can use broadcast_object_list by passing in the group into it
i.e.
dist.broadcast_object_list(..., group=my_group)
FYI, we have a prototype of a similar API to this at https://github.com/pytorch/pytorch/blob/main/torch/distributed/_dist2.py