jax icon indicating copy to clipboard operation
jax copied to clipboard

Allow collectives in manually sharded computations

Open copybara-service[bot] opened this issue 3 years ago • 0 comments

Allow collectives in manually sharded computations

... at least when the manual sharding applies to the whole mesh, because that's all that XLA can support right now. This is especially important when computing gradients of xmapped functions (when manual lowering is enabled), since AD often introduces many psums.

copybara-service[bot] avatar Aug 11 '22 12:08 copybara-service[bot]