jax
jax copied to clipboard
Allow collectives in manually sharded computations
Allow collectives in manually sharded computations
... at least when the manual sharding applies to the whole mesh, because
that's all that XLA can support right now. This is especially important
when computing gradients of xmapped functions (when manual lowering is
enabled), since AD often introduces many psums.