xla
xla copied to clipboard
Part 2. Introduce multi-node SPMD initialization for Neuron
In this PR, we adapt to account for a new initialization path that supports multi-node SPMD in Neuron. In order to minimize this change, we retain the xla.init() API, but introduce a reinitialization for PJRT alone once SPMD is enabled. Since enabling SPMD follows the initial Neuron initialization, we require reconfiguring once this is enabled, and if the user did not explicitly set XLA_USE_SPMD (via is_spmd(), as it is currently recommended). Under the hood, both APIs will guarantee that the environment is correctly configured when SPMD is enabled.