rfcs icon indicating copy to clipboard operation
rfcs copied to clipboard

RFC-0042-torch-distributed-redesign

Open youkaichao opened this issue 1 year ago • 20 comments

youkaichao avatar Nov 08 '24 05:11 youkaichao

for preview, please check https://github.com/youkaichao/rfcs/blob/master/RFC-0042-torch-distributed-redesign.md

youkaichao avatar Nov 08 '24 05:11 youkaichao

cc @ezyang @wconstab @kwen2501

youkaichao avatar Nov 08 '24 05:11 youkaichao

An important usecase for this, is dynamic prefill decode disaggregation: we have prefill instance and decode instance dynamically join the group, according to the workload. And they will send/recv kv caches from/to each other.

there are other solutions, like using etcd for communicating metadata, and directly use device communication libraries like our own nccl wrapper. That means completely dropping torch.distributed from our codebase though, and will be our last resort. We do want to use PyTorch as much as we can.

youkaichao avatar Nov 08 '24 06:11 youkaichao

The current global group is necessary for control plane operations over the cluster.

It's conflating the notion of a cluster with that of communication groups so it would be great to separate the two.

One aspect to make this feasible is whether it's possible to implement torch.distributed in terms of torch.distributed2.

kumpera avatar Nov 11 '24 17:11 kumpera

One aspect to make this feasible is whether it's possible to implement torch.distributed in terms of torch.distributed2.

do you mean we have stateless version of process group torch.distributed2 , and re-implement global group in torch.distributed ? That can be a great idea!

youkaichao avatar Nov 11 '24 17:11 youkaichao

Thanks for posting this RFC!

I want to see if we can make changes to existing torch.distributed apis first to solve some/all of your problems. And then if needed, we can consider a new set of APIs (e.g. torch.distributed2).

For the src/dst to send/recv, that is something that has been bugging us for a while and I suppose we could fix it in existing APIs without worrying about BC by simply adding new kwargs to the APIs, group_src or group_dst which would be exclusive with src and dst - e.g. you can pass one or the other but not both.

For the global group, I think this might be harder to solve but I'd like to get a document started with the different possibilities and pros/cons. cc @kwen2501

wconstab avatar Nov 12 '24 04:11 wconstab

I think a lot of what's being asked here can be done with just a new entrypoint (rather than just init_process_group) and avoid having to create a new package

That's largely what I'm doing in the torchft ProcessGroups -- just initializing the underlying PG without setting the global state. It is definitely a bit clunky (since it operates on the store API) but it's generally works just fine to instantiate a PG without calling init_process_group. https://github.com/pytorch-labs/torchft/blob/main/torchft/process_group.py

i.e. in current PyTorch you can do

from torch.distributed import ProcessGroupNCCL, TCPStore

store = TCPStore(
    host_name=host,
    port=int(port),
    is_master=False,
    wait_for_workers=False,
)
store = PrefixStore("my_custom_pg", store)

pg = ProcessGroupNCCL(store, rank=10, world_size=32)

pg.rank(), pg.size()

This can be used completely in an object oriented way without relying on any "internal" apis.

d4l3k avatar Nov 12 '24 18:11 d4l3k

@youkaichao would you be happy to use the workflow @d4l3k proposed? or is there still something missing?

@d4l3k is the PrefixStore needed such that each store can use a default UUID (does each store use UUID 0 or something)? I wonder if we should still provide a little bit of a helper here, (a) we could allow reusing the globally initialized TCPStore if it exists (or accept one as optional kwarg as alternative), (b) we could deal with UUID automatically somehow, and ensure that each PG still has a unique UUID somehow?

wconstab avatar Nov 12 '24 19:11 wconstab

@d4l3k that's a great idea. I actually tried it before. however, the problem is, you cannot use pg.send/recv . there are some exceptions like torch.distributed.all_reduce that can work with these standalone groups, but torch.distributed.send/recv do not work.

youkaichao avatar Nov 12 '24 23:11 youkaichao

I'm also exploring an idea of using the tcp store to directly implement a new set of send/recv/broadcast operations, in https://github.com/vllm-project/vllm/blob/377b74fe877c7eb4632c2ca0778b9da9a5db8ae6/vllm/distributed/utils.py#L127 . it works locally, but sometimes hangs during initialization in the ci though.

youkaichao avatar Nov 12 '24 23:11 youkaichao

@youkaichao is the only reason that send/recv do not work because of the dst/src mapping issue? I started to prototype a possible fix for that today, I'll share it here shortly.

Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks. But for certain use cases it could work. We have also been thinking about better support for control-plane communication cc @c-p-i-o

wconstab avatar Nov 13 '24 03:11 wconstab

is the only reason that send/recv do not work because of the dst/src mapping issue?

For send/recv, yes, kind of. There are other more complicated cases, though. For example, broadcast:

https://github.com/pytorch/pytorch/blob/659d2132be469a86ea34dcb7f79224c34ebb1685/torch/distributed/distributed_c10d.py#L2580

and broadcast_object_list:

https://github.com/pytorch/pytorch/blob/659d2132be469a86ea34dcb7f79224c34ebb1685/torch/distributed/distributed_c10d.py#L3239C5-L3239C26

they are quite difficult to use if i have a standalone group that is not part of the global group.

youkaichao avatar Nov 13 '24 04:11 youkaichao

Send/recv via tcpstore feels like it would require polling and become unscalable at large numbers of ranks.

For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .

We have also been thinking about better support for control-plane communication

that would be great.

youkaichao avatar Nov 13 '24 04:11 youkaichao

For send/recv, yes, kind of. There are other more complicated cases, though. For example, broadcast:

ok, these look like the same thing to me. Basically, if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue. That's what it looks like to me, at least.

For tcp store (and any "store"), it should have polling by default? I don't see any polling in the example code https://pytorch.org/docs/stable/distributed.html#torch.distributed.TCPStore .

Well, i'm not sure what you mean about polling by default. But if i were to build send/recv on top of tcpstore, i think my 2 choices would be (1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout, (2) implement a new polling thread on the recv side that keeps checking whether a send-data has been posted. I was referring to path (2). I'm not sure if (1) is actually practical for performance reasons but we could check.

wconstab avatar Nov 13 '24 12:11 wconstab

(1) naive, make the recv op 'synchronous' on the CPU, and rely on the TCP timeout

for my use case, (1) is enough.

if we added support to all our APIs for 'group_src' and 'group_dst' wherever there is currently a 'src' and 'dst', it would fix the issue

also need to take care of collective ops like allreduce and allgather. the goal is to support subgroups working by their own without any dependency on the global group.

youkaichao avatar Nov 13 '24 22:11 youkaichao

@youkaichao we don't document the ProcessGroup object APIs (I'm not sure why not, we really should) but if you use them directly it should work as expected for send/recv/broadcast as the ranks are PG local rather than global

https://github.com/pytorch/pytorch/blob/f98c601efe9b426bf85d48d4949cddd01b744e55/torch/csrc/distributed/c10d/init.cpp#L2122-L2128

i.e.

pg = ProcessGroupNCCL(store, rank=10, world_size=32)
pg.send(..., local_rank, "").wait()

vs

import torch.distributed as dist

dist.send(..., global_rank, group=pg)

d4l3k avatar Nov 14 '24 15:11 d4l3k

@d4l3k this is great! I don't see broadcast_object_list in the ProcessGroup object APIs though.

do you know why we don't directly use the tcp store to send object in broadcast_object_list, instead we use the tensor transport to send a single-element size tensor, and then a tensor with that amount of bytes? see https://github.com/pytorch/pytorch/blob/476e0697f523bca1a39c6269137b3dacad66e306/torch/distributed/distributed_c10d.py#L3329 for the relevant code.

youkaichao avatar Nov 15 '24 20:11 youkaichao

@youkaichao you're right -- that API isn't a process group API and is instead a layer on top

I think it's really a performance reason -- TCPStore has extra complexity since it's a persistent store and isn't designed for large data transfer. Generally more performant/safer to use pickle+collectives

We don't really recommend broadcast_object_list in general

d4l3k avatar Nov 15 '24 21:11 d4l3k

it is often the case for us to send a tensor with unknown sizes. currently, in inference, we have to:

  1. send the size of metadata
  2. send metadata (dtype, shape)
  3. send the tensor data

It includes 3 steps, and it would be better if we can merge step 1 and 2 into one step, by directly using an object store.

our usecase involves sending multiple small tensors frequently, so reducing the latency is critical here.

youkaichao avatar Nov 15 '24 21:11 youkaichao

@youkaichao you can use broadcast_object_list by passing in the group into it

i.e.

dist.broadcast_object_list(..., group=my_group)

d4l3k avatar Nov 18 '24 18:11 d4l3k

FYI, we have a prototype of a similar API to this at https://github.com/pytorch/pytorch/blob/main/torch/distributed/_dist2.py

d4l3k avatar Jul 14 '25 23:07 d4l3k