xla icon indicating copy to clipboard operation
xla copied to clipboard

[torch_xla2] Functional collectives

Open will-cromar opened this issue 8 months ago • 0 comments

  • Move collective op implementations out of ProcessGroup and into registered __torch_dispatch__ ops.
  • Rewrite ProcessGroup implementation using functional collectives
    • torch.distributed dispatches to c10d.* (e.g. torch.ops.c10d.allreduce_) instead of the functional op directly. he signatures of c10d.* differ from _c10d_functional.*_inplace differ somewhat (str reduce op vs enum, tensor vs list of tensor), so we can't just reuse the exact same function.
    • When I tried to implement c10d.* directly, all non-tensor objects were wrapped in ScriptObject, causing a bunch of errors. I could not figure out how to unwrap them.
    • The default implementation of these ops is a wrapper around PG, e.g. https://github.com/pytorch/pytorch/blob/fdd0a7f9b4fb80c4d4870569909505a9beb6ccb3/torch/csrc/distributed/c10d/Ops.cpp#L162-L180
  • Some functional collectives ignore the process group given a group name. Others look up e.g. the world size by the PG group name in the background (namely all_gather). It's unclear if that's the long-term intended behavior. Let's keep the ProcessGroup for now.

Aside from the traceable op implementations being cleaner, dynamo rewrites torch.distributed calls into their functional equivalents. E.g.

@torch.compile(backend=my_backend)
def cc(index):
  dist.all_reduce(index)
  return index

generates

opcode         name         target                                args                   kwargs
-------------  -----------  ------------------------------------  ---------------------  --------
placeholder    arg0_1       arg0_1                                ()                     {}
call_function  all_reduce   _c10d_functional.all_reduce.default   (arg0_1, 'sum', '0')   {}
call_function  wait_tensor  _c10d_functional.wait_tensor.default  (all_reduce,)          {}
call_function  copy         aten.copy.default                     (arg0_1, wait_tensor)  {}
output         output       output                                ((copy, copy),)        {}

cc @qihqi

Depends on #7311

will-cromar avatar Jun 26 '24 19:06 will-cromar