xla
xla copied to clipboard
[torch_xla2] Functional collectives
- Move collective op implementations out of
ProcessGroup
and into registered__torch_dispatch__
ops. - Rewrite ProcessGroup implementation using functional collectives
-
torch.distributed
dispatches toc10d.*
(e.g.torch.ops.c10d.allreduce_
) instead of the functional op directly. he signatures ofc10d.*
differ from_c10d_functional.*_inplace
differ somewhat (str reduce op vs enum, tensor vs list of tensor), so we can't just reuse the exact same function. - When I tried to implement
c10d.*
directly, all non-tensor objects were wrapped inScriptObject
, causing a bunch of errors. I could not figure out how to unwrap them. - The default implementation of these ops is a wrapper around PG, e.g. https://github.com/pytorch/pytorch/blob/fdd0a7f9b4fb80c4d4870569909505a9beb6ccb3/torch/csrc/distributed/c10d/Ops.cpp#L162-L180
-
- Some functional collectives ignore the process group given a group name. Others look up e.g. the world size by the PG group name in the background (namely
all_gather
). It's unclear if that's the long-term intended behavior. Let's keep the ProcessGroup for now.
Aside from the traceable op implementations being cleaner, dynamo rewrites torch.distributed
calls into their functional equivalents. E.g.
@torch.compile(backend=my_backend)
def cc(index):
dist.all_reduce(index)
return index
generates
opcode name target args kwargs
------------- ----------- ------------------------------------ --------------------- --------
placeholder arg0_1 arg0_1 () {}
call_function all_reduce _c10d_functional.all_reduce.default (arg0_1, 'sum', '0') {}
call_function wait_tensor _c10d_functional.wait_tensor.default (all_reduce,) {}
call_function copy aten.copy.default (arg0_1, wait_tensor) {}
output output output ((copy, copy),) {}
cc @qihqi
Depends on #7311